Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions openadapt_ml/training/trl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,26 +102,70 @@ def _load_unsloth_model(config: TRLTrainingConfig):


def _load_standard_model(config: TRLTrainingConfig):
"""Fallback: Load model with standard transformers + peft."""
from transformers import AutoModelForCausalLM, AutoProcessor
"""Fallback: Load model with standard transformers + peft.

Automatically detects vision-language models and uses the appropriate
model class (Qwen2VLForConditionalGeneration for VL models,
AutoModelForCausalLM for text-only models).
"""
from transformers import AutoConfig, AutoProcessor
from peft import LoraConfig, get_peft_model
import torch

model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
# Check if this is a vision-language model
model_config = AutoConfig.from_pretrained(
config.model_name, trust_remote_code=True
)
is_vl_model = (
"VL" in config.model_name.upper()
or "vision" in config.model_name.lower()
or hasattr(model_config, "vision_config")
)

if is_vl_model:
# Vision-language model - use Qwen2VLForConditionalGeneration or AutoModelForVision2Seq
try:
from transformers import Qwen2VLForConditionalGeneration

model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
print(" Using Qwen2VLForConditionalGeneration for VL model")
except (ImportError, ValueError, RuntimeError, TypeError):
# Fallback to AutoModelForVision2Seq for other VL models
from transformers import AutoModelForVision2Seq

model = AutoModelForVision2Seq.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
print(" Using AutoModelForVision2Seq for VL model")
else:
# Text-only model
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
print(" Using AutoModelForCausalLM for text-only model")

processor = AutoProcessor.from_pretrained(config.model_name, trust_remote_code=True)

# Apply LoRA
# Apply LoRA - use SEQ_2_SEQ_LM for VL models, CAUSAL_LM for text-only
peft_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
task_type="CAUSAL_LM",
task_type="SEQ_2_SEQ_LM" if is_vl_model else "CAUSAL_LM",
)
model = get_peft_model(model, peft_config)

Expand Down Expand Up @@ -265,7 +309,7 @@ def train_with_trl(
logging_steps=config.logging_steps,
save_strategy=config.save_strategy,
max_length=None, # Critical for VLMs
assistant_only_loss=True,
assistant_only_loss=False, # Not supported for VL models yet
)

trainer = SFTTrainer(
Expand Down
116 changes: 116 additions & 0 deletions tests/test_trl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,119 @@ def test_argparse_setup(self) -> None:
assert "--model" in source
assert "--epochs" in source
assert "--use-som" in source


# -----------------------------------------------------------------------------
# Test VL Model Detection
# -----------------------------------------------------------------------------


class TestVLModelDetection:
"""Test vision-language model detection logic in _load_standard_model.

The detection uses three criteria:
1. "VL" in model name (case-insensitive)
2. "vision" in model name (case-insensitive)
3. vision_config attribute in model config
"""

def test_vl_detection_by_name_vl_suffix(self) -> None:
"""Test VL detection for models with VL in name."""
from openadapt_ml.training.trl_trainer import TRLTrainingConfig

# These model names should be detected as VL models
vl_model_names = [
"Qwen/Qwen2-VL-7B-Instruct",
"Qwen/Qwen2.5-VL-7B-Instruct",
"unsloth/Qwen2.5-VL-7B-Instruct",
"some-model-vl-base", # lowercase vl
"Model-VL-2B",
]

for model_name in vl_model_names:
is_vl = "VL" in model_name.upper()
assert is_vl, f"Expected '{model_name}' to be detected as VL model"

def test_vl_detection_by_name_vision(self) -> None:
"""Test VL detection for models with 'vision' in name."""
vision_model_names = [
"llava-vision-7b",
"some-vision-model",
"VisionTransformer-base",
]

for model_name in vision_model_names:
is_vision = "vision" in model_name.lower()
assert is_vision, f"Expected '{model_name}' to be detected via 'vision'"

def test_text_only_detection(self) -> None:
"""Test that text-only models are NOT detected as VL."""
text_only_models = [
"meta-llama/Llama-2-7b-hf",
"Qwen/Qwen2-7B-Instruct", # Note: Qwen2, not Qwen2-VL
"mistralai/Mistral-7B-v0.1",
"google/gemma-7b",
"unsloth/gemma-2-9b-it",
]

for model_name in text_only_models:
is_vl_by_name = "VL" in model_name.upper() or "vision" in model_name.lower()
assert not is_vl_by_name, f"Expected '{model_name}' to NOT be detected as VL"

def test_vl_detection_by_config_attribute(self) -> None:
"""Test VL detection via vision_config attribute."""
# Mock a config object with vision_config
mock_config_vl = MagicMock()
mock_config_vl.vision_config = {"hidden_size": 1024}

assert hasattr(mock_config_vl, "vision_config")

# Mock a config object without vision_config
mock_config_text = MagicMock(spec=["model_type", "hidden_size"])

assert not hasattr(mock_config_text, "vision_config")

def test_vl_detection_logic_comprehensive(self) -> None:
"""Test the complete VL detection logic used in _load_standard_model.

This replicates the exact detection logic from the function to ensure
it correctly identifies VL vs text-only models.
"""
def is_vl_model(model_name: str, has_vision_config: bool) -> bool:
"""Replicate the detection logic from _load_standard_model."""
return (
"VL" in model_name.upper()
or "vision" in model_name.lower()
or has_vision_config
)

# VL models detected by name
assert is_vl_model("Qwen/Qwen2-VL-7B-Instruct", False)
assert is_vl_model("Qwen/Qwen2.5-VL-7B-Instruct", False)
assert is_vl_model("unsloth/Qwen2.5-VL-7B-Instruct", False)
assert is_vl_model("some-model-vl-base", False)

# VL models detected by "vision" in name
assert is_vl_model("llava-vision-7b", False)
assert is_vl_model("VisionTransformer-base", False)

# VL models detected by config attribute
assert is_vl_model("some-random-model", True) # has vision_config

# Text-only models (not detected as VL)
assert not is_vl_model("meta-llama/Llama-2-7b-hf", False)
assert not is_vl_model("Qwen/Qwen2-7B-Instruct", False)
assert not is_vl_model("mistralai/Mistral-7B-v0.1", False)
assert not is_vl_model("google/gemma-7b", False)

def test_lora_task_type_selection(self) -> None:
"""Test that correct LoRA task type is selected based on model type.

VL models should use SEQ_2_SEQ_LM, text-only should use CAUSAL_LM.
"""
def get_task_type(is_vl: bool) -> str:
"""Replicate the task type selection from _load_standard_model."""
return "SEQ_2_SEQ_LM" if is_vl else "CAUSAL_LM"

assert get_task_type(True) == "SEQ_2_SEQ_LM"
assert get_task_type(False) == "CAUSAL_LM"
Loading