Weight Loading & Safetensors
Weight Loading & Safetensors Compatibility
This project is designed to be fully compatible with official PaliGemma weights provided by Google via the HuggingFace Hub. It uses the safetensors format for efficient, memory-safe loading of model parameters into the custom "from scratch" implementation.
The load_hf_model Utility
The primary entry point for initializing the model with pre-trained weights is the load_hf_model function located in utils.py. This function automates the process of discovery, configuration, and state-dict mapping.
API Definition
def load_hf_model(model_path: str, device: str) -> Tuple[PaliGemmaForConditionalGeneration, AutoTokenizer]
Parameters:
model_path(str): The local directory path containing the HuggingFace model files (e.g.,config.json,*.safetensors).device(str): The target device for the model tensors (e.g.,"cuda","mps", or"cpu").
Returns:
Tuple[PaliGemmaForConditionalGeneration, AutoTokenizer]: A tuple containing the initialized VLM model and its corresponding tokenizer.
Usage Example
To load a model from a local checkpoint directory:
from utils import load_hf_model
# Path to the directory where you downloaded the HF model
model_directory = "./paligemma-3b-pt-224"
device = "cuda"
model, tokenizer = load_hf_model(model_directory, device)
# The model is now ready for inference
model.eval()
Loading Workflow
The weight loading process follows these internal steps to bridge the gap between the HuggingFace format and the scratch implementation:
- Tokenizer Initialization: Loads the
AutoTokenizerwithpadding_side="right". - Shard Aggregation: Scans the directory for all
*.safetensorsfiles and aggregates the tensors into a unified dictionary. - Config Mapping: Reads
config.jsonto instantiate aPaliGemmaConfig, ensuring the model architecture (layers, heads, dimensions) matches the weights. - State Dict Injection: Loads the tensors into the
PaliGemmaForConditionalGenerationinstance usingload_state_dict. - Weight Tying: Calls
model.tie_weights()to ensure the input embeddings and output linear layers share the same parameters, which is critical for the Gemma architecture.
Supported Weight Formats
The implementation specifically looks for .safetensors files. If your model checkpoint is in the older .bin (PyTorch pickle) format, you must convert it to safetensors first using the HuggingFace transformers library to use it with this project:
# Quick conversion snippet if needed
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("google/paligemma-3b-pt-224")
model.save_pretrained("./paligemma-safetensors", safe_serialization=True)