Skip to content

Commit 60a6db2

Browse files
authored
Merge branch 'main' into feat/feedback-loop-for-unmatched-test-results
2 parents 5063fba + 09fa96c commit 60a6db2

File tree

14 files changed

+474
-266
lines changed

14 files changed

+474
-266
lines changed

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,9 @@ fabric.properties
254254

255255
# Mac
256256
.DS_Store
257-
WARP.MD
257+
WARP.MD
258+
259+
.mcp.json
260+
.tessl/
261+
CLAUDE.md
262+
tessl.json

AGENTS.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,4 +315,8 @@ Language Server Protocol support in `codeflash/lsp/` enables IDE integration dur
315315
### Performance Optimization
316316
- Profile before and after changes
317317
- Use benchmarks to validate improvements
318-
- Generate detailed performance reports
318+
- Generate detailed performance reports
319+
320+
# Agent Rules <!-- tessl-managed -->
321+
322+
@.tessl/RULES.md follow the [instructions](.tessl/RULES.md)

codeflash/benchmarking/function_ranker.py

Lines changed: 132 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import TYPE_CHECKING
44

5-
from codeflash.cli_cmds.console import console, logger
5+
from codeflash.cli_cmds.console import logger
66
from codeflash.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD
77
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
88
from codeflash.tracing.profile_stats import ProfileStats
@@ -12,29 +12,63 @@
1212

1313
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1414

15+
pytest_patterns = {
16+
"<frozen", # Frozen modules like runpy
17+
"<string>", # Dynamically evaluated code
18+
"_pytest/", # Pytest internals
19+
"pytest", # Pytest files
20+
"pluggy/", # Plugin system
21+
"_pydev", # PyDev debugger
22+
"runpy.py", # Python module runner
23+
}
24+
pytest_func_patterns = {"pytest_", "_pytest", "runtest"}
25+
26+
27+
def is_pytest_infrastructure(filename: str, function_name: str) -> bool:
28+
"""Check if a function is part of pytest infrastructure that should be excluded from ranking.
29+
30+
This filters out pytest internal functions, hooks, and test framework code that
31+
would otherwise dominate the ranking but aren't candidates for optimization.
32+
"""
33+
# Check filename patterns
34+
for pattern in pytest_patterns:
35+
if pattern in filename:
36+
return True
37+
38+
return any(pattern in function_name.lower() for pattern in pytest_func_patterns)
39+
1540

