From d1848be63a01a4f5475131ca6d3d58cb9b42c72a Mon Sep 17 00:00:00 2001 From: Amit Moryossef Date: Wed, 27 Aug 2025 10:32:04 +0300 Subject: [PATCH 1/2] Refactor video processing and model loading functions --- CLI.py | 145 ++++++++++++++++++++++++++++++++------------------------- 1 file changed, 81 insertions(+), 64 deletions(-) diff --git a/CLI.py b/CLI.py index da96c39..bf89674 100644 --- a/CLI.py +++ b/CLI.py @@ -57,16 +57,13 @@ def get_video_tower(self): video_tower = video_tower[0] return video_tower - def get_all_tower(self, keys): tower = {key: getattr(self, f'get_{key}_tower') for key in keys} return tower - def load_video_tower_pretrained(self, pretrained_checkpoint): self.mm_projector.load_state_dict(pretrained_checkpoint, strict=True) - def initialize_image_modules(self, model_args, fsdp=None): image_tower = model_args.image_tower mm_vision_select_layer = model_args.mm_vision_select_layer @@ -92,6 +89,7 @@ def initialize_image_modules(self, model_args, fsdp=None): if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} @@ -122,6 +120,7 @@ def initialize_video_modules(self, model_args, fsdp=None): if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') + def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} @@ -133,27 +132,26 @@ def encode_images(self, images): return image_features def encode_videos(self, videos): - video_features = self.get_video_tower()(videos) # torch.Size([1, 2048, 1024]) - video_features = self.mm_projector(video_features.float()) # torch.Size([1, 2048, 4096]) + video_features = self.get_video_tower()(videos) # torch.Size([1, 2048, 1024]) + video_features = self.mm_projector(video_features.float()) # torch.Size([1, 2048, 4096]) return video_features - + def get_multimodal_embeddings(self, X_modalities): - Xs, keys= X_modalities + Xs, keys = X_modalities X_features = getattr(self, f'encode_{keys[0]}s')(Xs) # expand to get batchsize return X_features - -def get_processor(X, config, device, pretrained_checkpoint_tower, model_path = 'LanguageBind/Video-LLaVA-7B'): +def get_processor(X, config, device, pretrained_checkpoint_tower, model_path='LanguageBind/Video-LLaVA-7B'): processor = {} mm_backbone_mlp_model = LlavaMetaModel(config, pretrained_checkpoint_tower) - print(X) + print(X) if 'Image' in X: - image_tower = mm_backbone_mlp_model.get_image_tower() # LanguageBindImageTower() + image_tower = mm_backbone_mlp_model.get_image_tower() # LanguageBindImageTower() if not image_tower.is_loaded: image_tower.load_model() image_tower.to(device=device, dtype=torch.float16) @@ -175,6 +173,7 @@ class Projection(nn.Module): def __init__(self, ): super().__init__() self.linear_proj = nn.Linear(512, 4096) + def forward(self, x): return self.linear_proj(x) @@ -187,24 +186,22 @@ def __init__(self, ): nn.GELU(), nn.Linear(4096, 4096) ) + def forward(self, x): return self.proj(x) -def main( +def load_model( quantize: Optional[str] = None, dtype: str = "float32", - max_new_tokens: int = 200, - top_k: int = 200, - temperature: float = 0.8, accelerator: str = "auto", -) -> None: - + checkpoint_dir: str = "./checkpoints", +) -> dict: # import pdb; pdb.set_trace() lora_path = Path(args.lora_path) - pretrained_llm_path = Path(f"./checkpoints/vicuna-7b-v1.5/lit_model.pth") - tokenizer_llm_path = Path("./checkpoints/vicuna-7b-v1.5/tokenizer.model") - + pretrained_llm_path = Path(f"{checkpoint_dir}/vicuna-7b-v1.5/lit_model.pth") + tokenizer_llm_path = Path(f"{checkpoint_dir}/vicuna-7b-v1.5/tokenizer.model") + # assert lora_path.is_file() assert pretrained_llm_path.is_file() assert tokenizer_llm_path.is_file() @@ -223,7 +220,7 @@ def main( with EmptyInitOnDevice( device=fabric.device, dtype=dtype, quantization_mode=quantize ), lora(r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout, enabled=True): - checkpoint_dir = Path("checkpoints/vicuna-7b-v1.5") + checkpoint_dir = Path(f"{checkpoint_dir}/vicuna-7b-v1.5") lora_query = True lora_key = False lora_value = True @@ -243,13 +240,14 @@ def main( to_head=lora_head, ) model = GPT(config).bfloat16() - + mlp_path = args.mlp_path pretrained_checkpoint_mlp = torch.load(mlp_path) X = ['Video'] - mm_backbone_mlp_model, processor = get_processor(X, args, 'cuda', pretrained_checkpoint_mlp, model_path = 'LanguageBind/Video-LLaVA-7B') + mm_backbone_mlp_model, processor = get_processor(X, args, 'cuda', pretrained_checkpoint_mlp, + model_path='LanguageBind/Video-LLaVA-7B') video_processor = processor['video'] linear_proj = mm_backbone_mlp_model.mm_projector @@ -264,66 +262,85 @@ def main( print('Load llm base model from', pretrained_llm_path) print('Load lora model from', lora_path) - # load mlp again, to en sure, not neccessary actually + # load mlp again, to en sure, not neccessary actually linear_proj.load_state_dict(pretrained_checkpoint_mlp) linear_proj = linear_proj.cuda() print('Load mlp model again from', mlp_path) - print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) model.eval() model = fabric.setup_module(model) linear_proj.eval() - tokenizer = Tokenizer(tokenizer_llm_path) print('Load tokenizer from', tokenizer_llm_path) - - + return { + "tokenizer": tokenizer, + "model": model, + "mm_backbone_mlp_model": mm_backbone_mlp_model, + "video_processor": video_processor, + } + +def predict(tokenizer: Tokenizer, + model: GPT, + mm_backbone_mlp_model: LlavaMetaModel, + video_processor: any, + input_video_path: str, + prompt: str, + + max_new_tokens: int = 200, + top_k: int = 200, + temperature: float = 0.8) -> str: + + video_tensor = video_processor(input_video_path, return_tensors='pt')['pixel_values'] + + if type(video_tensor) is list: + tensor = [video.to('cuda', dtype=torch.float16) for video in video_tensor] + else: + tensor = video_tensor.to('cuda', dtype=torch.float16) # (1,3,8,224,224) + + X_modalities = [tensor, ['video']] + + video_feature = mm_backbone_mlp_model.get_multimodal_embeddings(X_modalities) + + sample = {"instruction": prompt, "input": input_video_path} + + prefix = generate_prompt_mlp(sample) + pre = torch.cat( + (tokenizer.encode(prefix.split('INPUT_VIDEO: ')[0] + "\n", bos=True, eos=False, device=model.device).view(1, + -1), + tokenizer.encode("INPUT_VIDEO: ", bos=False, eos=False, device=model.device).view(1, -1)), dim=1) + + prompt = (pre, ". ASSISTANT: ") + encoded = (prompt[0], video_feature[0], + tokenizer.encode(prompt[1], bos=False, eos=False, device=model.device).view(1, -1)) + + output_seq = generate( + model, + idx=encoded, + max_seq_length=4096, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + eos_id=tokenizer.eos_id, + tokenizer=tokenizer, + ) + outputfull = tokenizer.decode(output_seq) + return outputfull.split("ASSISTANT:")[-1].strip() + +def main(): + torch.set_float32_matmul_precision("high") + model_components = load_model() while True: - input_video_path = input("\033[0;34;40m Input video path: \033[0m") - video_tensor = video_processor(input_video_path, return_tensors='pt')['pixel_values'] - - if type(video_tensor) is list: - tensor = [video.to('cuda', dtype=torch.float16) for video in video_tensor] - else: - tensor = video_tensor.to('cuda', dtype=torch.float16) # (1,3,8,224,224) - - X_modalities = [tensor,['video']] + prompt = input("\033[0;34;40m Your question: \033[0m") - video_feature = mm_backbone_mlp_model.get_multimodal_embeddings(X_modalities) + output = predict(prompt=prompt, input_video_path=input_video_path, **model_components) - prompt = input("\033[0;34;40m Your question: \033[0m") - sample = {"instruction": prompt, "input": input_video_path} - - prefix = generate_prompt_mlp(sample) - pre = torch.cat((tokenizer.encode(prefix.split('INPUT_VIDEO: ')[0] + "\n", bos=True, eos=False, device=model.device).view(1, -1), tokenizer.encode("INPUT_VIDEO: ", bos=False, eos=False, device=model.device).view(1, -1)), dim=1) - - prompt = (pre, ". ASSISTANT: ") - encoded = (prompt[0], video_feature[0], tokenizer.encode(prompt[1], bos=False, eos=False, device=model.device).view(1, -1)) - - - t0 = time.perf_counter() - - output_seq = generate( - model, - idx=encoded, - max_seq_length=4096, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_k=top_k, - eos_id=tokenizer.eos_id, - tokenizer = tokenizer, - ) - outputfull = tokenizer.decode(output_seq) - output = outputfull.split("ASSISTANT:")[-1].strip() print("================================") print("Model output", output) print("================================") - if __name__ == "__main__": - torch.set_float32_matmul_precision("high") main() From 203bd6be2dcaea4b09b77cbfcae193ccfc910ff3 Mon Sep 17 00:00:00 2001 From: Amit Moryossef Date: Wed, 27 Aug 2025 14:29:44 +0300 Subject: [PATCH 2/2] Update CLI.py --- CLI.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/CLI.py b/CLI.py index bf89674..5b3e4dd 100644 --- a/CLI.py +++ b/CLI.py @@ -1,19 +1,14 @@ -import os import sys import time import warnings from pathlib import Path from typing import Optional -from typing import Dict, List, Literal, Optional, Tuple from lit_gpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable import lightning as L import torch -import numpy as np import torch.nn as nn -import torch.nn.functional as F -import models.vqvae as vqvae from generate import generate from lit_llama import Tokenizer, LLaMA, LLaMAConfig from lit_llama.lora import lora @@ -21,15 +16,11 @@ from lit_gpt.utils import lazy_load from scripts.video_dataset.prepare_video_dataset_video_llava import generate_prompt_mlp from options import option -import imageio -from tqdm import tqdm from models.multimodal_encoder.builder import build_image_tower, build_video_tower from models.multimodal_projector.builder import build_vision_projector warnings.filterwarnings('ignore') -args = option.get_args_parser() - class LlavaMetaModel: @@ -192,6 +183,7 @@ def forward(self, x): def load_model( + args: any, quantize: Optional[str] = None, dtype: str = "float32", accelerator: str = "auto", @@ -330,8 +322,10 @@ def predict(tokenizer: Tokenizer, return outputfull.split("ASSISTANT:")[-1].strip() def main(): + args = option.get_args_parser() + torch.set_float32_matmul_precision("high") - model_components = load_model() + model_components = load_model(args) while True: input_video_path = input("\033[0;34;40m Input video path: \033[0m") prompt = input("\033[0;34;40m Your question: \033[0m")