diff --git a/libs/langchain_v1/langchain/chat_models/base.py b/libs/langchain_v1/langchain/chat_models/base.py index c731c15f63252..0908f7936a23e 100644 --- a/libs/langchain_v1/langchain/chat_models/base.py +++ b/libs/langchain_v1/langchain/chat_models/base.py @@ -403,9 +403,16 @@ def _init_chat_model_helper( return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore] if model_provider == "huggingface": _check_pkg("langchain_huggingface") - from langchain_huggingface import ChatHuggingFace + from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint - return ChatHuggingFace(model_id=model, **kwargs) + # Build a HuggingFaceEndpoint from the model id and wrap it in ChatHuggingFace. + # ChatHuggingFace expects an underlying HF LLM in `llm`, not a raw model_id. + llm = HuggingFaceEndpoint( + repo_id=model, + task="text-generation", + **kwargs, + ) + return ChatHuggingFace(llm=llm) if model_provider == "groq": _check_pkg("langchain_groq") from langchain_groq import ChatGroq diff --git a/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py b/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py index c4a95c73d44fa..5ecc34de9fec6 100644 --- a/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py +++ b/libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from unittest import mock import pytest @@ -57,7 +57,7 @@ def test_init_missing_dep() -> None: def test_init_unknown_provider() -> None: - with pytest.raises(ValueError, match="Unsupported model_provider='bar'."): + with pytest.raises(ValueError, match=r"Unsupported model_provider='bar'\."): init_chat_model("foo", model_provider="bar") @@ -285,3 +285,31 @@ def test_configurable_with_default() -> None: prompt = ChatPromptTemplate.from_messages([("system", "foo")]) chain = prompt | model_with_config assert isinstance(chain, RunnableSequence) + + +def test_init_chat_model_huggingface(monkeypatch: Any) -> None: + lhf = pytest.importorskip("langchain_huggingface") + + chat_huggingface_cls = lhf.ChatHuggingFace + huggingface_endpoint_cls = lhf.HuggingFaceEndpoint + + created: dict[str, Any] = {} + + class DummyHFEndpoint(huggingface_endpoint_cls): # type: ignore[misc, valid-type] + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(**kwargs) + created["args"] = args + created["kwargs"] = kwargs + + monkeypatch.setattr(lhf, "HuggingFaceEndpoint", DummyHFEndpoint) + + model = init_chat_model( + model="microsoft/Phi-3-mini-4k-instruct", + model_provider="huggingface", + temperature=0, + ) + + assert isinstance(model, chat_huggingface_cls) + assert created["kwargs"]["repo_id"] == "microsoft/Phi-3-mini-4k-instruct" + assert created["kwargs"]["task"] == "text-generation" + assert created["kwargs"]["temperature"] == 0 diff --git a/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py b/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py index 30bfaeb67773f..85c16aaab905c 100644 --- a/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py +++ b/libs/langchain_v1/tests/unit_tests/embeddings/test_base.py @@ -88,7 +88,10 @@ def test_infer_model_and_provider_errors() -> None: _infer_model_and_provider("model", provider="") # Test invalid provider - with pytest.raises(ValueError, match="Provider 'invalid' is not supported.") as exc: + with pytest.raises( + ValueError, + match=r"Provider 'invalid' is not supported\.", + ) as exc: _infer_model_and_provider("model", provider="invalid") # Test provider list is in error for provider in _SUPPORTED_PROVIDERS: