From 1cbaaa5e11b68738540c196feb97ff36bd4636b4 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 3 Dec 2024 16:52:33 +0100 Subject: [PATCH 1/4] Added support for ColSmolVLM --- byaldi/colpali.py | 25 +++++++++++++++++++++++-- pyproject.toml | 2 +- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..e78e6b4 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -7,7 +7,7 @@ import srsly import torch -from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor +from colpali_engine.models import ColPali, ColPaliProcessor, ColQwen2, ColQwen2Processor, ColIdefics3, ColIdefics3Processor from pdf2image import convert_from_path from PIL import Image @@ -35,9 +35,10 @@ def __init__( if ( "colpali" not in pretrained_model_name_or_path.lower() and "colqwen2" not in pretrained_model_name_or_path.lower() + and "colsmolvlm" not in pretrained_model_name_or_path.lower() ): raise ValueError( - "This pre-release version of Byaldi only supports ColPali and ColQwen2 for now. Incorrect model name specified." + "This pre-release version of Byaldi only supports ColPali, ColQwen2 and ColSmolVLM for now. Incorrect model name specified." ) if verbose > 0: @@ -89,6 +90,18 @@ def __init__( ), token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), ) + elif "colsmolvlm" in pretrained_model_name_or_path.lower(): + self.model = ColIdefics3.from_pretrained( + self.pretrained_model_name_or_path, + torch_dtype=torch.bfloat16, + device_map=( + "cuda" + if device == "cuda" + or (isinstance(device, torch.device) and device.type == "cuda") + else None + ), + token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + ) self.model = self.model.eval() if "colpali" in pretrained_model_name_or_path.lower(): @@ -107,6 +120,14 @@ def __init__( token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), ), ) + elif "colsmolvlm" in pretrained_model_name_or_path.lower(): + self.processor = cast( + ColIdefics3Processor, + ColIdefics3Processor.from_pretrained( + self.pretrained_model_name_or_path, + token=kwargs.get("hf_token", None) or os.environ.get("HF_TOKEN"), + ), + ) self.device = device if device != "cuda" and not ( diff --git a/pyproject.toml b/pyproject.toml index bf1d624..cfb3b5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ maintainers = [ ] dependencies = [ - "colpali-engine>=0.3.4,<0.4.0", + "colpali-engine>=0.3.5,<0.4.0", "ml-dtypes", "mteb==1.6.35", "ninja", From 027cb49e63d6af009e24c70fabd1e2fc647def6a Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Tue, 3 Dec 2024 16:58:18 +0100 Subject: [PATCH 2/4] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cfb3b5a..bf1d624 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ maintainers = [ ] dependencies = [ - "colpali-engine>=0.3.5,<0.4.0", + "colpali-engine>=0.3.4,<0.4.0", "ml-dtypes", "mteb==1.6.35", "ninja", From ae053fcbf50e3ae79c510f2f7f5f2b766355c03a Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Wed, 4 Dec 2024 17:01:20 +0100 Subject: [PATCH 3/4] Added new test --- tests/test_colsmolvlm.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 tests/test_colsmolvlm.py diff --git a/tests/test_colsmolvlm.py b/tests/test_colsmolvlm.py new file mode 100644 index 0000000..0627353 --- /dev/null +++ b/tests/test_colsmolvlm.py @@ -0,0 +1,23 @@ +from typing import Generator + +import pytest +from colpali_engine.models import ColIdefics3 +from colpali_engine.utils.torch_utils import get_torch_device, tear_down_torch + +from byaldi import RAGMultiModalModel +from byaldi.colpali import ColPaliModel + + +@pytest.fixture(scope="module") +def colsmolvlm_rag_model() -> Generator[RAGMultiModalModel, None, None]: + device = get_torch_device("auto") + print(f"Using device: {device}") + yield RAGMultiModalModel.from_pretrained("vidore/colsmolvlm-alpha", device=device) + tear_down_torch() + + +@pytest.mark.slow +def test_load_colsmolvlm_from_pretrained(colsmolvlm_rag_model: RAGMultiModalModel): + assert isinstance(colsmolvlm_rag_model, RAGMultiModalModel) + assert isinstance(colsmolvlm_rag_model.model, ColPaliModel) + assert isinstance(colsmolvlm_rag_model.model.model, ColIdefics3) From 9ad2a372cdd33ef73e6165046064af66444c3078 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Fri, 13 Dec 2024 16:23:48 +0100 Subject: [PATCH 4/4] Updated colpali version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bf1d624..cfb3b5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ maintainers = [ ] dependencies = [ - "colpali-engine>=0.3.4,<0.4.0", + "colpali-engine>=0.3.5,<0.4.0", "ml-dtypes", "mteb==1.6.35", "ninja",