1641
class FunctionRanker:
17-
"""Ranks and filters functions based on a ttX score derived from profiling data.
42+
"""Ranks and filters functions based on % of addressable time derived from profiling data.
1843
19-
The ttX score is calculated as:
20-
ttX = own_time + (time_spent_in_callees / call_count)
44+
The % of addressable time is calculated as:
45+
addressable_time = own_time + (time_spent_in_callees / call_count)
2146
22-
This score prioritizes functions that are computationally heavy themselves (high `own_time`)
23-
or that make expensive calls to other functions (high average `time_spent_in_callees`).
47+
This represents the runtime of a function plus the runtime of its immediate dependent functions,
48+
as a fraction of overall runtime. It prioritizes functions that are computationally heavy themselves
49+
(high `own_time`) or that make expensive calls to other functions (high average `time_spent_in_callees`).
2450
2551
Functions are first filtered by an importance threshold based on their `own_time` as a
26-
fraction of the total runtime. The remaining functions are then ranked by their ttX score
52+
fraction of the total runtime. The remaining functions are then ranked by their % of addressable time
2753
to identify the best candidates for optimization.
2854
"""
2955

3056
def __init__(self, trace_file_path: Path) -> None:
3157
self.trace_file_path = trace_file_path
3258
self._profile_stats = ProfileStats(trace_file_path.as_posix())
3359
self._function_stats: dict[str, dict] = {}
60+
self._function_stats_by_name: dict[str, list[tuple[str, dict]]] = {}
3461
self.load_function_stats()
3562

63+
# Build index for faster lookups: map function_name to list of (key, stats)
64+
for key, stats in self._function_stats.items():
65+
func_name = stats.get("function_name")
66+
if func_name:
67+
self._function_stats_by_name.setdefault(func_name, []).append((key, stats))
68+
3669
def load_function_stats(self) -> None:
3770
try:
71+
pytest_filtered_count = 0
3872
for (filename, line_number, func_name), (
3973
call_count,
4074
_num_callers,
@@ -45,6 +79,10 @@ def load_function_stats(self) -> None:
4579
if call_count <= 0:
4680
continue
4781

82+
if is_pytest_infrastructure(filename, func_name):
83+
pytest_filtered_count += 1
84+
continue
85+
4886
# Parse function name to handle methods within classes
4987
class_name, qualified_name, base_function_name = (None, func_name, func_name)
5088
if "." in func_name and not func_name.startswith("<"):
@@ -56,8 +94,8 @@ def load_function_stats(self) -> None:
5694
own_time_ns = total_time_ns
5795
time_in_callees_ns = cumulative_time_ns - total_time_ns
5896

59-
# Calculate ttX score
60-
ttx_score = own_time_ns + (time_in_callees_ns / call_count)
97+
# Calculate addressable time (own time + avg time in immediate callees)
98+
addressable_time_ns = own_time_ns + (time_in_callees_ns / call_count)
6199

62100
function_key = f"{filename}:{qualified_name}"
63101
self._function_stats[function_key] = {
@@ -70,89 +108,118 @@ def load_function_stats(self) -> None:
70108
"own_time_ns": own_time_ns,
71109
"cumulative_time_ns": cumulative_time_ns,
72110
"time_in_callees_ns": time_in_callees_ns,
73-
"ttx_score": ttx_score,
111+
"addressable_time_ns": addressable_time_ns,
74112
}
75113

76-
logger.debug(f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats")
114+
logger.debug(
115+
f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats "
116+
f"(filtered {pytest_filtered_count} pytest infrastructure functions)"
117+
)
77118

78119
except Exception as e:
79120
logger.warning(f"Failed to process function stats from trace file {self.trace_file_path}: {e}")
80121
self._function_stats = {}
81122

82-
def _get_function_stats(self, function_to_optimize: FunctionToOptimize) -> dict | None:
123+
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
83124
target_filename = function_to_optimize.file_path.name
84-
for key, stats in self._function_stats.items():
85-
if stats.get("function_name") == function_to_optimize.function_name and (
86-
key.endswith(f"/{target_filename}") or target_filename in key
87-
):
125+
candidates = self._function_stats_by_name.get(function_to_optimize.function_name)
126+
if not candidates:
127+
logger.debug(
128+
f"Could not find stats for function {function_to_optimize.function_name} in file {target_filename}"
129+
)
130+
return None
131+
132+
for key, stats in candidates:
133+
# The check preserves exact logic: "key.endswith(f"/{target_filename}") or target_filename in key"
134+
if key.endswith(f"/{target_filename}") or target_filename in key:
88135
return stats
89136

90137
logger.debug(
91138
f"Could not find stats for function {function_to_optimize.function_name} in file {target_filename}"
92139
)
93140
return None
94141

95-
def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float:
96-
stats = self._get_function_stats(function_to_optimize)
97-
return stats["ttx_score"] if stats else 0.0
142+
def get_function_addressable_time(self, function_to_optimize: FunctionToOptimize) -> float:
143+
"""Get the addressable time in nanoseconds for a function.
98144
99-
def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
100-
ranked = sorted(functions_to_optimize, key=self.get_function_ttx_score, reverse=True)
101-
logger.debug(
102-
f"Function ranking order: {[f'{func.function_name} (ttX={self.get_function_ttx_score(func):.2f})' for func in ranked]}"
103-
)
104-
return ranked
145+
Addressable time = own_time + (time_in_callees / call_count)
146+
This represents the runtime of the function plus runtime of immediate dependent functions.
147+
"""
148+
stats = self.get_function_stats_summary(function_to_optimize)
149+
return stats["addressable_time_ns"] if stats else 0.0
105150

106-
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
107-
return self._get_function_stats(function_to_optimize)
151+
def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
152+
"""Ranks and filters functions based on their % of addressable time and importance.
108153
109-
def rerank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
110-
"""Ranks functions based on their ttX score.
154+
Filters out functions whose own_time is less than DEFAULT_IMPORTANCE_THRESHOLD
155+
of file-relative runtime, then ranks the remaining functions by addressable time.
111156
112-
This method calculates the ttX score for each function and returns
113-
the functions sorted in descending order of their ttX score.
114-
"""
115-
if not self._function_stats:
116-
logger.warning("No function stats available to rank functions.")
117-
return []
157+
Importance is calculated relative to functions in the same file(s) rather than
158+
total program time. This avoids filtering out functions due to test infrastructure
159+
overhead.
118160
119-
return self.rank_functions(functions_to_optimize)
161+
The addressable time metric (own_time + avg time in immediate callees) prioritizes
162+
functions that are computationally heavy themselves or that make expensive calls
163+
to other functions.
120164
121-
def rerank_and_filter_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
122-
"""Reranks and filters functions based on their impact on total runtime.
165+
Args:
166+
functions_to_optimize: List of functions to rank.
123167
124-
This method first calculates the total runtime of all profiled functions.
125-
It then filters out functions whose own_time is less than a specified
126-
percentage of the total runtime (importance_threshold).
168+
Returns:
169+
Important functions sorted in descending order of their addressable time.
127170
128-
The remaining 'important' functions are then ranked by their ttX score.
129171
"""
130-
stats_map = self._function_stats
131-
if not stats_map:
172+
if not self._function_stats:
173+
logger.warning("No function stats available to rank functions.")
132174
return []
133175

