KV Cache Mechanism
KV Cache Mechanism
The KV (Key-Value) Cache is a performance optimization technique used during autoregressive token generation. Instead of recomputing the Key and Value vectors for all previous tokens in every iteration, the KVCache class stores these states and retrieves them as needed, significantly reducing the computational overhead from $O(n^2)$ to $O(n)$.
Overview
In a Vision Language Model (VLM) like PaliGemma, the model first processes the image and the input prompt (the "prefill" phase). During the subsequent generation phase, the model produces one token at a time. The KVCache captures the hidden states of previous tokens so that each new step only needs to compute the representations for the single newest token.
Usage in Inference
To use the KV Cache, instantiate the KVCache class and pass it into the model's forward method. The model will automatically update the cache with new states and return the updated object.
from modeling_gemma import KVCache, PaliGemmaForConditionalGeneration
# 1. Initialize the cache
kv_cache = KVCache()
# 2. In your generation loop, pass the cache to the model
for _ in range(max_tokens_to_generate):
outputs = model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
kv_cache=kv_cache, # Pass the cache here
)
# 3. Retrieve the updated cache and the next token logits
kv_cache = outputs["kv_cache"]
next_token_logits = outputs["logits"][:, -1, :]
# ... logic to select next_token ...
# 4. Prepare for the next step: input_ids becomes just the single new token
input_ids = next_token.unsqueeze(-1)
API Reference
KVCache()
Initializes a new, empty Key-Value cache.
num_items() -> int
Returns the current sequence length stored within the cache. This is useful for tracking how many tokens (including image tokens and text tokens) have been processed so far.
update(key_states, value_states, layer_idx) -> Tuple[torch.Tensor, torch.Tensor]
This is an internal method used by the model's attention layers to store new states and retrieve the full history.
- Inputs:
key_states(torch.Tensor): The new key tensors for the current token. Shape:[Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim].value_states(torch.Tensor): The new value tensors for the current token. Shape:[Batch_Size, Num_Heads_KV, Seq_Len, Head_Dim].layer_idx(int): The index of the transformer layer performing the update.
- Returns:
- A tuple containing the concatenated history of keys and values for that specific layer.
Key Benefits
- Speed: Eliminates redundant calculations for previous tokens in the sequence.
- Memory Efficiency: While it consumes VRAM to store the tensors, it prevents the exponential growth of computation time as the generated sequence gets longer.
- Seamless Integration: The
PaliGemmaForConditionalGenerationmodel is designed to detect the presence of thekv_cacheand switch between full-sequence processing and incremental-token processing automatically.