Skip to content
Closed
11 changes: 9 additions & 2 deletions libs/langchain_v1/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,16 @@ def _init_chat_model_helper(
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint

return ChatHuggingFace(model_id=model, **kwargs)
# Build a HuggingFaceEndpoint from the model id and wrap it in ChatHuggingFace.
# ChatHuggingFace expects an underlying HF LLM in `llm`, not a raw model_id.
llm = HuggingFaceEndpoint(
repo_id=model,
task="text-generation",
**kwargs,
)
return ChatHuggingFace(llm=llm)
if model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
Expand Down
32 changes: 30 additions & 2 deletions libs/langchain_v1/tests/unit_tests/chat_models/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from unittest import mock

import pytest
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_init_missing_dep() -> None:


def test_init_unknown_provider() -> None:
with pytest.raises(ValueError, match="Unsupported model_provider='bar'."):
with pytest.raises(ValueError, match=r"Unsupported model_provider='bar'\."):
init_chat_model("foo", model_provider="bar")


Expand Down Expand Up @@ -285,3 +285,31 @@ def test_configurable_with_default() -> None:
prompt = ChatPromptTemplate.from_messages([("system", "foo")])
chain = prompt | model_with_config
assert isinstance(chain, RunnableSequence)


def test_init_chat_model_huggingface(monkeypatch: Any) -> None:
lhf = pytest.importorskip("langchain_huggingface")

chat_huggingface_cls = lhf.ChatHuggingFace
huggingface_endpoint_cls = lhf.HuggingFaceEndpoint

created: dict[str, Any] = {}

class DummyHFEndpoint(huggingface_endpoint_cls): # type: ignore[misc, valid-type]
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(**kwargs)
created["args"] = args
created["kwargs"] = kwargs

monkeypatch.setattr(lhf, "HuggingFaceEndpoint", DummyHFEndpoint)

model = init_chat_model(
model="microsoft/Phi-3-mini-4k-instruct",
model_provider="huggingface",
temperature=0,
)

assert isinstance(model, chat_huggingface_cls)
assert created["kwargs"]["repo_id"] == "microsoft/Phi-3-mini-4k-instruct"
assert created["kwargs"]["task"] == "text-generation"
assert created["kwargs"]["temperature"] == 0
5 changes: 4 additions & 1 deletion libs/langchain_v1/tests/unit_tests/embeddings/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def test_infer_model_and_provider_errors() -> None:
_infer_model_and_provider("model", provider="")

# Test invalid provider
with pytest.raises(ValueError, match="Provider 'invalid' is not supported.") as exc:
with pytest.raises(
ValueError,
match=r"Provider 'invalid' is not supported\.",
) as exc:
_infer_model_and_provider("model", provider="invalid")
# Test provider list is in error
for provider in _SUPPORTED_PROVIDERS:
Expand Down