diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..e78e6b4 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -7,7 +7,7 @@ import srsly import torch -from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor +from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor, ColIdefics3, ColIdefics3Processor from pdf2image import convert_from_path from PIL import Image @@ -35,9 +35,10 @@ def __init__( if ( "colpali" not in pretrained_model_name_or_path.lower() and "colqwen2" not in pretrained_model_name_or_path.lower() + and "colsmolvlm" not in pretrained_model_name_or_path.lower() ): raise ValueError( - "This pre-release version of Byaldi only supports ColPali and ColQwen2 for now. Incorrect model name specified." + "This pre-release version of Byaldi only supports ColPali, ColQwen2 and ColSmolVLM for now. Incorrect model name specified." ) if verbose > 0: @@ -89,6 +90,18 @@ def __init__( ), token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), ) + elif "colsmolvlm" in pretrained_model_name_or_path.lower(): + self.model = ColIdefics3.from_pretrained( + self.pretrained_model_name_or_path, + torch_dtype=torch.bfloat16, + device_map=( + "cuda" + if device == "cuda" + or (isinstance(device, torch.device) and device.type == "cuda") + else None + ), + token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + ) self.model = self.model.eval() if "colpali" in pretrained_model_name_or_path.lower(): @@ -107,6 +120,14 @@ def __init__( token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), ), ) + elif "colsmolvlm" in pretrained_model_name_or_path.lower(): + self.processor = cast( + ColIdefics3Processor, + ColIdefics3Processor.from_pretrained( + self.pretrained_model_name_or_path, + token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + ), + ) self.device = device if device != "cuda" and not ( diff --git a/pyproject.toml b/pyproject.toml index bf1d624..cfb3b5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ maintainers = [ ] dependencies = [ - "colpali-engine>=0.3.4,<0.4.0", + "colpali-engine>=0.3.5,<0.4.0", "ml-dtypes", "mteb==1.6.35", "ninja", diff --git a/tests/test_colsmolvlm.py b/tests/test_colsmolvlm.py new file mode 100644 index 0000000..0627353 --- /dev/null +++ b/tests/test_colsmolvlm.py @@ -0,0 +1,23 @@ +from typing import Generator + +import pytest +from colpali_engine.models import ColIdefics3 +from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch + +from byaldi import RAGMultiModalModel +from byaldi.colpali import ColPaliModel + + +@pytest.fixture(scope="module") +def colsmolvlm_rag_model() -> Generator[RAGMultiModalModel, None, None]: + device = get_torch_device("auto") + print(f"Using device: {device}") + yield RAGMultiModalModel.from_pretrained("vidore/colsmolvlm-alpha", device=device) + tear_down_torch() + + +@pytest.mark.slow +def test_load_colsmolvlm_from_pretrained(colsmolvlm_rag_model: RAGMultiModalModel): + assert isinstance(colsmolvlm_rag_model, RAGMultiModalModel) + assert isinstance(colsmolvlm_rag_model.model, ColPaliModel) + assert isinstance(colsmolvlm_rag_model.model.model, ColIdefics3)