Autoregressive Generation
Autoregressive Generation
The vlm_from_scratch library implements autoregressive generation to produce text responses based on visual and textual prompts. This process involves predicting one token at a time, where each new token is appended to the sequence and used as context for the next prediction.
Overview of the Generation Loop
The generation logic resides primarily in inference.py within the test_inference function. The workflow follows these steps:
- Input Encoding: The
PaliGemmaProcessorconverts the image and text prompt into tensors. - Prefilling (Initial Forward Pass): The model processes the image and the entire prompt at once to populate the KV Cache.
- Iterative Decoding: The model generates tokens one by one until a stop condition (maximum length or EOS token) is met.
- Decoding: The generated token IDs are converted back into human-readable text.
Preparing Inputs with PaliGemmaProcessor
Before generation begins, the PaliGemmaProcessor prepares the multimodal input. It resizes and normalizes the image, prepends the necessary <image> tokens to the prompt, and handles tokenization.
from processing_paligemma import PaliGemmaProcessor
# Initialize processor
processor = PaliGemmaProcessor(tokenizer, num_image_tokens=256, image_size=224)
# Process image and text
model_inputs = processor(text=["Identify the object in this image."], images=[image_path])
Input Parameters:
text(List[str]): The text prompts.images(List[Image.Image]): The input images.
Output:
input_ids: Tokenized text with prepended image tokens.pixel_values: Normalized image tensors.attention_mask: Mask for the self-attention mechanism.
Using the KV Cache
To optimize inference, the model utilizes a Key-Value (KV) Cache. During the iterative generation process, instead of re-processing the entire sequence for every new token, the model stores the keys and values of previous tokens.
from modeling_gemma import KVCache
kv_cache = KVCache()
# Inside the generation loop:
outputs = model(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
kv_cache=kv_cache,
)
kv_cache = outputs["kv_cache"] # Updated cache for the next iteration
In subsequent steps after the first pass, input_ids only contains the single most recently generated token, significantly reducing computation.
Sampling Strategies
The generation loop supports both deterministic and probabilistic sampling via the test_inference function.
Greedy Search
By setting do_sample=False, the model always selects the token with the highest probability.
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
Top-P (Nucleus) Sampling
By setting do_sample=True, you can use Top-P sampling and temperature to increase the diversity and creativity of the output.
| Parameter | Type | Description |
| :--- | :--- | :--- |
| temperature | float | Controls the "sharpness" of the probability distribution. Lower values make the model more confident. |
| top_p | float | The cumulative probability threshold. Only the smallest set of tokens whose cumulative probability exceeds p are considered. |
Execution Example
The generation loop continues until the tokenizer.eos_token_id is encountered or the max_tokens_to_generate limit is reached.
# Example of running the inference script via CLI
python inference.py \
--model_path "./paligemma-weight-folder" \
--prompt "Describe the scene." \
--image_file_path "example.jpg" \
--max_tokens_to_generate 100 \
--do_sample True \
--temperature 0.8 \
--top_p 0.9
Technical Workflow of the Loop
- Logits Extraction: The model outputs logits for the entire sequence; only the last position (
logits[:, -1, :]) is used to predict the next token. - Token Selection: The next token is chosen based on the selected sampling strategy.
- State Update:
- The new token becomes the
input_idsfor the next step. - The
attention_maskis concatenated with a1to account for the growing sequence. - The
kv_cacheis updated internally within the model's forward pass.
- The new token becomes the
- Termination: The loop breaks if the model generates the End-of-Sequence (EOS) token.