Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import os
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from transformers import MllamaForConditionalGeneration
from vllm.sampling_params import SamplingParams
from transformers import MllamaForConditionalGeneration, BitsAndBytesConfig
from transformers import AutoModelForCausalLM
import google.generativeai as genai
from vllm import LLM
from vllm.sampling_params import SamplingParams
from groq import Groq
from dotenv import load_dotenv

Expand Down Expand Up @@ -57,15 +57,26 @@ def load_model(model_choice):

if model_choice == 'qwen':
device = detect_device()

# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4 quantization type
bnb_4bit_use_double_quant=True, # Use double quantization for better efficiency
bnb_4bit_compute_dtype=torch.float16 if device != 'cpu' else torch.float32
)

# Load the model with 4-bit quantization
model = Qwen2VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen2-VL-7B-Instruct",
torch_dtype=torch.float16 if device != 'cpu' else torch.float32,
device_map="auto"
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True # Ensure compatibility with the model
)

processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
model.to(device)
_model_cache[model_choice] = (model, processor, device)
logger.info("Qwen model loaded and cached.")
logger.info("Qwen model (4-bit quantized) loaded and cached.")
return _model_cache[model_choice]

elif model_choice == 'gemini':
Expand All @@ -79,26 +90,37 @@ def load_model(model_choice):
elif model_choice == 'llama-vision':
device = detect_device()
model_id = "alpindale/Llama-3.2-11B-Vision-Instruct"

# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4 quantization type
bnb_4bit_use_double_quant=True, # Use double quantization for better efficiency
bnb_4bit_compute_dtype=torch.float16 if device != 'cpu' else torch.float32
)

# Load the model with 4-bit quantization
model = MllamaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16 if device != 'cpu' else torch.float32,
device_map="auto"
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True
)

processor = AutoProcessor.from_pretrained(model_id)
model.to(device)
_model_cache[model_choice] = (model, processor, device)
logger.info("Llama-Vision model loaded and cached.")
logger.info("Llama-Vision model (4-bit quantized) loaded and cached.")
return _model_cache[model_choice]

elif model_choice == "pixtral":
device = detect_device()
mistral_models_path = os.path.join(os.getcwd(), 'mistral_models', 'Pixtral')

if not os.path.exists(mistral_models_path):
os.makedirs(mistral_models_path, exist_ok=True)
from huggingface_hub import snapshot_download
snapshot_download(repo_id="mistralai/Pixtral-12B-2409",
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
snapshot_download(repo_id="mistralai/Pixtral-12B-2409",
allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"],
local_dir=mistral_models_path)

from mistral_inference.transformer import Transformer
Expand All @@ -107,11 +129,11 @@ def load_model(model_choice):

tokenizer = MistralTokenizer.from_file(os.path.join(mistral_models_path, "tekken.json"))
model = Transformer.from_folder(mistral_models_path)

_model_cache[model_choice] = (model, tokenizer, generate, device)
logger.info("Pixtral model loaded and cached.")
return _model_cache[model_choice]

elif model_choice == "molmo":
device = detect_device()
processor = AutoProcessor.from_pretrained(
Expand Down