Skip to content

Commit 7d18130

Browse files
authored
Merge pull request #934 from codeflash-ai/cf-820
Pytest discovery with fixtures for class instantiation
2 parents a7b9e85 + 69ef7cc commit 7d18130

File tree

2 files changed

+328
-1
lines changed

2 files changed

+328
-1
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,58 @@ def process_test_files(
868868
continue
869869
try:
870870
if not definition or definition[0].type != "function":
871+
# Fallback: Try to match against functions_to_optimize when Jedi can't resolve
872+
# This handles cases where Jedi fails with pytest fixtures
873+
if functions_to_optimize and name.name:
874+
for func_to_opt in functions_to_optimize:
875+
# Check if this unresolved name matches a function we're looking for
876+
if func_to_opt.function_name == name.name:
877+
# Check if the test file imports the class/module containing this function
878+
qualified_name_with_modules = func_to_opt.qualified_name_with_modules_from_root(
879+
project_root_path
880+
)
881+
882+
# Only add if this test actually tests the function we're optimizing
883+
for test_func in test_functions_by_name[scope]:
884+
if test_func.parameters is not None:
885+
if test_framework == "pytest":
886+
scope_test_function = (
887+
f"{test_func.function_name}[{test_func.parameters}]"
888+
)
889+
else: # unittest
890+
scope_test_function = (
891+
f"{test_func.function_name}_{test_func.parameters}"
892+
)
893+
else:
894+
scope_test_function = test_func.function_name
895+
896+
function_to_test_map[qualified_name_with_modules].add(
897+
FunctionCalledInTest(
898+
tests_in_file=TestsInFile(
899+
test_file=test_file,
900+
test_class=test_func.test_class,
901+
test_function=scope_test_function,
902+
test_type=test_func.test_type,
903+
),
904+
position=CodePosition(line_no=name.line, col_no=name.column),
905+
)
906+
)
907+
tests_cache.insert_test(
908+
file_path=str(test_file),
909+
file_hash=file_hash,
910+
qualified_name_with_modules_from_root=qualified_name_with_modules,
911+
function_name=scope,
912+
test_class=test_func.test_class or "",
913+
test_function=scope_test_function,
914+
test_type=test_func.test_type,
915+
line_number=name.line,
916+
col_number=name.column,
917+
)
918+
919+
if test_func.test_type == TestType.REPLAY_TEST:
920+
num_discovered_replay_tests += 1
921+
922+
num_discovered_tests += 1
871923
continue
872924
definition_obj = definition[0]
873925
definition_path = str(definition_obj.module_path)

tests/test_unit_test_discovery.py

Lines changed: 276 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
filter_test_files_by_imports,
99
)
1010
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
11-
from codeflash.models.models import TestsInFile, TestType
11+
from codeflash.models.models import TestsInFile, TestType, FunctionParent
1212
from codeflash.verification.verification_utils import TestConfig
1313

1414

@@ -714,6 +714,210 @@ def test_add_with_parameters(self):
714714
assert calculator_test.tests_in_file.test_file.resolve() == test_file_path.resolve()
715715
assert calculator_test.tests_in_file.test_function == "test_add_with_parameters"
716716

