Llava has limited budget: Multimodal AI resources are limited

Over the past few years, I have mainly used large language models, training, fine-tuning, prompts, etc., as this is highly required in the market and users. But I believe that LLM, which works primarily on text, is just the beginning of Genai. At some point, everyone will want Physical AImodels can be seen, heard, felt and rational in a more solid human way.
So, let’s start with multimodal. In this notebook, I introduce Llava, an architecture that interprets images and text to generate multi-modal responses.
In this tutorial, we will use weight components suitable for running your notebook on a free-layer environment like Google Colab.
The components we want to use are:
1️⃣ Clip vit b/32 As an image encoder
2️⃣ Tinyllama-1.1b As a language model
3️⃣a 2-layer MLP adapter Bridge two
set up
Before we dig into the code, let’s set up our environment.
Let’s first install the dataset library.
!pip install -U datasets
Now we need to import the required packages from hugging faces and pytorch. These imports provide pre-trained models and utilities for multi-modal processing.
import json
from pathlib import Path
import requests
import safetensors
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import (
AutoConfig,
AutoTokenizer,
LlamaTokenizer,
LlavaConfig,
LlavaForConditionalGeneration,
LlavaProcessor,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
)
from transformers.models.clip.modeling_clip import CLIPVisionModel
from transformers.models.clip.image_processing_clip import CLIPImageProcessor
Download pre-trained model components
Our LLAVA model will be made by:

this hf_hub_download
It is the hub we are exploring to retrieve the weight of pre-training:
vision_backbone_name = "openai/clip-vit-base-patch32"
text_backbone_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
_ = hf_hub_download(
vision_backbone_name, filename="pytorch_model.bin", local_dir="/content"
)
_ = hf_hub_download(
text_backbone_name, filename="model.safetensors", local_dir="/content"
)
Model
Instantiate the new llava model
Now, let’s instantiate the new LLAVA model. As mentioned above, the LLAVA model consists of two parts we just downloaded, a visual encoder and a text decoder.
vision_config = AutoConfig.from_pretrained(vision_backbone_name).vision_config
text_config = AutoConfig.from_pretrained(text_backbone_name)
We specify the backbone model in the LLAVA configuration. Then, we use LlavaForConditionalGeneration(llava_config)
.
llava_config = LlavaConfig(vision_config=vision_config, text_config=text_config)
model = LlavaForConditionalGeneration(llava_config).cuda()
model
Perform some surgical procedures

