PaliGemma Architecture
PaliGemma Architecture
PaliGemma is a Vision-Language Model (VLM) composed of three primary architectural components: a SigLIP Vision Encoder, a Language Decoder (Gemma), and a Linear Projector that bridges the two.
The architecture follows a "prefix-LM" approach where visual information is converted into a sequence of tokens that are prepended to the text tokens, allowing the language model to "see" the image before processing the prompt.
Architectural Components
- Vision Encoder (SigLIP): Based on the
SiglipVisionModel, it processes input images of size $224 \times 224$ (or other configured sizes) and breaks them into patches. Each patch is embedded and processed through multiple transformer layers to generate a sequence of visual features. - Linear Projector: A simple but critical layer that maps the output hidden states of the SigLIP encoder (typically a different dimension) to the hidden dimension of the Gemma language model.
- Language Decoder (Gemma): A decoder-only transformer (
GemmaForCausalLM) that receives the combined sequence of projected image embeddings and text embeddings to generate responses.
Model Configuration
The architecture is controlled via the PaliGemmaConfig, which encapsulates the configurations for both the vision and text modules.
from modeling_gemma import PaliGemmaConfig
# Example: Initializing a configuration
config = PaliGemmaConfig(
vision_config={
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"image_size": 224,
"patch_size": 16,
},
text_config={
"hidden_size": 2048,
"intermediate_size": 16384,
"num_hidden_layers": 18,
"num_attention_heads": 8,
"vocab_size": 257152,
},
projection_dim=2048
)
Data Processing Flow
To use the model, inputs must be prepared using the PaliGemmaProcessor. This processor handles image normalization, resizing, and the complex task of merging text and image tokens into a single sequence.
- Image Path: Images are normalized using ImageNet standard mean and std, then converted into
pixel_values. - Tokenization: The text prompt is tokenized and prepended with a fixed number of
<image>tokens (determined by the number of patches the vision encoder produces). - Format: The final input to the model's embedding layer looks like:
[<image> * N] + [BOS] + [Prompt] + [\n]
# Usage via the Processor
from processing_paligemma import PaliGemmaProcessor
processor = PaliGemmaProcessor(tokenizer, num_image_tokens=256, image_size=224)
model_inputs = processor(text=["describe this image"], images=[image_pil])
# model_inputs contains:
# - input_ids: Tensor of token IDs including image placeholders
# - pixel_values: Preprocessed image tensor
# - attention_mask: Mask for the combined sequence
The PaliGemma Class Interface
The main interface for the model is PaliGemmaForConditionalGeneration. It manages the forward pass through the vision encoder, the projection layer, and the language decoder.
forward Method
The forward method is the primary entry point for both training and inference.
Inputs:
input_ids(torch.LongTensor): Indices of input sequence tokens in the vocabulary.pixel_values(torch.FloatTensor): Preprocessed image pixels.attention_mask(torch.Tensor): Mask to avoid performing attention on padding tokens.kv_cache(KVCache, optional): Cached Key and Value states for efficient autoregressive generation.
Outputs:
- A dictionary containing
logits(prediction scores for the language modeling head) and the updatedkv_cache.
from modeling_gemma import PaliGemmaForConditionalGeneration
model = PaliGemmaForConditionalGeneration(config)
# Standard Forward Pass
outputs = model(
input_ids=model_inputs["input_ids"],
pixel_values=model_inputs["pixel_values"],
attention_mask=model_inputs["attention_mask"]
)
logits = outputs["logits"]
Efficient Inference with KV-Caching
For generating text, the model utilizes a KVCache object to store previous keys and values. This prevents redundant computations of the image tokens and previously generated text tokens.
from modeling_gemma import KVCache
kv_cache = KVCache()
# Inside a generation loop:
outputs = model(
input_ids=current_token_id,
pixel_values=pixel_values,
attention_mask=current_mask,
kv_cache=kv_cache
)
# The model updates the cache internally and returns it
kv_cache = outputs["kv_cache"]
next_token_logits = outputs["logits"][:, -1, :]
Key Constants
- image_token_index: The specific ID in the vocabulary representing the
<image>placeholder (default:256000). - Special Tokens: PaliGemma uses specialized tokens for advanced tasks:
<loc0000>to<loc1023>: Used for object detection bounding boxes.<seg000>to<seg127>: Used for object segmentation.