diff --git a/models/model_loader.py b/models/model_loader.py index f048bad..55797d7 100644 --- a/models/model_loader.py +++ b/models/model_loader.py @@ -3,11 +3,11 @@ import os import torch from transformers import Qwen2VLForConditionalGeneration, AutoProcessor -from transformers import MllamaForConditionalGeneration -from vllm.sampling_params import SamplingParams +from transformers import MllamaForConditionalGeneration, BitsAndBytesConfig from transformers import AutoModelForCausalLM import google.generativeai as genai from vllm import LLM +from vllm.sampling_params import SamplingParams from groq import Groq from dotenv import load_dotenv @@ -57,15 +57,26 @@ def load_model(model_choice): if model_choice == 'qwen': device = detect_device() + + # Configure 4-bit quantization + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", # NormalFloat4 quantization type + bnb_4bit_use_double_quant=True, # Use double quantization for better efficiency + bnb_4bit_compute_dtype=torch.float16 if device != 'cpu' else torch.float32 + ) + + # Load the model with 4-bit quantization model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-7B-Instruct", - torch_dtype=torch.float16 if device != 'cpu' else torch.float32, - device_map="auto" + quantization_config=quantization_config, + device_map="auto", + trust_remote_code=True # Ensure compatibility with the model ) + processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - model.to(device) _model_cache[model_choice] = (model, processor, device) - logger.info("Qwen model loaded and cached.") + logger.info("Qwen model (4-bit quantized) loaded and cached.") return _model_cache[model_choice] elif model_choice == 'gemini': @@ -79,26 +90,37 @@ def load_model(model_choice): elif model_choice == 'llama-vision': device = detect_device() model_id = "alpindale/Llama-3.2-11B-Vision-Instruct" + + # Configure 4-bit quantization + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", # NormalFloat4 quantization type + bnb_4bit_use_double_quant=True, # Use double quantization for better efficiency + bnb_4bit_compute_dtype=torch.float16 if device != 'cpu' else torch.float32 + ) + + # Load the model with 4-bit quantization model = MllamaForConditionalGeneration.from_pretrained( model_id, - torch_dtype=torch.float16 if device != 'cpu' else torch.float32, - device_map="auto" + quantization_config=quantization_config, + device_map="auto", + trust_remote_code=True ) + processor = AutoProcessor.from_pretrained(model_id) - model.to(device) _model_cache[model_choice] = (model, processor, device) - logger.info("Llama-Vision model loaded and cached.") + logger.info("Llama-Vision model (4-bit quantized) loaded and cached.") return _model_cache[model_choice] - + elif model_choice == "pixtral": device = detect_device() mistral_models_path = os.path.join(os.getcwd(), 'mistral_models', 'Pixtral') - + if not os.path.exists(mistral_models_path): os.makedirs(mistral_models_path, exist_ok=True) from huggingface_hub import snapshot_download - snapshot_download(repo_id="mistralai/Pixtral-12B-2409", - allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], + snapshot_download(repo_id="mistralai/Pixtral-12B-2409", + allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], local_dir=mistral_models_path) from mistral_inference.transformer import Transformer @@ -107,11 +129,11 @@ def load_model(model_choice): tokenizer = MistralTokenizer.from_file(os.path.join(mistral_models_path, "tekken.json")) model = Transformer.from_folder(mistral_models_path) - + _model_cache[model_choice] = (model, tokenizer, generate, device) logger.info("Pixtral model loaded and cached.") return _model_cache[model_choice] - + elif model_choice == "molmo": device = detect_device() processor = AutoProcessor.from_pretrained(