134-
total_program_time = sum(s["own_time_ns"] for s in stats_map.values() if s.get("own_time_ns", 0) > 0)
176+
# Calculate total time from functions in the same file(s) as functions to optimize
177+
if functions_to_optimize:
178+
# Get unique files from functions to optimize
179+
target_files = {func.file_path.name for func in functions_to_optimize}
180+
# Calculate total time only from functions in these files
181+
total_program_time = sum(
182+
s["own_time_ns"]
183+
for s in self._function_stats.values()
184+
if s.get("own_time_ns", 0) > 0
185+
and any(
186+
str(s.get("filename", "")).endswith("/" + target_file) or s.get("filename") == target_file
187+
for target_file in target_files
188+
)
189+
)
190+
logger.debug(
191+
f"Using file-relative importance for {len(target_files)} file(s): {target_files}. "
192+
f"Total file time: {total_program_time:,} ns"
193+
)
194+
else:
195+
total_program_time = sum(
196+
s["own_time_ns"] for s in self._function_stats.values() if s.get("own_time_ns", 0) > 0
197+
)
135198

136199
if total_program_time == 0:
137200
logger.warning("Total program time is zero, cannot determine function importance.")
138-
return self.rank_functions(functions_to_optimize)
139-
140-
important_functions = []
141-
for func in functions_to_optimize:
142-
func_stats = self._get_function_stats(func)
143-
if func_stats and func_stats.get("own_time_ns", 0) > 0:
144-
importance = func_stats["own_time_ns"] / total_program_time
145-
if importance >= DEFAULT_IMPORTANCE_THRESHOLD:
146-
important_functions.append(func)
147-
else:
148-
logger.debug(
149-
f"Filtering out function {func.qualified_name} with importance "
150-
f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})"
151-
)
152-
153-
logger.info(
154-
f"Filtered down to {len(important_functions)} important functions from {len(functions_to_optimize)} total functions"
201+
functions_to_rank = functions_to_optimize
202+
else:
203+
functions_to_rank = []
204+
for func in functions_to_optimize:
205+
func_stats = self.get_function_stats_summary(func)
206+
if func_stats and func_stats.get("addressable_time_ns", 0) > 0:
207+
importance = func_stats["addressable_time_ns"] / total_program_time
208+
if importance >= DEFAULT_IMPORTANCE_THRESHOLD:
209+
functions_to_rank.append(func)
210+
else:
211+
logger.debug(
212+
f"Filtering out function {func.qualified_name} with importance "
213+
f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})"
214+
)
215+
216+
logger.info(
217+
f"Filtered down to {len(functions_to_rank)} important functions "
218+
f"from {len(functions_to_optimize)} total functions"
219+
)
220+
221+
ranked = sorted(functions_to_rank, key=self.get_function_addressable_time, reverse=True)
222+
logger.debug(
223+
f"Function ranking order: {[f'{func.function_name} (addressable_time={self.get_function_addressable_time(func):.2f}ns)' for func in ranked]}"
155224
)
156-
console.rule()
157-
158-
return self.rank_functions(important_functions)
225+
return ranked

