Introduction
Overview
vlm_from_scratch is a focused, pedagogical implementation of a Vision Language Model (VLM) based on the PaliGemma architecture. By combining a SigLIP vision encoder with a Gemma language model, this project provides a clean-room implementation for understanding and running multimodal inference.
This repository is designed for users who want a transparent, minimal-dependency approach to multimodal processing, supporting tasks such as image captioning, visual question answering (VQA), and object detection.
Key Features
- Modular Architecture: Independent implementations of SigLIP (Vision) and Gemma (Text).
- Hugging Face Compatibility: Utilities to load weights directly from the Hugging Face Hub.
- Efficient Inference: Supports KV-caching for faster token generation and Top-P sampling.
- Rich Tokenization: Pre-configured support for specialized tokens like bounding boxes (
<loc0000>) and segmentation masks (<seg000>).
Quick Start
The following example demonstrates how to load a model and run inference on an image and text prompt.
from utils import load_hf_model
from processing_paligemma import PaliGemmaProcessor
from modeling_gemma import PaliGemmaForConditionalGeneration
import torch
# 1. Load the model and tokenizer
model_path = "./path-to-paligemma-weights"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, tokenizer = load_hf_model(model_path, device)
# 2. Initialize the processor
# num_image_tokens and image_size are typically found in the model config
processor = PaliGemmaProcessor(
tokenizer=tokenizer,
num_image_tokens=model.config.vision_config.num_image_tokens,
image_size=model.config.vision_config.image_size
)
# 3. Prepare inputs
prompt = "Describe this image."
image_path = "sample_image.jpg"
model_inputs = processor(text=[prompt], images=[Image.open(image_path)])
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
# 4. Generate tokens
with torch.no_grad():
outputs = model(
input_ids=model_inputs["input_ids"],
pixel_values=model_inputs["pixel_values"],
attention_mask=model_inputs["attention_mask"]
)
# Process logits to get the next token...
Core API Reference
utils.load_hf_model
A utility function to instantiate the model architecture and populate it with pre-trained weights.
Parameters:
model_path(str): Local path to the directory containing.safetensorsandconfig.json.device(str): The target device for model loading (e.g.,"cuda","mps","cpu").
Returns:
Tuple[PaliGemmaForConditionalGeneration, AutoTokenizer]: The initialized model and its corresponding tokenizer.
processing_paligemma.PaliGemmaProcessor
The primary interface for preparing data. It handles image resizing, normalization, and prepending the required number of <image> tokens to the text prompt.
Constructor:
tokenizer: An instance of a Gemma tokenizer.num_image_tokens(int): The number of visual tokens the vision encoder produces.image_size(int): The required input resolution for the vision backbone.
Input (__call__):
text(List[str]): A list of prompts.images(List[Image.Image]): A list of PIL images.
Output:
dict: Containsinput_ids,pixel_values, andattention_maskas PyTorch tensors.
modeling_gemma.PaliGemmaForConditionalGeneration
The main model class representing the full VLM.
Key Methods:
forward(...): Performs a forward pass. Acceptsinput_ids,pixel_values,attention_mask, and an optionalkv_cache.tie_weights(): Synchronizes the input embedding and output projection weights to improve performance and reduce memory usage.
Technical Considerations
KV-Caching
For efficient autoregressive generation, the project includes a KVCache class. Instead of re-processing the entire sequence for every new token, the model stores previous Key and Value states.
Image Processing
Images are normalized using standard ImageNet mean and standard deviation:
- Mean:
[0.5, 0.5, 0.5] - Std:
[0.5, 0.5, 0.5] - Rescaling: Pixel values are rescaled to the
[0, 1]range before normalization.
Special Tokens
The model is configured to handle:
<image>: Placeholders for visual embeddings.<loc0000>-<loc1023>: Tokens representing coordinates for object detection.<seg000>-<seg127>: Tokens representing segments for mask generation.