Rotary Embeddings & RMSNorm
Root Mean Square Layer Normalization (RMSNorm)
In the Gemma architecture, standard Layer Normalization is replaced by RMSNorm. This variant scales the input based on the root mean square of the activations, providing computational efficiency and training stability.
GemmaRMSNorm
The GemmaRMSNorm class implements the normalization used throughout the text-processing layers of the model.
Usage
from modeling_gemma import GemmaRMSNorm
# Initialize for a specific hidden dimension
norm = GemmaRMSNorm(dim=2048, eps=1e-6)
# Apply to hidden states
normalized_states = norm(hidden_states)
API Reference
| Parameter | Type | Description |
| :--- | :--- | :--- |
| dim | int | The dimension of the input tensor (usually hidden_size). |
| eps | float | A small value added to the denominator to prevent division by zero. Defaults to 1e-6. |
Input:
x(torch.Tensor): The input hidden states to be normalized.
Output:
torch.Tensor: Normalized tensor of the same shape and type as the input.
Rotary Positional Embeddings (RoPE)
Gemma utilizes Rotary Positional Embeddings (RoPE) to incorporate positional information into the model. Unlike traditional additive positional embeddings, RoPE applies a rotation to the Query (Q) and Key (K) representations in the attention mechanism, allowing the model to better capture relative distances between tokens.
GemmaRotaryEmbedding
The GemmaRotaryEmbedding class manages the calculation of the rotation frequencies based on the sequence length and embedding dimensions.
Usage
from modeling_gemma import GemmaRotaryEmbedding
# Initialize with head dimension and base
rope = GemmaRotaryEmbedding(
dim=256,
max_position_embeddings=8192,
base=10000.0
)
# Generate frequency components for specific positions
# Position IDs typically range from 0 to current_seq_len
cos, sin = rope(hidden_states, position_ids)
API Reference
| Parameter | Type | Description |
| :--- | :--- | :--- |
| dim | int | The dimension of the attention heads (head_dim). |
| max_position_embeddings | int | The maximum sequence length supported. |
| base | float | The base value for calculating the rotation frequencies (theta). Defaults to 10000.0. |
Input/Output
Forward Input:
x(torch.Tensor): The input states, used primarily to determine the current device and data type.position_ids(torch.LongTensor): The indices representing the position of each token in the sequence.
Forward Output:
Tuple[torch.Tensor, torch.Tensor]: Returns two tensors,cosandsin, which are applied to the Query and Key states within the attention layer.
Integration Note
While GemmaRotaryEmbedding is an internal component of the attention mechanism, its behavior is governed by the GemmaConfig. You can control its properties by adjusting rope_theta and head_dim in your model configuration.