717+
def test_unittest_discovery_with_pytest_fixture():
718+
with tempfile.TemporaryDirectory() as tmpdirname:
719+
path_obj_tmpdirname = Path(tmpdirname)
720+
721+
# Create a simple code file
722+
code_file_path = path_obj_tmpdirname / "topological_sort.py"
723+
code_file_content = """
724+
import uuid
725+
from collections import defaultdict
726+
727+
728+
class Graph:
729+
def __init__(self, vertices: int):
730+
self.vertices=vertices
731+
732+
def dummy_fn(self):
733+
return 1
734+
735+
def topologicalSort(self):
736+
return self.vertices
737+
738+
"""
739+
code_file_path.write_text(code_file_content)
740+
741+
# Create a unittest test file with parameterized tests
742+
test_file_path = path_obj_tmpdirname / "test_topological_sort.py"
743+
test_file_content = """
744+
from topological_sort import Graph
745+
import pytest
746+
747+
@pytest.fixture
748+
def g():
749+
return Graph(6)
750+
751+
def test_topological_sort(g):
752+
assert g.dummy_fn() == 1
753+
assert g.topologicalSort() == 6
754+
"""
755+
test_file_path.write_text(test_file_content)
756+
757+
# Configure test discovery
758+
test_config = TestConfig(
759+
tests_root=path_obj_tmpdirname,
760+
project_root_path=path_obj_tmpdirname,
761+
test_framework="pytest", # Using pytest framework to discover unittest tests
762+
tests_project_rootdir=path_obj_tmpdirname.parent,
763+
)
764+
fto = FunctionToOptimize(function_name="topologicalSort", file_path=code_file_path, parents=[FunctionParent(name="Graph", type="ClassDef")])
765+
# Discover tests
766+
discovered_tests, _, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file_path: [fto]})
767+
768+
# Verify the unittest was discovered
769+
assert len(discovered_tests) == 2
770+
assert "topological_sort.Graph.topologicalSort" in discovered_tests
771+
assert len(discovered_tests["topological_sort.Graph.topologicalSort"]) == 1
772+
tpsort_test = next(iter(discovered_tests["topological_sort.Graph.topologicalSort"]))
773+
assert tpsort_test.tests_in_file.test_file.resolve() == test_file_path.resolve()
774+
assert tpsort_test.tests_in_file.test_function == "test_topological_sort"
775+
776+
def test_unittest_discovery_with_pytest_class_fixture():
777+
with tempfile.TemporaryDirectory() as tmpdirname:
778+
path_obj_tmpdirname = Path(tmpdirname)
779+
780+
# Create a simple code file
781+
code_file_path = path_obj_tmpdirname / "router_file.py"
782+
code_file_content = """
783+
from __future__ import annotations
784+
785+
import hashlib
786+
import json
787+
788+
class Router:
789+
model_names: list
790+
cache_responses = False
791+
tenacity = None
792+
793+
def __init__( # noqa: PLR0915
794+
self,
795+
model_list = None,
796+
) -> None:
797+
self.model_list = model_list
798+
self.model_id_to_deployment_index_map = {}
799+
self.model_name_to_deployment_indices = {}
800+
def _generate_model_id(self, model_group, litellm_params):
801+
# Optimized: Use list and join instead of string concatenation in loop
802+
# This avoids creating many temporary string objects (O(n) vs O(n²) complexity)
803+
parts = [model_group]
804+
for k, v in litellm_params.items():
805+
if isinstance(k, str):
806+
parts.append(k)
807+
elif isinstance(k, dict):
808+
parts.append(json.dumps(k))
809+
else:
810+
parts.append(str(k))
811+
812+
if isinstance(v, str):
813+
parts.append(v)
814+
elif isinstance(v, dict):
815+
parts.append(json.dumps(v))
816+
else:
817+
parts.append(str(v))
818+
819+
concat_str = "".join(parts)
820+
hash_object = hashlib.sha256(concat_str.encode())
821+
822+
return hash_object.hexdigest()
823+
def _add_model_to_list_and_index_map(
824+
self, model, model_id = None
825+
) -> None:
826+
idx = len(self.model_list)
827+
self.model_list.append(model)
828+
829+
# Update model_id index for O(1) lookup
830+
if model_id is not None:
831+
self.model_id_to_deployment_index_map[model_id] = idx
832+
elif model.get("model_info", {}).get("id") is not None:
833+
self.model_id_to_deployment_index_map[model["model_info"]["id"]] = idx
834+
835+
# Update model_name index for O(1) lookup
836+
model_name = model.get("model_name")
837+
if model_name:
838+
if model_name not in self.model_name_to_deployment_indices:
839+
self.model_name_to_deployment_indices[model_name] = []
840+
self.model_name_to_deployment_indices[model_name].append(idx)
841+
842+
def _build_model_id_to_deployment_index_map(self, model_list):
843+
# First populate the model_list
844+
self.model_list = []
845+
for _, model in enumerate(model_list):
846+
# Extract model_info from the model dict
847+
model_info = model.get("model_info", {})
848+
model_id = model_info.get("id")
849+
850+
# If no ID exists, generate one using the same logic as set_model_list
851+
if model_id is None:
852+
model_name = model.get("model_name", "")
853+
litellm_params = model.get("litellm_params", {})
854+
model_id = self._generate_model_id(model_name, litellm_params)
855+
# Update the model_info in the original list
856+
if "model_info" not in model:
857+
model["model_info"] = {}
858+
model["model_info"]["id"] = model_id
859+
860+
self._add_model_to_list_and_index_map(model=model, model_id=model_id)
861+
862+
"""
863+
code_file_path.write_text(code_file_content)
864+
865+
# Create a unittest test file with parameterized tests
866+
test_file_path = path_obj_tmpdirname / "test_router_file.py"
867+
test_file_content = """
868+
import pytest
869+
870+
from router_file import Router
871+
872+
873+
class TestRouterIndexManagement:
874+
@pytest.fixture
875+
def router(self):
876+
return Router(model_list=[])
877+
def test_build_model_id_to_deployment_index_map(self, router):
878+
model_list = [
879+
{
880+
"model_name": "gpt-3.5-turbo",
881+
"litellm_params": {"model": "gpt-3.5-turbo"},
882+
"model_info": {"id": "model-1"},
883+
},
884+
{
885+
"model_name": "gpt-4",
886+
"litellm_params": {"model": "gpt-4"},
887+
"model_info": {"id": "model-2"},
888+
},
889+
]
890+
891+
# Test: Build index from model list
892+
router._build_model_id_to_deployment_index_map(model_list)
893+
894+
# Verify: model_list is populated
895+
assert len(router.model_list) == 2
896+
# Verify: model_id_to_deployment_index_map is correctly built
897+
assert router.model_id_to_deployment_index_map["model-1"] == 0
898+
assert router.model_id_to_deployment_index_map["model-2"] == 1
899+
"""
900+
test_file_path.write_text(test_file_content)
901+
902+
# Configure test discovery
903+
test_config = TestConfig(
904+
tests_root=path_obj_tmpdirname,
905+
project_root_path=path_obj_tmpdirname,
906+
test_framework="pytest", # Using pytest framework to discover unittest tests
907+
tests_project_rootdir=path_obj_tmpdirname.parent,
908+
)
909+
fto = FunctionToOptimize(function_name="_build_model_id_to_deployment_index_map", file_path=code_file_path, parents=[FunctionParent(name="Router", type="ClassDef")])
910+
# Discover tests
911+
discovered_tests, _, _ = discover_unit_tests(test_config, file_to_funcs_to_optimize={code_file_path: [fto]})
912+
913+
# Verify the unittest was discovered
914+
assert len(discovered_tests) == 1
915+
assert "router_file.Router._build_model_id_to_deployment_index_map" in discovered_tests
916+
assert len(discovered_tests["router_file.Router._build_model_id_to_deployment_index_map"]) == 1
917+
router_test = next(iter(discovered_tests["router_file.Router._build_model_id_to_deployment_index_map"]))
918+
assert router_test.tests_in_file.test_file.resolve() == test_file_path.resolve()
919+
assert router_test.tests_in_file.test_function == "test_build_model_id_to_deployment_index_map"
920+
717921