Previously, we said we could build LLAVA models from pre-trained image encoder and pre-trained LLM. Let’s do this!
The original Llava model is initialized from a Fixture vit l/14 and Vicuna v1.5 7b. To make things easier to manage with the resources provided by Google CoLab’s free plan, we will use Just scrape b/16 and Tinyllama 1.1b.
The only component we will train is the 2-layer MLP adapter between them.
To use the clip and Tinyllama model, we need to load its pre-trained weight. But these weights can appear in different formats, such as .safetensor or .bin. The load_weights function handles it for us. It checks the file type and calls the correct loading function.
def load_weights(path_to_weights: str):
if path_to_weights.endswith(".safetensors"):
return load_safetensors_weights(path_to_weights)
elif path_to_weights.endswith(".bin"):
return load_bin_weights(path_to_weights)
else:
raise ValueError(f"Unsupported weights file: {path_to_weights}")
def load_bin_weights(path_to_weights: str):
return torch.load(path_to_weights, weights_only=True)
def load_safetensors_weights(path_to_weights: str):
return safetensors.torch.load_file(path_to_weights)
vision_backbone_state_dict = load_weights("/content/pytorch_model.bin")
text_backbone_state_dict = load_weights("/content/model.safetensors")
Inject the weight of the visual backbone into the model 💉
The following lines load the weights into the visual part of the model. We set Strictly = false Be flexible because it allows us to skip any weights that are completely inconsistent with the expected structure of the model.
incompatible_keys = model.vision_tower.load_state_dict(
vision_backbone_state_dict, strict=False
)
assert len(incompatible_keys.missing_keys) == 0, (
f"Missing keys in state dict: {incompatible_keys.missing_keys}"
)
incompatible_keys.unexpected_keys
Inject the weight of the text backbone into the model 💉
The same logic as before, also applies to text models.
incompatible_keys = model.language_model.load_state_dict(
text_backbone_state_dict, strict=True
)
Freeze pre-trained components❄️
We now want to freeze the visual and text models of the backbone because we don’t want to update their weights when training.
We will only train small adapters (MLPs that connect vision and language), which trains lighter and faster.
_ = model.vision_tower.requires_grad_(False)
_ = model.language_model.requires_grad_(False)
# Then we define a helper function to count model parameters
def count_parameters(model, trainable_only=False):
return sum(
p.numel()
for p in model.parameters()
if not trainable_only or p.requires_grad
)
print(f"Total parameters: {count_parameters(model)}")
print(f"Trainable parameters: {count_parameters(model, trainable_only=True)}")
processor
Before feeding some text into our model, we need to convert words to numbers. This is the token needed.
tokenizer = LlamaTokenizer.from_pretrained(
text_backbone_name, additional_special_tokens=["", ""]
)
tokenizer.pad_token_id = 32001
Here is the format we will use the LLAVA model chat.
The first part is called System promptswhich contains general guidelines on how the model should respond to the user.
The second part is the Jinja template (basically code) which determines how the conversation is rendered based on some structured input (see the example below).
LLAVA_CHAT_TEMPLATE = (
"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. "
"{% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"
)
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
sample_messages = [
{
"content": [
{
"index": 0,
"text": None,
"type": "image"
},
{
"index": None,
"text": "nWhat potential activities might be popular at this location?",
"type": "text"
}
],
"role": "user"
},
{
"content": [
{
"index": None,
"text": (
"At this location, with a sandy path leading to the ocean where multiple boats, including "
"sailboats, are moored, popular activities might include boating, sailing, swimming, and "
"beachcombing. Additionally, the sandy path and shoreline provide an ideal setting for leisurely "
"strolls and picnics, while the ocean view offers a serene environment for relaxation and "
"photography. Depending on the specific area and available facilities, other water sports such as "
"kayaking, paddleboarding, and snorkeling could also be prevalent."
),
"type": "text"
}
],
"role": "assistant"
}
]
Let’s apply the chat template to our sample.
tokenizer.apply_chat_template(
sample_messages, tokenize=False, add_generation_prompt=False
)
At this point, we have set up the token and downloaded the visual model. We bring them together processor.
processor = LlavaProcessor(
image_processor=CLIPImageProcessor.from_pretrained(vision_backbone_name),
tokenizer=tokenizer,
patch_size=model.config.vision_config.patch_size,
)
processor.chat_template = LLAVA_CHAT_TEMPLATE
Since we added special tokens
and
For our tokenizer, the model requires Adjust its vocabulary Know them too
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
Dataset
Let’s download the dataset from hugging the face to use.
A dataset containing a couple samples of image text is available publicly and can be found here.
train_dataset = load_dataset(
"HuggingFaceH4/llava-instruct-mix-vsft", split="train", streaming=True
)
What are our training examples?
next(iter(train_dataset))
How do we build a batch of examples?
The following function will use the original image text example and convert it into a model-ready input. It uses a chat template to format the message to LlavaProcessor
We defined it earlier and created the appropriate training tag when the fill is ignored.
def get_data_collator(processor, ignore_index):
def collate_examples(examples):
# Extract texts and images from the raw examples
texts = []
images = []
for example in examples:
messages = example["messages"]
text = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
images.append(example["images"][0])
# Process the inputs (tokenize text and transform images)
batch = processor(texts, images, return_tensors="pt", padding=True)
# Create labels
labels = batch["input_ids"].clone()
if processor.tokenizer.pad_token_id is not None:
labels[labels == processor.tokenizer.pad_token_id] = ignore_index
batch["labels"] = labels
return batch
return collate_examples
# NOTE: this does a bit more than a collate function should...
train
Ultimately, let’s define the training parameters, including batch size, learning rate, total steps, and the speed of using mixed precision (FP16). We also avoid saving checkpoints to keep light. Then we wrap everything in one Seq2SeqTrainer
Pass the model, dataset and our custom collider to get image text input.
args = Seq2SeqTrainingArguments(
output_dir="/content/training_output",
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-4,
max_steps=350,
lr_scheduler_type="cosine_with_min_lr",
lr_scheduler_kwargs={"min_lr": 2e-5},
warmup_ratio=0.05,
logging_strategy="steps",
logging_steps=5,
fp16=True,
remove_unused_columns=False, # Important!
optim="adamw_torch",
report_to="none",
save_strategy="no", # let's not save the checkpoint to disk, otherwise it'll take 5 mins
)
trainer = Seq2SeqTrainer(
model=model,
args=args,
data_collator=get_data_collator(
processor, ignore_index=model.config.ignore_index,
),
train_dataset=train_dataset,
)
trainer.train()
reasoning
It is important to note that to make sure the inference works as expected, you should use heavier models and train for longer.
We will use this image for inference:

conversation = [
{
"content": [
{
"type": "image"
},
{
"text": "nWhat is represented in the image?",
"type": "text"
}
],
"role": "user"
}
]
In this cell block, we load the image from the URL and format the conversation using a chat template. The processor turns both into tensors. We then move the input to the model’s device and generate a response, letting the model describe the image according to the user’s prompt.
image_url = "
inputs_for_generation = processor(
images=Image.open(requests.get(image_url, stream=True).raw),
text=processor.apply_chat_template(conversation, add_generation_prompt=True),
return_tensors="pt",
)
inputs_for_generation = inputs_for_generation.to(device=model.device)
output = trainer.model.generate(
**inputs_for_generation, max_new_tokens=200, do_sample=False
)
print(processor.decode(output[0], skip_special_tokens=True))
Expand and improve
- Use larger image encoders (e.g. Large fixture) and LLM (e.g. Llama 3.1 8b)
- Training for longer. It takes some time for the model to figure out how to follow the instructions in the presence of image features
- Follow the multi-stage training procedure adopted by the original LLAVA
- Stage 1: Pre-training of feature alignment ->Train the model on single-turn instruction data, where the picture is required to be briefly described. Image encoder and LLM are frozen
- Phase 2: Fine-tuning end-to-end ->Train the model for multi-turn instruction data. Only the image encoder is frozen
Work Demo: huggingface.co/spaces/badayvedat/llava
in conclusion
I think this small project is fun and gives a better understanding of how multi-models like Llava work. Even if we use smaller models, the main idea is the same: combine vision and language into a system that can understand images and talk about them.
Of course, the results obtained in this toy example are not very good. There is a lot of room for improvement. However, making Llava work very challenging in resource-limited environments
If you like this article, please follow me on TD! 😁
💼LinkedIn Door| 🐦x(Twitter)| 💻 Website