Sampling & Top-P Logic
Sampling & Top-P Logic
The vlm_from_scratch library provides flexible text generation strategies to control the balance between coherence and creativity. Users can toggle between deterministic Greedy Decoding and stochastic Nucleus (Top-P) Sampling via the inference interface.
Generation Strategies
During inference, the model produces logits for the next token. How the next token is selected depends on the configuration passed to the generation loop:
1. Greedy Decoding
By default (or when do_sample=False), the model selects the token with the highest probability. This is best for factual tasks where the most "likely" answer is desired.
2. Temperature Scaling
When sampling is enabled, the temperature parameter modifies the logit distribution before the softmax layer:
- Low Temperature (< 1.0): Makes the distribution "sharper," increasing the likelihood of high-probability tokens and making the model more confident/conservative.
- High Temperature (> 1.0): Flattens the distribution, giving more weight to less likely tokens, resulting in more "creative" or diverse output.
3. Nucleus (Top-P) Sampling
Nucleus sampling further refines the selection by filtering the vocabulary. It identifies the smallest set of tokens whose cumulative probability exceeds the threshold p. Tokens outside this "nucleus" are zeroed out, preventing the model from picking extremely low-probability (and often nonsensical) tokens.
Usage Example
You can configure sampling parameters directly through the main inference entry point or when calling the test_inference function.
# Example: Creative sampling with high temperature and Top-P
from inference import main
main(
model_path="./paligemma-weight-folder",
prompt="Describe this image in detail",
image_file_path="example.jpg",
do_sample=True,
temperature=0.8,
top_p=0.9,
max_tokens_to_generate=100
)
API Reference
test_inference Parameters
The sampling logic is encapsulated within the test_inference function in inference.py.
| Parameter | Type | Default | Description |
| :--- | :--- | :--- | :--- |
| do_sample | bool | False | Whether or not to use sampling; use greedy decoding otherwise. |
| temperature | float | 0.8 | The value used to module the next token probabilities. |
| top_p | float | 0.9 | The cumulative probability threshold for nucleus sampling. |
| max_tokens_to_generate | int | 100 | The maximum number of tokens to append to the prompt. |
Internal Logic: _sample_top_p
While this is an internal utility, it defines the core behavior of the sampling mechanism:
def _sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""
Args:
probs (torch.Tensor): Softmaxed probabilities of shape (Batch, Vocab_Size).
p (float): Cumulative probability threshold (0.0 < p <= 1.0).
Returns:
torch.Tensor: The sampled token index.
"""
Process Flow:
- Sort: Probabilities are sorted in descending order.
- Cumulative Sum: The cumulative distribution is calculated.
- Mask: Tokens that fall outside the cumulative threshold
pare masked (set to0.0). - Renormalize: The remaining probabilities are rescaled to sum to
1.0. - Multinomial: A token is sampled from the remaining filtered distribution.