@@ -773,6 +773,203 @@ def test_topological_sort(g):
773773 assert tpsort_test .tests_in_file .test_file .resolve () == test_file_path .resolve ()
774774 assert tpsort_test .tests_in_file .test_function == "test_topological_sort"
775775
776+ def test_unittest_discovery_with_pytest_class_fixture ():
777+ with tempfile .TemporaryDirectory () as tmpdirname :
778+ path_obj_tmpdirname = Path (tmpdirname )
779+
780+ # Create a simple code file
781+ code_file_path = path_obj_tmpdirname / "router_file.py"
782+ code_file_content = """
783+ from __future__ import annotations
784+
785+ import hashlib
786+ import json
787+ from typing import Any, Dict, List, Optional, Required, TypedDict, Union # noqa: UP035
788+
789+
790+ class LiteLLMParamsTypedDict(TypedDict, total=False):
791+ model: str
792+ custom_llm_provider: Optional[str]
793+ tpm: Optional[int]
794+ rpm: Optional[int]
795+ order: Optional[int]
796+ weight: Optional[int]
797+ max_parallel_requests: Optional[int]
798+ api_key: Optional[str]
799+ api_base: Optional[str]
800+ api_version: Optional[str]
801+ stream_timeout: Optional[Union[float, str]]
802+ max_retries: Optional[int]
803+ organization: Optional[Union[list, str]] # for openai orgs
804+ ## DROP PARAMS ##
805+ drop_params: Optional[bool]
806+ ## UNIFIED PROJECT/REGION ##
807+ region_name: Optional[str]
808+ ## VERTEX AI ##
809+ vertex_project: Optional[str]
810+ vertex_location: Optional[str]
811+ ## AWS BEDROCK / SAGEMAKER ##
812+ aws_access_key_id: Optional[str]
813+ aws_secret_access_key: Optional[str]
814+ aws_region_name: Optional[str]
815+ ## IBM WATSONX ##
816+ watsonx_region_name: Optional[str]
817+ ## CUSTOM PRICING ##
818+ input_cost_per_token: Optional[float]
819+ output_cost_per_token: Optional[float]
820+ input_cost_per_second: Optional[float]
821+ output_cost_per_second: Optional[float]
822+ num_retries: Optional[int]
823+ ## MOCK RESPONSES ##
824+
825+ # routing params
826+ # use this for tag-based routing
827+ tags: Optional[list[str]]
828+
829+ # deployment budgets
830+ max_budget: Optional[float]
831+ budget_duration: Optional[str]
832+
833+ class DeploymentTypedDict(TypedDict, total=False):
834+ model_name: Required[str]
835+ litellm_params: Required[LiteLLMParamsTypedDict]
836+ model_info: dict
837+
838+ class Router:
839+ model_names: set = set() # noqa: RUF012
840+ cache_responses: Optional[bool] = False
841+ default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
842+ tenacity = None
843+
844+ def __init__( # noqa: PLR0915
845+ self,
846+ model_list: Optional[
847+ Union[list[DeploymentTypedDict], list[dict[str, Any]]]
848+ ] = None,
849+ ) -> None:
850+ self.model_list = model_list # noqa: ARG002
851+ self.model_id_to_deployment_index_map: dict[str, int] = {}
852+ self.model_name_to_deployment_indices: dict[str, list[int]] = {}
853+ def _generate_model_id(self, model_group: str, litellm_params: dict): # noqa: ANN202
854+ # Optimized: Use list and join instead of string concatenation in loop
855+ # This avoids creating many temporary string objects (O(n) vs O(n²) complexity)
856+ parts = [model_group]
857+ for k, v in litellm_params.items():
858+ if isinstance(k, str):
859+ parts.append(k)
860+ elif isinstance(k, dict):
861+ parts.append(json.dumps(k))
862+ else:
863+ parts.append(str(k))
864+
865+ if isinstance(v, str):
866+ parts.append(v)
867+ elif isinstance(v, dict):
868+ parts.append(json.dumps(v))
869+ else:
870+ parts.append(str(v))
871+
872+ concat_str = "".join(parts)
873+ hash_object = hashlib.sha256(concat_str.encode())
874+
875+ return hash_object.hexdigest()
876+ def _add_model_to_list_and_index_map(
877+ self, model: dict, model_id: Optional[str] = None
878+ ) -> None:
879+ idx = len(self.model_list)
880+ self.model_list.append(model)
881+
882+ # Update model_id index for O(1) lookup
883+ if model_id is not None:
884+ self.model_id_to_deployment_index_map[model_id] = idx
885+ elif model.get("model_info", {}).get("id") is not None:
886+ self.model_id_to_deployment_index_map[model["model_info"]["id"]] = idx
887+
888+ # Update model_name index for O(1) lookup
889+ model_name = model.get("model_name")
890+ if model_name:
891+ if model_name not in self.model_name_to_deployment_indices:
892+ self.model_name_to_deployment_indices[model_name] = []
893+ self.model_name_to_deployment_indices[model_name].append(idx)
894+
895+ def _build_model_id_to_deployment_index_map(self, model_list: list) -> None:
896+ # First populate the model_list
897+ self.model_list = []
898+ for _, model in enumerate(model_list):
899+ # Extract model_info from the model dict
900+ model_info = model.get("model_info", {})
901+ model_id = model_info.get("id")
902+
903+ # If no ID exists, generate one using the same logic as set_model_list
904+ if model_id is None:
905+ model_name = model.get("model_name", "")
906+ litellm_params = model.get("litellm_params", {})
907+ model_id = self._generate_model_id(model_name, litellm_params)
908+ # Update the model_info in the original list
909+ if "model_info" not in model:
910+ model["model_info"] = {}
911+ model["model_info"]["id"] = model_id
912+
913+ self._add_model_to_list_and_index_map(model=model, model_id=model_id)
914+ """
915+ code_file_path .write_text (code_file_content )
916+
917+ # Create a unittest test file with parameterized tests
918+ test_file_path = path_obj_tmpdirname / "test_router_file.py"
919+ test_file_content = """
920+ import pytest
921+
922+ from router_file import Router
923+
924+
925+ class TestRouterIndexManagement:
926+ @pytest.fixture
927+ def router(self): # noqa: ANN201
928+ return Router(model_list=[])
929+ def test_build_model_id_to_deployment_index_map(self, router) -> None: # noqa: ANN001
930+ model_list = [
931+ {
932+ "model_name": "gpt-3.5-turbo",
933+ "litellm_params": {"model": "gpt-3.5-turbo"},
934+ "model_info": {"id": "model-1"},
935+ },
936+ {
937+ "model_name": "gpt-4",
938+ "litellm_params": {"model": "gpt-4"},
939+ "model_info": {"id": "model-2"},
940+ },
941+ ]
942+
943+ # Test: Build index from model list
944+ router._build_model_id_to_deployment_index_map(model_list) # noqa: SLF001
945+
946+ # Verify: model_list is populated
947+ assert len(router.model_list) == 2
948+ # Verify: model_id_to_deployment_index_map is correctly built
949+ assert router.model_id_to_deployment_index_map["model-1"] == 0
950+ assert router.model_id_to_deployment_index_map["model-2"] == 1
951+ """
952+ test_file_path .write_text (test_file_content )
953+
954+ # Configure test discovery
955+ test_config = TestConfig (
956+ tests_root = path_obj_tmpdirname ,
957+ project_root_path = path_obj_tmpdirname ,
958+ test_framework = "pytest" , # Using pytest framework to discover unittest tests
959+ tests_project_rootdir = path_obj_tmpdirname .parent ,
960+ )
961+ fto = FunctionToOptimize (function_name = "_build_model_id_to_deployment_index_map" , file_path = code_file_path , parents = [FunctionParent (name = "Router" , type = "ClassDef" )])
962+ # Discover tests
963+ discovered_tests , _ , _ = discover_unit_tests (test_config , file_to_funcs_to_optimize = {code_file_path : [fto ]})
964+
965+ # Verify the unittest was discovered
966+ assert len (discovered_tests ) == 1
967+ assert "router_file.Router._build_model_id_to_deployment_index_map" in discovered_tests
968+ assert len (discovered_tests ["router_file.Router._build_model_id_to_deployment_index_map" ]) == 1
969+ router_test = next (iter (discovered_tests ["router_file.Router._build_model_id_to_deployment_index_map" ]))
970+ assert router_test .tests_in_file .test_file .resolve () == test_file_path .resolve ()
971+ assert router_test .tests_in_file .test_function == "test_build_model_id_to_deployment_index_map"
972+
776973
777974def test_unittest_discovery_with_pytest_parameterized ():
778975 with tempfile .TemporaryDirectory () as tmpdirname :
@@ -1422,6 +1619,49 @@ def test_topological_sort(g):
14221619
14231620 assert should_process is True
14241621
1622+ def test_analyze_imports_class_fixture ():
1623+ with tempfile .TemporaryDirectory () as tmpdirname :
1624+ test_file = Path (tmpdirname ) / "test_example.py"
1625+ test_content = """
1626+ import pytest
1627+
1628+ from router_file import Router
1629+
1630+
1631+ class TestRouterIndexManagement:
1632+ @pytest.fixture
1633+ def router(self): # noqa: ANN201
1634+ return Router(model_list=[])
1635+ def test_build_model_id_to_deployment_index_map(self, router) -> None: # noqa: ANN001
1636+ model_list = [
1637+ {
1638+ "model_name": "gpt-3.5-turbo",
1639+ "litellm_params": {"model": "gpt-3.5-turbo"},
1640+ "model_info": {"id": "model-1"},
1641+ },
1642+ {
1643+ "model_name": "gpt-4",
1644+ "litellm_params": {"model": "gpt-4"},
1645+ "model_info": {"id": "model-2"},
1646+ },
1647+ ]
1648+
1649+ # Test: Build index from model list
1650+ router._build_model_id_to_deployment_index_map(model_list) # noqa: SLF001
1651+
1652+ # Verify: model_list is populated
1653+ assert len(router.model_list) == 2
1654+ # Verify: model_id_to_deployment_index_map is correctly built
1655+ assert router.model_id_to_deployment_index_map["model-1"] == 0
1656+ assert router.model_id_to_deployment_index_map["model-2"] == 1
1657+ """
1658+ test_file .write_text (test_content )
1659+
1660+ target_functions = {"Router._build_model_id_to_deployment_index_map" }
1661+ should_process = analyze_imports_in_test_file (test_file , target_functions )
1662+
1663+ assert should_process is True
1664+
14251665def test_analyze_imports_aliased_class_method_negative ():
14261666 with tempfile .TemporaryDirectory () as tmpdirname :
14271667 test_file = Path (tmpdirname ) / "test_example.py"
0 commit comments