@@ -14,11 +14,18 @@ def test_init(self):
1414 # Check that it's a valid instance
1515 assert isinstance (tools , RetrieverTools )
1616
17+ @patch ("src.agents.retriever_tools.HuggingFaceCrossEncoder" )
18+ @patch ("src.agents.retriever_tools.RetrieverTools._create_embedding_model" )
1719 @patch ("src.agents.retriever_tools.HybridRetrieverChain" )
18- def test_initialize_success (self , mock_hybrid_chain ):
20+ def test_initialize_success (
21+ self , mock_hybrid_chain , mock_create_embed , mock_cross_encoder
22+ ):
1923 """Test successful initialization of all retrievers."""
2024 tools = RetrieverTools ()
2125
26+ mock_create_embed .return_value = Mock ()
27+ mock_cross_encoder .return_value = Mock ()
28+
2229 # Mock the HybridRetrieverChain instances
2330 mock_chains = []
2431 for i in range (
@@ -55,11 +62,18 @@ def test_initialize_success(self, mock_hybrid_chain):
5562 assert RetrieverTools .klayout_retriever == mock_chains [4 ].retriever
5663 assert RetrieverTools .errinfo_retriever == mock_chains [5 ].retriever
5764
65+ @patch ("src.agents.retriever_tools.HuggingFaceCrossEncoder" )
66+ @patch ("src.agents.retriever_tools.RetrieverTools._create_embedding_model" )
5867 @patch ("src.agents.retriever_tools.HybridRetrieverChain" )
59- def test_initialize_with_fast_mode (self , mock_hybrid_chain ):
68+ def test_initialize_with_fast_mode (
69+ self , mock_hybrid_chain , mock_create_embed , mock_cross_encoder
70+ ):
6071 """Test initialization with fast mode enabled."""
6172 tools = RetrieverTools ()
6273
74+ mock_create_embed .return_value = Mock ()
75+ mock_cross_encoder .return_value = Mock ()
76+
6377 # Mock the HybridRetrieverChain instances
6478 mock_chains = []
6579 for i in range (6 ):
@@ -250,11 +264,18 @@ def test_retrieve_klayout_docs_not_initialized(self):
250264 with pytest .raises (ValueError , match = "KLayout Retriever not initialized" ):
251265 RetrieverTools .retrieve_klayout_docs .invoke (input = "test query" )
252266
267+ @patch ("src.agents.retriever_tools.HuggingFaceCrossEncoder" )
268+ @patch ("src.agents.retriever_tools.RetrieverTools._create_embedding_model" )
253269 @patch ("src.agents.retriever_tools.HybridRetrieverChain" )
254- def test_initialize_verifies_configuration_parameters (self , mock_hybrid_chain ):
270+ def test_initialize_verifies_configuration_parameters (
271+ self , mock_hybrid_chain , mock_create_embed , mock_cross_encoder
272+ ):
255273 """Test that initialize passes correct configuration parameters."""
256274 tools = RetrieverTools ()
257275
276+ mock_create_embed .return_value = Mock ()
277+ mock_cross_encoder .return_value = Mock ()
278+
258279 # Mock the HybridRetrieverChain instances
259280 mock_chains = []
260281 for i in range (6 ):
@@ -283,11 +304,18 @@ def test_initialize_verifies_configuration_parameters(self, mock_hybrid_chain):
283304 assert kwargs ["weights" ] == [0.6 , 0.2 , 0.2 ]
284305 assert kwargs ["contextual_rerank" ] is True
285306
307+ @patch ("src.agents.retriever_tools.HuggingFaceCrossEncoder" )
308+ @patch ("src.agents.retriever_tools.RetrieverTools._create_embedding_model" )
286309 @patch ("src.agents.retriever_tools.HybridRetrieverChain" )
287- def test_initialize_with_environment_variables (self , mock_hybrid_chain ):
310+ def test_initialize_with_environment_variables (
311+ self , mock_hybrid_chain , mock_create_embed , mock_cross_encoder
312+ ):
288313 """Test initialization respects environment variables."""
289314 tools = RetrieverTools ()
290315
316+ mock_create_embed .return_value = Mock ()
317+ mock_cross_encoder .return_value = Mock ()
318+
291319 # Mock the HybridRetrieverChain instances
292320 mock_chains = []
293321 for i in range (6 ):
@@ -323,11 +351,18 @@ def test_tool_decorators_applied(self):
323351 assert hasattr (RetrieverTools .retrieve_yosys_rtdocs , "name" )
324352 assert hasattr (RetrieverTools .retrieve_klayout_docs , "name" )
325353
354+ @patch ("src.agents.retriever_tools.HuggingFaceCrossEncoder" )
355+ @patch ("src.agents.retriever_tools.RetrieverTools._create_embedding_model" )
326356 @patch ("src.agents.retriever_tools.HybridRetrieverChain" )
327- def test_different_docs_paths_for_retrievers (self , mock_hybrid_chain ):
357+ def test_different_docs_paths_for_retrievers (
358+ self , mock_hybrid_chain , mock_create_embed , mock_cross_encoder
359+ ):
328360 """Test that different retrievers use different document paths."""
329361 tools = RetrieverTools ()
330362
363+ mock_create_embed .return_value = Mock ()
364+ mock_cross_encoder .return_value = Mock ()
365+
331366 # Mock the HybridRetrieverChain instances
332367 mock_chains = []
333368 for i in range (6 ):
@@ -369,11 +404,18 @@ def test_different_docs_paths_for_retrievers(self, mock_hybrid_chain):
369404 # Errinfo should have error-specific paths
370405 assert any ("man3" in path for path in errinfo_paths )
371406
407+ @patch ("src.agents.retriever_tools.HuggingFaceCrossEncoder" )
408+ @patch ("src.agents.retriever_tools.RetrieverTools._create_embedding_model" )
372409 @patch ("src.agents.retriever_tools.HybridRetrieverChain" )
373- def test_html_docs_configuration (self , mock_hybrid_chain ):
410+ def test_html_docs_configuration (
411+ self , mock_hybrid_chain , mock_create_embed , mock_cross_encoder
412+ ):
374413 """Test HTML docs configuration for specific retrievers."""
375414 tools = RetrieverTools ()
376415
416+ mock_create_embed .return_value = Mock ()
417+ mock_cross_encoder .return_value = Mock ()
418+
377419 # Mock the HybridRetrieverChain instances
378420 mock_chains = []
379421 for i in range (6 ):
@@ -426,11 +468,18 @@ def test_staticmethod_decorators(self):
426468 result = RetrieverTools .retrieve_general .invoke (input = "test" )
427469 assert result == ("" , [], [], [])
428470
471+ @patch ("src.agents.retriever_tools.HuggingFaceCrossEncoder" )
472+ @patch ("src.agents.retriever_tools.RetrieverTools._create_embedding_model" )
429473 @patch ("src.agents.retriever_tools.HybridRetrieverChain" )
430- def test_retriever_chain_create_hybrid_retriever_called (self , mock_hybrid_chain ):
474+ def test_retriever_chain_create_hybrid_retriever_called (
475+ self , mock_hybrid_chain , mock_create_embed , mock_cross_encoder
476+ ):
431477 """Test that create_hybrid_retriever is called on all chains."""
432478 tools = RetrieverTools ()
433479
480+ mock_create_embed .return_value = Mock ()
481+ mock_cross_encoder .return_value = Mock ()
482+
434483 # Mock the HybridRetrieverChain instances
435484 mock_chains = []
436485 for i in range (6 ):
0 commit comments