Skip to content

Commit a6f7d5a

Browse files
author
Codeflash Bot
committed
more tests
1 parent 86a8d20 commit a6f7d5a

File tree

1 file changed

+240
-0
lines changed

1 file changed

+240
-0
lines changed

tests/test_unit_test_discovery.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,203 @@ def test_topological_sort(g):
773773
assert tpsort_test.tests_in_file.test_file.resolve() == test_file_path.resolve()
774774
assert tpsort_test.tests_in_file.test_function == "test_topological_sort"
775775

776+
def test_unittest_discovery_with_pytest_class_fixture():
777+
with tempfile.TemporaryDirectory() as tmpdirname:
778+
path_obj_tmpdirname = Path(tmpdirname)
779+
780+
# Create a simple code file
781+
code_file_path = path_obj_tmpdirname / "router_file.py"
782+
code_file_content = """
783+
from __future__ import annotations
784+
785+
import hashlib
786+
import json
787+
from typing import Any, Dict, List, Optional, Required, TypedDict, Union # noqa: UP035
788+
789+
790+
class LiteLLMParamsTypedDict(TypedDict, total=False):
791+
model: str
792+
custom_llm_provider: Optional[str]
793+
tpm: Optional[int]
794+
rpm: Optional[int]
795+
order: Optional[int]
796+
weight: Optional[int]
797+
max_parallel_requests: Optional[int]
798+
api_key: Optional[str]
799+
api_base: Optional[str]
800+
api_version: Optional[str]
801+
stream_timeout: Optional[Union[float, str]]
802+
max_retries: Optional[int]
803+
organization: Optional[Union[list, str]] # for openai orgs
804+
## DROP PARAMS ##
805+
drop_params: Optional[bool]
806+
## UNIFIED PROJECT/REGION ##
807+
region_name: Optional[str]
808+
## VERTEX AI ##
809+
vertex_project: Optional[str]
810+
vertex_location: Optional[str]
811+
## AWS BEDROCK / SAGEMAKER ##
812+
aws_access_key_id: Optional[str]
813+
aws_secret_access_key: Optional[str]
814+
aws_region_name: Optional[str]
815+
## IBM WATSONX ##
816+
watsonx_region_name: Optional[str]
817+
## CUSTOM PRICING ##
818+
input_cost_per_token: Optional[float]
819+
output_cost_per_token: Optional[float]
820+
input_cost_per_second: Optional[float]
821+
output_cost_per_second: Optional[float]
822+
num_retries: Optional[int]
823+
## MOCK RESPONSES ##
824+
825+
# routing params
826+
# use this for tag-based routing
827+
tags: Optional[list[str]]
828+
829+
# deployment budgets
830+
max_budget: Optional[float]
831+
budget_duration: Optional[str]
832+
833+
class DeploymentTypedDict(TypedDict, total=False):
834+
model_name: Required[str]
835+
litellm_params: Required[LiteLLMParamsTypedDict]
836+
model_info: dict
837+
838+
class Router:
839+
model_names: set = set() # noqa: RUF012
840+
cache_responses: Optional[bool] = False
841+
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
842+
tenacity = None
843+
844+
def __init__( # noqa: PLR0915
845+
self,
846+
model_list: Optional[
847+
Union[list[DeploymentTypedDict], list[dict[str, Any]]]
848+
] = None,
849+
) -> None:
850+
self.model_list = model_list # noqa: ARG002
851+
self.model_id_to_deployment_index_map: dict[str, int] = {}
852+
self.model_name_to_deployment_indices: dict[str, list[int]] = {}
853+
def _generate_model_id(self, model_group: str, litellm_params: dict): # noqa: ANN202
854+
# Optimized: Use list and join instead of string concatenation in loop
855+
# This avoids creating many temporary string objects (O(n) vs O(n²) complexity)
856+
parts = [model_group]
857+
for k, v in litellm_params.items():
858+
if isinstance(k, str):
859+
parts.append(k)
860+
elif isinstance(k, dict):
861+
parts.append(json.dumps(k))
862+
else:
863+
parts.append(str(k))
864+
865+
if isinstance(v, str):
866+
parts.append(v)
867+
elif isinstance(v, dict):
868+
parts.append(json.dumps(v))
869+
else:
870+
parts.append(str(v))
871+
872+
concat_str = "".join(parts)
873+
hash_object = hashlib.sha256(concat_str.encode())
874+
875+
return hash_object.hexdigest()
876+
def _add_model_to_list_and_index_map(
877+
self, model: dict, model_id: Optional[str] = None
878+
) -> None:
879+
idx = len(self.model_list)
880+
self.model_list.append(model)
881+
882+
# Update model_id index for O(1) lookup
883+
if model_id is not None:
884+
self.model_id_to_deployment_index_map[model_id] = idx
885+
elif model.get("model_info", {}).get("id") is not None:
886+
self.model_id_to_deployment_index_map[model["model_info"]["id"]] = idx
887+
888+
# Update model_name index for O(1) lookup
889+
model_name = model.get("model_name")
890+
if model_name:
891+
if model_name not in self.model_name_to_deployment_indices:
892+
self.model_name_to_deployment_indices[model_name] = []
893+
self.model_name_to_deployment_indices[model_name].append(idx)
894+
895+
def _build_model_id_to_deployment_index_map(self, model_list: list) -> None:
896+
# First populate the model_list
897+
self.model_list = []
898+
for _, model in enumerate(model_list):
899+
# Extract model_info from the model dict
900+
model_info = model.get("model_info", {})
901+
model_id = model_info.get("id")
902+
903+
# If no ID exists, generate one using the same logic as set_model_list
904+
if model_id is None:
905+
model_name = model.get("model_name", "")
906+
litellm_params = model.get("litellm_params", {})
907+
model_id = self._generate_model_id(model_name, litellm_params)
908+
# Update the model_info in the original list
909+
if "model_info" not in model:
910+
model["model_info"] = {}
911+
model["model_info"]["id"] = model_id
912+
913+
self._add_model_to_list_and_index_map(model=model, model_id=model_id)
914+
"""
915+
code_file_path.write_text(code_file_content)
916+
917+
# Create a unittest test file with parameterized tests
918+
test_file_path = path_obj_tmpdirname / "test_router_file.py"
919+
test_file_content = """
920+
import pytest
921+
922+
from router_file import Router
923+
924+
925+
class TestRouterIndexManagement:
926+
@pytest.fixture
927+
def router(self): # noqa: ANN201
928+
return Router(model_list=[])
929+
def test_build_model_id_to_deployment_index_map(self, router) -> None: # noqa: ANN001
930+
model_list = [
931+
{
932+
"model_name": "gpt-3.5-turbo",
933+
"litellm_params": {"model": "gpt-3.5-turbo"},
934+
"model_info": {"id": "model-1"},
935+
},
936+
{
937+
"model_name": "gpt-4",
938+
"litellm_params": {"model": "gpt-4"},
939+
"model_info": {"id": "model-2"},
940+
},
941+
]
942+
943+
# Test: Build index from model list
944+
router._build_model_id_to_deployment_index_map(model_list) # noqa: SLF001
945+
946+
# Verify: model_list is populated
947+
assert len(router.model_list) == 2
948+
# Verify: model_id_to_deployment_index_map is correctly built
949+
assert router.model_id_to_deployment_index_map["model-1"] == 0
950+
assert router.model_id_to_deployment_index_map["model-2"] == 1
951+
"""
952+
test_file_path.write_text(test_file_content)
953+
954+
# Configure test discovery
955+
test_config = TestConfig(
956+
tests_root=path_obj_tmpdirname,
957+
project_root_path=path_obj_tmpdirname,
958+
test_framework="pytest", # Using pytest framework to discover unittest tests
959+
tests_project_rootdir=path_obj_tmpdirname.parent,
960+
)
961+
fto = FunctionToOptimize(function_name="_build_model_id_to_deployment_index_map", file_path=code_file_path, parents=[FunctionParent(name="Router", type="ClassDef")])
962+
# Discover tests
963+
discovered_tests, _, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file_path: [fto]})
964+
965+
# Verify the unittest was discovered
966+
assert len(discovered_tests) == 1
967+
assert "router_file.Router._build_model_id_to_deployment_index_map" in discovered_tests
968+
assert len(discovered_tests["router_file.Router._build_model_id_to_deployment_index_map"]) == 1
969+
router_test = next(iter(discovered_tests["router_file.Router._build_model_id_to_deployment_index_map"]))
970+
assert router_test.tests_in_file.test_file.resolve() == test_file_path.resolve()
971+
assert router_test.tests_in_file.test_function == "test_build_model_id_to_deployment_index_map"
972+
776973

