Skip to content

Commit 03de4db

Browse files
committed
Consolidate FunctionRanker: merge rank/rerank/filter methods into single rank_functions
1 parent 45cdc62 commit 03de4db

File tree

2 files changed

+48
-61
lines changed

2 files changed

+48
-61
lines changed

codeflash/benchmarking/function_ranker.py

Lines changed: 41 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def load_function_stats(self) -> None:
7979
logger.warning(f"Failed to process function stats from trace file {self.trace_file_path}: {e}")
8080
self._function_stats = {}
8181

82-
def _get_function_stats(self, function_to_optimize: FunctionToOptimize) -> dict | None:
82+
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
8383
target_filename = function_to_optimize.file_path.name
8484
for key, stats in self._function_stats.items():
8585
if stats.get("function_name") == function_to_optimize.function_name and (
@@ -93,66 +93,58 @@ def _get_function_stats(self, function_to_optimize: FunctionToOptimize) -> dict
9393
return None
9494

9595
def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float:
96-
stats = self._get_function_stats(function_to_optimize)
96+
stats = self.get_function_stats_summary(function_to_optimize)
9797
return stats["ttx_score"] if stats else 0.0
9898

9999
def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
100-
ranked = sorted(functions_to_optimize, key=self.get_function_ttx_score, reverse=True)
101-
logger.debug(
102-
f"Function ranking order: {[f'{func.function_name} (ttX={self.get_function_ttx_score(func):.2f})' for func in ranked]}"
103-
)
104-
return ranked
100+
"""Ranks and filters functions based on their ttX score and importance.
105101
106-
def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None:
107-
return self._get_function_stats(function_to_optimize)
102+
Filters out functions whose own_time is less than DEFAULT_IMPORTANCE_THRESHOLD
103+
of total runtime, then ranks the remaining functions by ttX score.
108104
109-
def rerank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
110-
"""Ranks functions based on their ttX score.
105+
The ttX score prioritizes functions that are computationally heavy themselves
106+
or that make expensive calls to other functions.
111107
112-
This method calculates the ttX score for each function and returns
113-
the functions sorted in descending order of their ttX score.
114-
"""
115-
if not self._function_stats:
116-
logger.warning("No function stats available to rank functions.")
117-
return []
108+
Args:
109+
functions_to_optimize: List of functions to rank.
118110
119-
return self.rank_functions(functions_to_optimize)
111+
Returns:
112+
Important functions sorted in descending order of their ttX score.
120113
121-
def rerank_and_filter_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]:
122-
"""Reranks and filters functions based on their impact on total runtime.
123-
124-
This method first calculates the total runtime of all profiled functions.
125-
It then filters out functions whose own_time is less than a specified
126-
percentage of the total runtime (importance_threshold).
127-
128-
The remaining 'important' functions are then ranked by their ttX score.
129114
"""
130-
stats_map = self._function_stats
131-
if not stats_map:
115+
if not self._function_stats:
116+
logger.warning("No function stats available to rank functions.")
132117
return []
133118

134-
total_program_time = sum(s["own_time_ns"] for s in stats_map.values() if s.get("own_time_ns", 0) > 0)
119+
total_program_time = sum(
120+
s["own_time_ns"] for s in self._function_stats.values() if s.get("own_time_ns", 0) > 0
121+
)
135122

136123
if total_program_time == 0:
137124
logger.warning("Total program time is zero, cannot determine function importance.")
138-
return self.rank_functions(functions_to_optimize)
139-
140-
important_functions = []
141-
for func in functions_to_optimize:
142-
func_stats = self._get_function_stats(func)
143-
if func_stats and func_stats.get("own_time_ns", 0) > 0:
144-
importance = func_stats["own_time_ns"] / total_program_time
145-
if importance >= DEFAULT_IMPORTANCE_THRESHOLD:
146-
important_functions.append(func)
147-
else:
148-
logger.debug(
149-
f"Filtering out function {func.qualified_name} with importance "
150-
f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})"
151-
)
152-
153-
logger.info(
154-
f"Filtered down to {len(important_functions)} important functions from {len(functions_to_optimize)} total functions"
125+
functions_to_rank = functions_to_optimize
126+
else:
127+
functions_to_rank = []
128+
for func in functions_to_optimize:
129+
func_stats = self.get_function_stats_summary(func)
130+
if func_stats and func_stats.get("own_time_ns", 0) > 0:
131+
importance = func_stats["own_time_ns"] / total_program_time
132+
if importance >= DEFAULT_IMPORTANCE_THRESHOLD:
133+
functions_to_rank.append(func)
134+
else:
135+
logger.debug(
136+
f"Filtering out function {func.qualified_name} with importance "
137+
f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})"
138+
)
139+
140+
logger.info(
141+
f"Filtered down to {len(functions_to_rank)} important functions "
142+
f"from {len(functions_to_optimize)} total functions"
143+
)
144+
console.rule()
145+
146+
ranked = sorted(functions_to_rank, key=self.get_function_ttx_score, reverse=True)
147+
logger.debug(
148+
f"Function ranking order: {[f'{func.function_name} (ttX={self.get_function_ttx_score(func):.2f})' for func in ranked]}"
155149
)
156-
console.rule()
157-
158-
return self.rank_functions(important_functions)
150+
return ranked

tests/test_function_ranker.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,13 @@ def test_get_function_ttx_score(function_ranker, workload_functions):
8080
def test_rank_functions(function_ranker, workload_functions):
8181
ranked_functions = function_ranker.rank_functions(workload_functions)
8282

83-
assert len(ranked_functions) == len(workload_functions)
83+
# Should filter out functions below importance threshold and sort by ttX score
84+
assert len(ranked_functions) <= len(workload_functions)
85+
assert len(ranked_functions) > 0 # At least some functions should pass the threshold
86+
87+
# funcA should pass the importance threshold
88+
func_a_in_results = any(f.function_name == "funcA" for f in ranked_functions)
89+
assert func_a_in_results
8490

8591
# Verify functions are sorted by ttX score in descending order
8692
for i in range(len(ranked_functions) - 1):
@@ -89,17 +95,6 @@ def test_rank_functions(function_ranker, workload_functions):
8995
assert current_score >= next_score
9096

9197

92-
def test_rerank_and_filter_functions(function_ranker, workload_functions):
93-
filtered_ranked = function_ranker.rerank_and_filter_functions(workload_functions)
94-
95-
# Should filter out functions below importance threshold
96-
assert len(filtered_ranked) <= len(workload_functions)
97-
98-
# funcA should pass the importance threshold (0.33% > 0.1%)
99-
func_a_in_results = any(f.function_name == "funcA" for f in filtered_ranked)
100-
assert func_a_in_results
101-
102-
10398
def test_get_function_stats_summary(function_ranker, workload_functions):
10499
func_a = None
105100
for func in workload_functions:

0 commit comments

Comments
 (0)