718922
def test_unittest_discovery_with_pytest_parameterized():
719923
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -1335,6 +1539,77 @@ def test_topological_sort():
13351539

13361540
assert should_process is True
13371541

1542+
def test_analyze_imports_fixture():
1543+
with tempfile.TemporaryDirectory() as tmpdirname:
1544+
test_file = Path(tmpdirname) / "test_example.py"
1545+
test_content = """
1546+
from code_to_optimize.topological_sort import Graph
1547+
import pytest
1548+
1549+
@pytest.fixture
1550+
def g():
1551+
return Graph(6)
1552+
1553+
def test_topological_sort(g):
1554+
g.addEdge(5, 2)
1555+
g.addEdge(5, 0)
1556+
g.addEdge(4, 0)
1557+
g.addEdge(4, 1)
1558+
g.addEdge(2, 3)
1559+
g.addEdge(3, 1)
1560+
1561+
assert g.topologicalSort()[0] == [5, 4, 2, 3, 1, 0]
1562+
"""
1563+
test_file.write_text(test_content)
1564+
1565+
target_functions = {"Graph.topologicalSort"}
1566+
should_process = analyze_imports_in_test_file(test_file, target_functions)
1567+
1568+
assert should_process is True
1569+
1570+
def test_analyze_imports_class_fixture():
1571+
with tempfile.TemporaryDirectory() as tmpdirname:
1572+
test_file = Path(tmpdirname) / "test_example.py"
1573+
test_content = """
1574+
import pytest
1575+
1576+
from router_file import Router
1577+
1578+
1579+
class TestRouterIndexManagement:
1580+
@pytest.fixture
1581+
def router(self):
1582+
return Router(model_list=[])
1583+
def test_build_model_id_to_deployment_index_map(self, router):
1584+
model_list = [
1585+
{
1586+
"model_name": "gpt-3.5-turbo",
1587+
"litellm_params": {"model": "gpt-3.5-turbo"},
1588+
"model_info": {"id": "model-1"},
1589+
},
1590+
{
1591+
"model_name": "gpt-4",
1592+
"litellm_params": {"model": "gpt-4"},
1593+
"model_info": {"id": "model-2"},
1594+
},
1595+
]
1596+
1597+
# Test: Build index from model list
1598+
router._build_model_id_to_deployment_index_map(model_list)
1599+
1600+
# Verify: model_list is populated
1601+
assert len(router.model_list) == 2
1602+
# Verify: model_id_to_deployment_index_map is correctly built
1603+
assert router.model_id_to_deployment_index_map["model-1"] == 0
1604+
assert router.model_id_to_deployment_index_map["model-2"] == 1
1605+
"""
1606+
test_file.write_text(test_content)
1607+
1608+
target_functions = {"Router._build_model_id_to_deployment_index_map"}
1609+
should_process = analyze_imports_in_test_file(test_file, target_functions)
1610+
1611+
assert should_process is True
1612+
13381613
def test_analyze_imports_aliased_class_method_negative():
13391614
with tempfile.TemporaryDirectory() as tmpdirname:
13401615
test_file = Path(tmpdirname) / "test_example.py"

0 commit comments

Comments
 (0)