777974
def test_unittest_discovery_with_pytest_parameterized():
778975
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -1422,6 +1619,49 @@ def test_topological_sort(g):
14221619

14231620
assert should_process is True
14241621

1622+
def test_analyze_imports_class_fixture():
1623+
with tempfile.TemporaryDirectory() as tmpdirname:
1624+
test_file = Path(tmpdirname) / "test_example.py"
1625+
test_content = """
1626+
import pytest
1627+
1628+
from router_file import Router
1629+
1630+
1631+
class TestRouterIndexManagement:
1632+
@pytest.fixture
1633+
def router(self): # noqa: ANN201
1634+
return Router(model_list=[])
1635+
def test_build_model_id_to_deployment_index_map(self, router) -> None: # noqa: ANN001
1636+
model_list = [
1637+
{
1638+
"model_name": "gpt-3.5-turbo",
1639+
"litellm_params": {"model": "gpt-3.5-turbo"},
1640+
"model_info": {"id": "model-1"},
1641+
},
1642+
{
1643+
"model_name": "gpt-4",
1644+
"litellm_params": {"model": "gpt-4"},
1645+
"model_info": {"id": "model-2"},
1646+
},
1647+
]
1648+
1649+
# Test: Build index from model list
1650+
router._build_model_id_to_deployment_index_map(model_list) # noqa: SLF001
1651+
1652+
# Verify: model_list is populated
1653+
assert len(router.model_list) == 2
1654+
# Verify: model_id_to_deployment_index_map is correctly built
1655+
assert router.model_id_to_deployment_index_map["model-1"] == 0
1656+
assert router.model_id_to_deployment_index_map["model-2"] == 1
1657+
"""
1658+
test_file.write_text(test_content)
1659+
1660+
target_functions = {"Router._build_model_id_to_deployment_index_map"}
1661+
should_process = analyze_imports_in_test_file(test_file, target_functions)
1662+
1663+
assert should_process is True
1664+
14251665
def test_analyze_imports_aliased_class_method_negative():
14261666
with tempfile.TemporaryDirectory() as tmpdirname:
14271667
test_file = Path(tmpdirname) / "test_example.py"

0 commit comments

Comments
 (0)