codeflash/benchmarking/replay_test.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -64,30 +64,23 @@ def get_unique_test_name(module: str, function_name: str, benchmark_name: str, c
6464

6565

6666
def create_trace_replay_test_code(
67-
trace_file: str,
68-
functions_data: list[dict[str, Any]],
69-
test_framework: str = "pytest",
70-
max_run_count=256, # noqa: ANN001
67+
trace_file: str, functions_data: list[dict[str, Any]], max_run_count: int = 256
7168
) -> str:
7269
"""Create a replay test for functions based on trace data.
7370
7471
Args:
7572
----
7673
trace_file: Path to the SQLite database file
7774
functions_data: List of dictionaries with function info extracted from DB
78-
test_framework: 'pytest' or 'unittest'
7975
max_run_count: Maximum number of runs to include in the test
8076
8177
Returns:
8278
-------
8379
A string containing the test code
8480
8581
"""
86-
assert test_framework in ["pytest", "unittest"]
87-
8882
# Create Imports
89-
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
90-
{"import unittest" if test_framework == "unittest" else ""}
83+
imports = """from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
9184
from codeflash.benchmarking.replay_test import get_next_arg_and_return
9285
"""
9386

@@ -158,13 +151,7 @@ def create_trace_replay_test_code(
158151
)
159152

160153
# Create main body
161-
162-
if test_framework == "unittest":
163-
self = "self"
164-
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
165-
else:
166-
test_template = ""
167-
self = ""
154+
test_template = ""
168155

169156
for func in functions_data:
170157
module_name = func.get("module_name")
@@ -223,30 +210,26 @@ def create_trace_replay_test_code(
223210
filter_variables=filter_variables,
224211
)
225212

226-
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
213+
formatted_test_body = textwrap.indent(test_body, " ")
227214

228-
test_template += " " if test_framework == "unittest" else ""
229215
unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name)
230-
test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n"
216+
test_template += f"def test_{unique_test_name}():\n{formatted_test_body}\n"
231217

232218
return imports + "\n" + metadata + "\n" + test_template
233219

234220

235-
def generate_replay_test(
236-
trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100
237-
) -> int:
221+
def generate_replay_test(trace_file_path: Path, output_dir: Path, max_run_count: int = 100) -> int:
238222
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.
239223
240224
Args:
241225
----
242226
trace_file_path: Path to the SQLite database file
243227
output_dir: Directory to write the generated tests (if None, only returns the code)
244-
test_framework: 'pytest' or 'unittest'
245228
max_run_count: Maximum number of runs to include per function
246229
247230
Returns:
248231
-------
249-
Dictionary mapping benchmark names to generated test code
232+
The number of replay tests generated
250233
251234
"""
252235
count = 0
@@ -293,10 +276,7 @@ def generate_replay_test(
293276
continue
294277
# Generate the test code for this benchmark
295278
test_code = create_trace_replay_test_code(
296-
trace_file=trace_file_path.as_posix(),
297-
functions_data=functions_data,
298-
test_framework=test_framework,
299-
max_run_count=max_run_count,
279+
trace_file=trace_file_path.as_posix(), functions_data=functions_data, max_run_count=max_run_count
300280
)
301281
test_code = sort_imports(code=test_code)
302282
output_file = get_test_file_path(

0 commit comments

Comments
 (0)