Gemma Language Decoder
The Gemma Language Decoder serves as the autoregressive backbone of the PaliGemma architecture. It is responsible for processing both the processed visual tokens and text tokens to generate coherent language responses. This implementation follows the Gemma transformer architecture, featuring Rotary Positional Embeddings (RoPE), RMSNorm, and an efficient KV-Caching mechanism for optimized inference.
Gemma Configuration
The architecture of the decoder is defined via the GemmaConfig class. This allows users to specify the depth, width, and attention mechanisms of the language model.
GemmaConfig Parameters
| Parameter | Type | Description |
| :--- | :--- | :--- |
| vocab_size | int | Total size of the vocabulary. |
| hidden_size | int | Dimensionality of the encoder layers and the pooler layer. |
| intermediate_size | int | Dimensionality of the "intermediate" (feed-forward) layer. |
| num_hidden_layers | int | Number of hidden layers in the Transformer decoder. |
| num_attention_heads | int | Number of attention heads for each attention layer. |
| num_key_value_heads | int | Number of key-value heads for Grouped Query Attention. |
| head_dim | int | Dimensionality of each attention head (default: 256). |
| max_position_embeddings| int | The maximum sequence length that this model might ever be used with. |
| rms_norm_eps | float | The epsilon used by the RMSNorm layers. |
from modeling_gemma import GemmaConfig
# Example configuration for a small Gemma instance
config = GemmaConfig(
vocab_size=257152,
hidden_size=2048,
intermediate_size=16384,
num_hidden_layers=18,
num_attention_heads=8,
num_key_value_heads=1,
rms_norm_eps=1e-6
)
KV-Cache Management
To enable efficient token-by-token generation, the model utilizes a KVCache. This object stores the Key and Value states of previous tokens, preventing redundant computations during the autoregressive decoding process.
API Reference
__init__(): Initializes empty lists for key and value caches across all layers.update(key_states, value_states, layer_idx):- Inputs:
key_states(Tensor),value_states(Tensor),layer_idx(int). - Outputs: Returns the concatenated (historical + new) tensors for the specific layer.
- Inputs:
num_items(): Returns the current sequence length stored in the cache.
from modeling_gemma import KVCache
# Initialize cache
kv_cache = KVCache()
# During inference, the model updates and returns the cache
outputs = model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
kv_cache=kv_cache
)
new_kv_cache = outputs["kv_cache"]
Architectural Components
While these components are used internally by the transformer blocks, they define the specific behavior of the Gemma decoder:
GemmaRMSNorm
Unlike standard LayerNorm, Gemma uses Root Mean Square Layer Normalization. This implementation includes a unit-offset weight scaling (1.0 + weight) which is specific to the Gemma/Llama family of models to improve training stability.
GemmaRotaryEmbedding (RoPE)
The model employs Rotary Positional Embeddings to encode positional information directly into the Query and Key tensors. This allows the model to handle long sequences effectively by capturing relative distances between tokens.
Inference Integration
The language decoder is typically accessed through the PaliGemmaForConditionalGeneration wrapper. In a generation loop, the decoder takes the last generated token and the existing KVCache to produce the next set of logits.
# Standard generation step
with torch.no_grad():
outputs = model(
input_ids=input_ids, # Shape: [Batch, 1] for subsequent steps
pixel_values=pixel_values,
attention_mask=attention_mask,
kv_cache=kv_cache
)
next_token_logits = outputs["logits"][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
Input/Output Tensors
When interacting with the decoder:
input_ids(torch.LongTensor): Indices of input sequence tokens in the vocabulary.pixel_values(torch.FloatTensor): Processed image tensors (passed through the Vision Encoder first).attention_mask(torch.Tensor): Mask to avoid performing attention on padding token indices.logits(torch.FloatTensor): Prediction scores of the language modeling head (vocabulary size for each token).