
Video understanding remains one of the most challenging frontiers in computer vision. Unlike static images, videos exhibit rich temporal dynamics, including human actions, object interactions, and scene transitions. Conventional video classification methods rely on architectures such as 3D CNNs and Video Transformers (Timesformer, ViViT). These methods, while effective, cannot incorporate natural-language guidance into the classification process.
Simultaneously, the rapid progress in Vision-Language Models (VLMs) — such as CLIP, BLIP, BLIP-2, and Flamingo — has demonstrated that jointly learning visual and linguistic representations yields powerful, generalizable features. BLIP-2 (Bootstrapped Language-Image Pre-training 2) introduced an efficient architecture, the Querying Transformer (Q-Former), that bridges a frozen image encoder and a frozen large language model, achieving state-of-the-art performance on image-text tasks with minimal trainable parameters.
However, BLIP-2 was designed for static images. Extending it to video understanding and further conditioning it on natural-language prompts (instructions) opens the door to a flexible, instruction-driven video classification paradigm — where the model not only sees the video but is also told what to look for.
In this article, we’ll walk through a complete, generic pipeline for fine-tuning a BLIP2 model for video classification.
What is BLIP2?
The BLIP-2 model was proposed in BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models by Junnan Li, Dongxu Li, Silvio Savarese, and Steven Hoi. BLIP-2 leverages frozen pre-trained image encoders and large language models (LLMs) by training a lightweight, 12-layer Transformer encoder in between them, achieving state-of-the-art performance on various vision-language tasks.

This work presents a method to fine-tune BLIP-2 for prompt-instructed video classification by:
- Extending the BLIP-2 vision encoder to handle video inputs (temporal frame sequences)
- Leveraging the pretrained Q-Former for visual feature aggregation
- Conditioning classification on a natural language prompt via a cross-attention fusion mechanism
- Training only a small set of parameters (Q-Former + fusion layers + classifier) while keeping the vision encoder frozen
InstructBLIP demonstrated that conditioning the Q-Former on task-specific instructions significantly improves performance across diverse vision-language tasks. Our work takes a similar approach but applies it specifically to video classification, where temporal reasoning is critical and the instruction guides spatial-temporal feature aggregation.
Python Implementation
Import Libraries
import os
import av
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as HFDataset, DatasetDict
from transformers import (
Blip2Processor,
BatchEncoding,
TrainingArguments,
Trainer,
EarlyStoppingCallback,
)
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
accuracy_score,
f1_score,
classification_report,
confusion_matrix,
)
from typing import Any, Callable, Optional
import logging
import pandas as pd
from dataclasses import dataclass
from transformers import (
Blip2Config,
Blip2ForConditionalGeneration,
Blip2QFormerModel,
Blip2VisionModel,
)
from transformers.modeling_outputs import BaseModelOutputWithPooling
Training Hyperparameters
MODEL_NAME = "Salesforce/blip2-opt-2.7b"
NUM_FRAMES = 8
BATCH_SIZE = 8
NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1
OUTPUT_DIR = "./videoblip_classification_output"
RANDOM_SEED = 42
IMG_SIZE = 224
Read Video
def read_video_pyav(filepath: str, num_frames: int = 8) -> np.ndarray:
"""
Read a video file and uniformly sample `num_frames` frames.
:returns: np.ndarray of shape (num_frames, height, width, 3), dtype uint8
"""
container = av.open(filepath)
stream = container.streams.video[0]
total_frames = stream.frames
if total_frames == 0:
frames = [f.to_ndarray(format="rgb24") for f in container.decode(video=0)]
total_frames = len(frames)
indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
return np.stack([frames[i] for i in indices])
indices = set(np.linspace(0, total_frames - 1, num_frames, dtype=int).tolist())
frames = []
for i, frame in enumerate(container.decode(video=0)):
if i in indices:
frames.append(frame.to_ndarray(format="rgb24"))
if len(frames) == num_frames:
break
container.close()
while len(frames) < num_frames:
frames.append(frames[-1])
return np.stack(frames)
Function to Process Video and Prompt Jointly
def process(
processor: Blip2Processor,
video: np.ndarray | None = None,
text: str | list[str] | None = None,
)-> BatchEncoding:
if video is not None:
video_list = [frame for frame in video]
inputs = processor(images=video_list, text=text, return_tensors="pt")
if video is not None:
if isinstance(inputs['pixel_values'], list):
pixel_values = torch.stack(inputs['pixel_values'])
else:
pixel_values = inputs['pixel_values']
num_frames = video.shape[0]
c, h, w = pixel_values.shape[1:]
inputs["pixel_values"] = pixel_values.view(num_frames, c, h, w).permute(1, 0, 2, 3)
return inputs
Customized BLIP2 Model for Videos
@dataclass
class VideoClassificationOutput:
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple] = None
attentions: Optional[tuple] = None
class VideoBlipVisionModel(Blip2VisionModel):
"""Video-aware BLIP2 vision model."""
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
output_hidden_states: bool | None = None,
return_dict: bool | None = None,
)-> tuple | BaseModelOutputWithPooling:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
batch, _, time, _, _ = pixel_values.size()
flat_pixel_values = pixel_values.permute(0, 2, 1, 3, 4).flatten(end_dim=1)
vision_outputs: BaseModelOutputWithPooling = super().forward(
pixel_values=flat_pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
seq_len = vision_outputs.last_hidden_state.size(1)
last_hidden_state = vision_outputs.last_hidden_state.view(
batch, time * seq_len, -1
)
pooler_output = vision_outputs.pooler_output.view(batch, time, -1)
hidden_states = (
tuple(
hidden.view(batch, time * seq_len, -1)
for hidden in vision_outputs.hidden_states
)
if vision_outputs.hidden_states is not None
else None
)
attentions = (
tuple(
hidden.view(batch, time, -1, seq_len, seq_len)
for hidden in vision_outputs.attentions
)
if vision_outputs.attentions is not None
else None
)
if return_dict:
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=hidden_states,
attentions=attentions,
)
return (last_hidden_state, pooler_output, hidden_states, attentions)
class VideoBlipForClassification(Blip2ForConditionalGeneration):
"""
Prompt-conditioned VideoBLIP classifier.
Inputs:
- pixel_values: (B, C, T, H, W)
- input_ids: (B, L)
- attention_mask: (B, L)
The prompt is fed into Q-Former along with query tokens and video features.
"""
def __init__(self, config: Blip2Config, num_classes: int) -> None:
super(Blip2ForConditionalGeneration, self).__init__(config)
self.num_classes = num_classes
self.vision_model = VideoBlipVisionModel(config.vision_config)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)
)
self.qformer = Blip2QFormerModel(config.qformer_config)
qformer_hidden = config.qformer_config.hidden_size
self.classifier = nn.Sequential(
nn.LayerNorm(qformer_hidden),
nn.Linear(qformer_hidden, qformer_hidden),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(qformer_hidden, num_classes),
)
self.language_projection = None
self.language_model = None
self.post_init()
@classmethod
def from_pretrained_videoblip(
cls,
pretrained_model_name_or_path: str,
num_classes: int,
freeze_vision: bool = True,
freeze_qformer: bool = False,
**kwargs,
) -> "VideoBlipForClassification":
config = Blip2Config.from_pretrained(pretrained_model_name_or_path)
model = cls(config, num_classes=num_classes)
pretrained = Blip2ForConditionalGeneration.from_pretrained(
pretrained_model_name_or_path, **kwargs
)
msg = model.vision_model.load_state_dict(
pretrained.vision_model.state_dict(), strict=False
)
print(f"Vision model load info: {msg}")
model.qformer.load_state_dict(pretrained.qformer.state_dict())
model.query_tokens.data.copy_(pretrained.query_tokens.data)
if freeze_vision:
for param in model.vision_model.parameters():
param.requires_grad = False
print("Vision encoder frozen.")
if freeze_qformer:
for param in model.qformer.parameters():
param.requires_grad = False
model.query_tokens.requires_grad = False
print("Q-Former frozen.")
del pretrained
return model
def forward(
self,
pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: bool = True,
) -> VideoClassificationOutput:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
frame_embeds = vision_outputs.last_hidden_state
frame_attention_mask = torch.ones(
frame_embeds.size()[:-1], dtype=torch.long, device=frame_embeds.device
)
query_tokens = self.query_tokens.expand(frame_embeds.size(0), -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=frame_embeds,
encoder_attention_mask=frame_attention_mask,
return_dict=return_dict,
)
query_output = query_outputs[0] if not return_dict else query_outputs.last_hidden_state
pooled_query_output = query_output.mean(dim=1)
logits = self.classifier(pooled_query_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
if not return_dict:
output = (logits,)
if output_hidden_states:
output = output + (query_outputs.hidden_states,)
if output_attentions:
output = output + (query_outputs.attentions,)
return ((loss,) + output) if loss is not None else output
return VideoClassificationOutput(
loss=loss,
logits=logits,
hidden_states=query_outputs.hidden_states,
attentions=query_outputs.attentions,
)
Data Preparation
def collate_fn(batch: list[dict]) -> dict:
pixel_values = torch.stack([torch.tensor(item['pixel_values']) for item in batch]) # (B, C, T, H, W)
labels = torch.tensor([item['labels'] for item in batch]) # (B,)
return {'pixel_values': pixel_values, 'labels': labels}
processor = Blip2Processor.from_pretrained(MODEL_NAME)
def create_dataset(data: DatasetDict, processor: Blip2Processor) -> DatasetDict:
def process_example(example: dict) -> dict:
video_path = example['video']
label = example['label']
video_frames = read_video_pyav(video_path, num_frames=NUM_FRAMES) # (T, H, W, C) numpy array
# Pass the numpy array directly
inputs = process(processor, video=video_frames, text=PROMPT)
return {
'pixel_values': inputs['pixel_values'], # (C, T, H, W)
'labels': label
}
num_cpus = os.cpu_count()
return data.map(process_example, batched=False, num_proc=num_cpus)
train_dataset = create_dataset(data['train'], processor)
val_dataset = create_dataset(data['validation'], processor)
test_dataset = create_dataset(data['test'], processor)
Training
def compute_metrics(eval_pred):
predictions, labels = eval_pred
if isinstance(predictions, tuple):
predictions = predictions[0]
pred_ids = np.argmax(predictions, axis=-1)
return {
"accuracy": accuracy_score(labels, pred_ids),
"f1": f1_score(labels, pred_ids, average="weighted"),
}
class VideoBlipTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
labels = inputs.pop("labels", None)
# When return_outputs is True, the trainer expects a tuple (loss, model_outputs...).
# We should ask the model to return a tuple by setting return_dict=False.
outputs = model(
pixel_values=inputs["pixel_values"],
labels=labels,
return_dict=not return_outputs, # Return a dict for training, a tuple for evaluation
)
if return_outputs:
# For evaluation, outputs is a tuple (loss, logits, ...)
loss = outputs[0]
return (loss, outputs)
else:
# For training, outputs is the VideoClassificationOutput object
loss = outputs.loss
return loss
model = VideoBlipForClassification.from_pretrained_videoblip(
MODEL_NAME,
num_classes=NUM_CLASSES,
freeze_vision=True,
freeze_qformer=False,
)
training_args = TrainingArguments(
run_name="VideoBlip-Classification",
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=8,
learning_rate=LEARNING_RATE,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
save_total_limit=2,
remove_unused_columns=False,
fp16=torch.cuda.is_available(),
logging_steps=10,
report_to="wandb",
)
trainer = VideoBlipTrainer(
model=model,
args=raining_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=collate_fn,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
References
- https://huggingface.co/docs/transformers/index
- https://huggingface.co/papers/2301.12597
- https://github.com/salesforce/LAVIS/tree/main/projects/blip2
- https://github.com/huggingface/transformers/tree/main/src/transformers/models
Fine-tuning BLIP2 for Prompt-instructed Video Classification was originally published in Towards AI on Medium, where people are continuing the conversation by highlighting and responding to this story.