diff --git a/.gitignore b/.gitignore index ebf794cc2..568f9eaca 100644 --- a/.gitignore +++ b/.gitignore @@ -254,4 +254,9 @@ fabric.properties # Mac .DS_Store -WARP.MD \ No newline at end of file +WARP.MD + +.mcp.json +.tessl/ +CLAUDE.md +tessl.json diff --git a/AGENTS.md b/AGENTS.md index 360ec78be..fe6acacc4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -315,4 +315,8 @@ Language Server Protocol support in `codeflash/lsp/` enables IDE integration dur ### Performance Optimization - Profile before and after changes - Use benchmarks to validate improvements -- Generate detailed performance reports \ No newline at end of file +- Generate detailed performance reports + +# Agent Rules + +@.tessl/RULES.md follow the [instructions](.tessl/RULES.md) diff --git a/code_to_optimize/code_directories/simple_tracer_e2e/codeflash.sqlite3 b/code_to_optimize/code_directories/simple_tracer_e2e/codeflash.sqlite3 new file mode 100644 index 000000000..6e3ea527d Binary files /dev/null and b/code_to_optimize/code_directories/simple_tracer_e2e/codeflash.sqlite3 differ diff --git a/codeflash/benchmarking/function_ranker.py b/codeflash/benchmarking/function_ranker.py index 9d1d8ec14..21d146c05 100644 --- a/codeflash/benchmarking/function_ranker.py +++ b/codeflash/benchmarking/function_ranker.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from codeflash.cli_cmds.console import console, logger +from codeflash.cli_cmds.console import logger from codeflash.code_utils.config_consts import DEFAULT_IMPORTANCE_THRESHOLD from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.tracing.profile_stats import ProfileStats @@ -12,18 +12,44 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize +pytest_patterns = { + "", # Dynamically evaluated code + "_pytest/", # Pytest internals + "pytest", # Pytest files + "pluggy/", # Plugin system + "_pydev", # PyDev debugger + "runpy.py", # Python module runner +} +pytest_func_patterns = {"pytest_", "_pytest", "runtest"} + + +def is_pytest_infrastructure(filename: str, function_name: str) -> bool: + """Check if a function is part of pytest infrastructure that should be excluded from ranking. + + This filters out pytest internal functions, hooks, and test framework code that + would otherwise dominate the ranking but aren't candidates for optimization. + """ + # Check filename patterns + for pattern in pytest_patterns: + if pattern in filename: + return True + + return any(pattern in function_name.lower() for pattern in pytest_func_patterns) + class FunctionRanker: - """Ranks and filters functions based on a ttX score derived from profiling data. + """Ranks and filters functions based on % of addressable time derived from profiling data. - The ttX score is calculated as: - ttX = own_time + (time_spent_in_callees / call_count) + The % of addressable time is calculated as: + addressable_time = own_time + (time_spent_in_callees / call_count) - This score prioritizes functions that are computationally heavy themselves (high `own_time`) - or that make expensive calls to other functions (high average `time_spent_in_callees`). + This represents the runtime of a function plus the runtime of its immediate dependent functions, + as a fraction of overall runtime. It prioritizes functions that are computationally heavy themselves + (high `own_time`) or that make expensive calls to other functions (high average `time_spent_in_callees`). Functions are first filtered by an importance threshold based on their `own_time` as a - fraction of the total runtime. The remaining functions are then ranked by their ttX score + fraction of the total runtime. The remaining functions are then ranked by their % of addressable time to identify the best candidates for optimization. """ @@ -31,10 +57,18 @@ def __init__(self, trace_file_path: Path) -> None: self.trace_file_path = trace_file_path self._profile_stats = ProfileStats(trace_file_path.as_posix()) self._function_stats: dict[str, dict] = {} + self._function_stats_by_name: dict[str, list[tuple[str, dict]]] = {} self.load_function_stats() + # Build index for faster lookups: map function_name to list of (key, stats) + for key, stats in self._function_stats.items(): + func_name = stats.get("function_name") + if func_name: + self._function_stats_by_name.setdefault(func_name, []).append((key, stats)) + def load_function_stats(self) -> None: try: + pytest_filtered_count = 0 for (filename, line_number, func_name), ( call_count, _num_callers, @@ -45,6 +79,10 @@ def load_function_stats(self) -> None: if call_count <= 0: continue + if is_pytest_infrastructure(filename, func_name): + pytest_filtered_count += 1 + continue + # Parse function name to handle methods within classes class_name, qualified_name, base_function_name = (None, func_name, func_name) if "." in func_name and not func_name.startswith("<"): @@ -56,8 +94,8 @@ def load_function_stats(self) -> None: own_time_ns = total_time_ns time_in_callees_ns = cumulative_time_ns - total_time_ns - # Calculate ttX score - ttx_score = own_time_ns + (time_in_callees_ns / call_count) + # Calculate addressable time (own time + avg time in immediate callees) + addressable_time_ns = own_time_ns + (time_in_callees_ns / call_count) function_key = f"{filename}:{qualified_name}" self._function_stats[function_key] = { @@ -70,21 +108,30 @@ def load_function_stats(self) -> None: "own_time_ns": own_time_ns, "cumulative_time_ns": cumulative_time_ns, "time_in_callees_ns": time_in_callees_ns, - "ttx_score": ttx_score, + "addressable_time_ns": addressable_time_ns, } - logger.debug(f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats") + logger.debug( + f"Loaded timing stats for {len(self._function_stats)} functions from trace using ProfileStats " + f"(filtered {pytest_filtered_count} pytest infrastructure functions)" + ) except Exception as e: logger.warning(f"Failed to process function stats from trace file {self.trace_file_path}: {e}") self._function_stats = {} - def _get_function_stats(self, function_to_optimize: FunctionToOptimize) -> dict | None: + def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None: target_filename = function_to_optimize.file_path.name - for key, stats in self._function_stats.items(): - if stats.get("function_name") == function_to_optimize.function_name and ( - key.endswith(f"/{target_filename}") or target_filename in key - ): + candidates = self._function_stats_by_name.get(function_to_optimize.function_name) + if not candidates: + logger.debug( + f"Could not find stats for function {function_to_optimize.function_name} in file {target_filename}" + ) + return None + + for key, stats in candidates: + # The check preserves exact logic: "key.endswith(f"/{target_filename}") or target_filename in key" + if key.endswith(f"/{target_filename}") or target_filename in key: return stats logger.debug( @@ -92,67 +139,87 @@ def _get_function_stats(self, function_to_optimize: FunctionToOptimize) -> dict ) return None - def get_function_ttx_score(self, function_to_optimize: FunctionToOptimize) -> float: - stats = self._get_function_stats(function_to_optimize) - return stats["ttx_score"] if stats else 0.0 + def get_function_addressable_time(self, function_to_optimize: FunctionToOptimize) -> float: + """Get the addressable time in nanoseconds for a function. - def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]: - ranked = sorted(functions_to_optimize, key=self.get_function_ttx_score, reverse=True) - logger.debug( - f"Function ranking order: {[f'{func.function_name} (ttX={self.get_function_ttx_score(func):.2f})' for func in ranked]}" - ) - return ranked + Addressable time = own_time + (time_in_callees / call_count) + This represents the runtime of the function plus runtime of immediate dependent functions. + """ + stats = self.get_function_stats_summary(function_to_optimize) + return stats["addressable_time_ns"] if stats else 0.0 - def get_function_stats_summary(self, function_to_optimize: FunctionToOptimize) -> dict | None: - return self._get_function_stats(function_to_optimize) + def rank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]: + """Ranks and filters functions based on their % of addressable time and importance. - def rerank_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]: - """Ranks functions based on their ttX score. + Filters out functions whose own_time is less than DEFAULT_IMPORTANCE_THRESHOLD + of file-relative runtime, then ranks the remaining functions by addressable time. - This method calculates the ttX score for each function and returns - the functions sorted in descending order of their ttX score. - """ - if not self._function_stats: - logger.warning("No function stats available to rank functions.") - return [] + Importance is calculated relative to functions in the same file(s) rather than + total program time. This avoids filtering out functions due to test infrastructure + overhead. - return self.rank_functions(functions_to_optimize) + The addressable time metric (own_time + avg time in immediate callees) prioritizes + functions that are computationally heavy themselves or that make expensive calls + to other functions. - def rerank_and_filter_functions(self, functions_to_optimize: list[FunctionToOptimize]) -> list[FunctionToOptimize]: - """Reranks and filters functions based on their impact on total runtime. + Args: + functions_to_optimize: List of functions to rank. - This method first calculates the total runtime of all profiled functions. - It then filters out functions whose own_time is less than a specified - percentage of the total runtime (importance_threshold). + Returns: + Important functions sorted in descending order of their addressable time. - The remaining 'important' functions are then ranked by their ttX score. """ - stats_map = self._function_stats - if not stats_map: + if not self._function_stats: + logger.warning("No function stats available to rank functions.") return [] - total_program_time = sum(s["own_time_ns"] for s in stats_map.values() if s.get("own_time_ns", 0) > 0) + # Calculate total time from functions in the same file(s) as functions to optimize + if functions_to_optimize: + # Get unique files from functions to optimize + target_files = {func.file_path.name for func in functions_to_optimize} + # Calculate total time only from functions in these files + total_program_time = sum( + s["own_time_ns"] + for s in self._function_stats.values() + if s.get("own_time_ns", 0) > 0 + and any( + str(s.get("filename", "")).endswith("/" + target_file) or s.get("filename") == target_file + for target_file in target_files + ) + ) + logger.debug( + f"Using file-relative importance for {len(target_files)} file(s): {target_files}. " + f"Total file time: {total_program_time:,} ns" + ) + else: + total_program_time = sum( + s["own_time_ns"] for s in self._function_stats.values() if s.get("own_time_ns", 0) > 0 + ) if total_program_time == 0: logger.warning("Total program time is zero, cannot determine function importance.") - return self.rank_functions(functions_to_optimize) - - important_functions = [] - for func in functions_to_optimize: - func_stats = self._get_function_stats(func) - if func_stats and func_stats.get("own_time_ns", 0) > 0: - importance = func_stats["own_time_ns"] / total_program_time - if importance >= DEFAULT_IMPORTANCE_THRESHOLD: - important_functions.append(func) - else: - logger.debug( - f"Filtering out function {func.qualified_name} with importance " - f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})" - ) - - logger.info( - f"Filtered down to {len(important_functions)} important functions from {len(functions_to_optimize)} total functions" + functions_to_rank = functions_to_optimize + else: + functions_to_rank = [] + for func in functions_to_optimize: + func_stats = self.get_function_stats_summary(func) + if func_stats and func_stats.get("own_time_ns", 0) > 0: + importance = func_stats["own_time_ns"] / total_program_time + if importance >= DEFAULT_IMPORTANCE_THRESHOLD: + functions_to_rank.append(func) + else: + logger.debug( + f"Filtering out function {func.qualified_name} with importance " + f"{importance:.2%} (below threshold {DEFAULT_IMPORTANCE_THRESHOLD:.2%})" + ) + + logger.info( + f"Filtered down to {len(functions_to_rank)} important functions " + f"from {len(functions_to_optimize)} total functions" + ) + + ranked = sorted(functions_to_rank, key=self.get_function_addressable_time, reverse=True) + logger.debug( + f"Function ranking order: {[f'{func.function_name} (addressable_time={self.get_function_addressable_time(func):.2f}ns)' for func in ranked]}" ) - console.rule() - - return self.rank_functions(important_functions) + return ranked diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index e9f66dc8a..5fc9ab720 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -64,10 +64,7 @@ def get_unique_test_name(module: str, function_name: str, benchmark_name: str, c def create_trace_replay_test_code( - trace_file: str, - functions_data: list[dict[str, Any]], - test_framework: str = "pytest", - max_run_count=256, # noqa: ANN001 + trace_file: str, functions_data: list[dict[str, Any]], max_run_count: int = 256 ) -> str: """Create a replay test for functions based on trace data. @@ -75,7 +72,6 @@ def create_trace_replay_test_code( ---- trace_file: Path to the SQLite database file functions_data: List of dictionaries with function info extracted from DB - test_framework: 'pytest' or 'unittest' max_run_count: Maximum number of runs to include in the test Returns: @@ -83,11 +79,8 @@ def create_trace_replay_test_code( A string containing the test code """ - assert test_framework in ["pytest", "unittest"] - # Create Imports - imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle -{"import unittest" if test_framework == "unittest" else ""} + imports = """from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle from codeflash.benchmarking.replay_test import get_next_arg_and_return """ @@ -158,13 +151,7 @@ def create_trace_replay_test_code( ) # Create main body - - if test_framework == "unittest": - self = "self" - test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n" - else: - test_template = "" - self = "" + test_template = "" for func in functions_data: module_name = func.get("module_name") @@ -223,30 +210,26 @@ def create_trace_replay_test_code( filter_variables=filter_variables, ) - formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ") + formatted_test_body = textwrap.indent(test_body, " ") - test_template += " " if test_framework == "unittest" else "" unique_test_name = get_unique_test_name(module_name, function_name, benchmark_function_name, class_name) - test_template += f"def test_{unique_test_name}({self}):\n{formatted_test_body}\n" + test_template += f"def test_{unique_test_name}():\n{formatted_test_body}\n" return imports + "\n" + metadata + "\n" + test_template -def generate_replay_test( - trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100 -) -> int: +def generate_replay_test(trace_file_path: Path, output_dir: Path, max_run_count: int = 100) -> int: """Generate multiple replay tests from the traced function calls, grouped by benchmark. Args: ---- trace_file_path: Path to the SQLite database file output_dir: Directory to write the generated tests (if None, only returns the code) - test_framework: 'pytest' or 'unittest' max_run_count: Maximum number of runs to include per function Returns: ------- - Dictionary mapping benchmark names to generated test code + The number of replay tests generated """ count = 0 @@ -293,10 +276,7 @@ def generate_replay_test( continue # Generate the test code for this benchmark test_code = create_trace_replay_test_code( - trace_file=trace_file_path.as_posix(), - functions_data=functions_data, - test_framework=test_framework, - max_run_count=max_run_count, + trace_file=trace_file_path.as_posix(), functions_data=functions_data, max_run_count=max_run_count ) test_code = sort_imports(code=test_code) output_file = get_test_file_path( diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 624102b73..496525b9d 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -655,8 +655,10 @@ def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[s if target is None or len(name_parts) == 1: return target - if not isinstance(target, ast.ClassDef): + if not isinstance(target, ast.ClassDef) or len(name_parts) < 2: return None + # At this point, name_parts has at least 2 elements + method_name: str = name_parts[1] # type: ignore[misc] class_skeleton.add((target.lineno, target.body[0].lineno - 1)) cbody = target.body if isinstance(cbody[0], ast.expr): # Is a docstring @@ -669,7 +671,7 @@ def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[s if ( isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)) and len(cnode_name := cnode.name) > 4 - and cnode_name != name_parts[1] + and cnode_name != method_name and cnode_name.isascii() and cnode_name.startswith("__") and cnode_name.endswith("__") @@ -677,7 +679,7 @@ def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[s contextual_dunder_methods.add((target.name, cnode_name)) class_skeleton.add((cnode.lineno, cnode.end_lineno)) - return find_target(target.body, name_parts[1:]) + return find_target(target.body, (method_name,)) with file_path.open(encoding="utf8") as file: source_code: str = file.read() @@ -708,9 +710,14 @@ def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[s ) return None, set() for qualified_name_parts in qualified_name_parts_list: - target_node: ast.AST | None = find_target(module_node.body, qualified_name_parts) + target_node = find_target(module_node.body, qualified_name_parts) if target_node is None: continue + # find_target returns FunctionDef, AsyncFunctionDef, ClassDef, Assign, or AnnAssign - all have lineno/end_lineno + if not isinstance( + target_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Assign, ast.AnnAssign) + ): + continue if ( isinstance(target_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index 4987e6d8d..20b22181a 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -155,8 +155,15 @@ def get_cached_gh_event_data() -> dict[str, Any]: event_path = os.getenv("GITHUB_EVENT_PATH") if not event_path: return {} - with Path(event_path).open(encoding="utf-8") as f: - return json.load(f) # type: ignore # noqa + # Use json.load directly without variable assignment for micro-optimization + # Read file and load JSON in one step + try: + with open(event_path, "rb") as f: # using binary mode for slightly faster reading + return json.loads(f.read().decode("utf-8")) # type: ignore # noqa + except Exception: + # Fallback for unexpected file IO/decoding errors as original code does not handle these + # Matches behavior: if the file cannot be read or json decoding fails, just let exception propagate + raise def is_repo_a_fork() -> bool: diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 5821d23b1..3958f40cf 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -40,6 +40,10 @@ from codeflash.verification.verification_utils import TestConfig from rich.text import Text +_property_id = "property" + +_ast_name = ast.Name + @dataclass(frozen=True) class FunctionProperties: @@ -774,4 +778,8 @@ def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool: - return any(isinstance(node, ast.Name) and node.id == "property" for node in function_node.decorator_list) + for node in function_node.decorator_list: # noqa: SIM110 + # Use isinstance rather than type(...) is ... for better performance with single inheritance trees like ast + if isinstance(node, _ast_name) and node.id == _property_id: + return True + return False diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 0ea380059..483488fdc 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -620,12 +620,14 @@ class TestResults(BaseModel): # noqa: PLW1641 def add(self, function_test_invocation: FunctionTestInvocation) -> None: unique_id = function_test_invocation.unique_invocation_loop_id - if unique_id in self.test_result_idx: + test_result_idx = self.test_result_idx + if unique_id in test_result_idx: if DEBUG_MODE: logger.warning(f"Test result with id {unique_id} already exists. SKIPPING") return - self.test_result_idx[unique_id] = len(self.test_results) - self.test_results.append(function_test_invocation) + test_results = self.test_results + test_result_idx[unique_id] = len(test_results) + test_results.append(function_test_invocation) def merge(self, other: TestResults) -> None: original_len = len(self.test_results) diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index b3e1f8d12..ac757e6a9 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: from argparse import Namespace + from codeflash.benchmarking.function_ranker import FunctionRanker from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import BenchmarkKey, FunctionCalledInTest @@ -53,6 +54,7 @@ def __init__(self, args: Namespace) -> None: self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None) self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None self.replay_tests_dir = None + self.trace_file: Path | None = None self.functions_checkpoint: CodeflashRunCheckpoint | None = None self.current_function_being_optimized: FunctionToOptimize | None = None # current only for the LSP self.current_function_optimizer: FunctionOptimizer | None = None @@ -87,24 +89,26 @@ def run_benchmarks( file_path_to_source_code[file] = f.read() try: instrument_codeflash_trace_decorator(file_to_funcs_to_optimize) - trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace" - if trace_file.exists(): - trace_file.unlink() + self.trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace" + if self.trace_file.exists(): + self.trace_file.unlink() self.replay_tests_dir = Path( tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root) ) trace_benchmarks_pytest( - self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file + self.args.benchmarks_root, self.args.tests_root, self.args.project_root, self.trace_file ) # Run all tests that use pytest-benchmark - replay_count = generate_replay_test(trace_file, self.replay_tests_dir) + replay_count = generate_replay_test(self.trace_file, self.replay_tests_dir) if replay_count == 0: logger.info( f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization" ) else: - function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file) - total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file) + function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings( + self.trace_file + ) + total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(self.trace_file) function_to_results = validate_and_format_benchmark_table( function_benchmark_timings, total_benchmark_timings ) @@ -251,6 +255,145 @@ def discover_tests( ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) return function_to_tests, num_discovered_tests + def display_global_ranking( + self, globally_ranked: list[tuple[Path, FunctionToOptimize]], ranker: FunctionRanker, show_top_n: int = 15 + ) -> None: + from rich.table import Table + + if not globally_ranked: + return + + # Show top N functions + display_count = min(show_top_n, len(globally_ranked)) + + table = Table( + title=f"Function Ranking (Top {display_count} of {len(globally_ranked)})", + title_style="bold cyan", + border_style="cyan", + show_lines=False, + ) + + table.add_column("Priority", style="bold yellow", justify="center", width=8) + table.add_column("Function", style="cyan", width=40) + table.add_column("File", style="dim", width=25) + table.add_column("Addressable Time", justify="right", style="green", width=12) + table.add_column("Impact", justify="center", style="bold", width=8) + + # Get addressable time for display + for i, (file_path, func) in enumerate(globally_ranked[:display_count], 1): + addressable_time = ranker.get_function_addressable_time(func) + + # Format function name + func_name = func.qualified_name + if len(func_name) > 38: + func_name = func_name[:35] + "..." + + # Format file name + file_name = file_path.name + if len(file_name) > 23: + file_name = "..." + file_name[-20:] + + # Format addressable time + if addressable_time >= 1e9: + time_display = f"{addressable_time / 1e9:.2f}s" + elif addressable_time >= 1e6: + time_display = f"{addressable_time / 1e6:.1f}ms" + elif addressable_time >= 1e3: + time_display = f"{addressable_time / 1e3:.1f}µs" + else: + time_display = f"{addressable_time:.0f}ns" + + # Impact indicator + if i <= 5: + impact = "🔥" + impact_style = "bold red" + elif i <= 10: + impact = "⚡" + impact_style = "bold yellow" + else: + impact = "💡" + impact_style = "bold blue" + + table.add_row(f"#{i}", func_name, file_name, time_display, impact, style=impact_style if i <= 5 else None) + + console.print(table) + + if len(globally_ranked) > display_count: + console.print(f"[dim]... and {len(globally_ranked) - display_count} more functions[/dim]") + + def rank_all_functions_globally( + self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], trace_file_path: Path | None + ) -> list[tuple[Path, FunctionToOptimize]]: + """Rank all functions globally across all files based on trace data. + + This performs global ranking instead of per-file ranking, ensuring that + high-impact functions are optimized first regardless of which file they're in. + + Args: + file_to_funcs_to_optimize: Mapping of file paths to functions to optimize + trace_file_path: Path to trace file with performance data + + Returns: + List of (file_path, function) tuples in globally ranked order by addressable time. + If no trace file or ranking fails, returns functions in original file order. + + """ + all_functions: list[tuple[Path, FunctionToOptimize]] = [] + for file_path, functions in file_to_funcs_to_optimize.items(): + all_functions.extend((file_path, func) for func in functions) + + # If no trace file, return in original order + if not trace_file_path or not trace_file_path.exists(): + logger.debug("No trace file available, using original function order") + return all_functions + + try: + from codeflash.benchmarking.function_ranker import FunctionRanker + + console.rule() + logger.info("loading|Ranking functions globally by performance impact...") + console.rule() + # Create ranker with trace data + ranker = FunctionRanker(trace_file_path) + + # Extract just the functions for ranking (without file paths) + functions_only = [func for _, func in all_functions] + + # Rank globally + ranked_functions = ranker.rank_functions(functions_only) + + # Reconstruct with file paths by looking up original file for each ranked function + # Build reverse mapping: function -> file path + # Since FunctionToOptimize is unhashable (contains list), we compare by identity + func_to_file_map = {} + for file_path, func in all_functions: + # Use a tuple of unique identifiers as the key + key: tuple[Path, str, int | None] = (func.file_path, func.qualified_name, func.starting_line) + func_to_file_map[key] = file_path + globally_ranked = [] + for func in ranked_functions: + key = (func.file_path, func.qualified_name, func.starting_line) + file_path = func_to_file_map.get(key) + if file_path: + globally_ranked.append((file_path, func)) + + console.rule() + logger.info( + f"Globally ranked {len(ranked_functions)} functions by addressable time " + f"(filtered {len(functions_only) - len(ranked_functions)} low-importance functions)" + ) + + # Display ranking table for user visibility + self.display_global_ranking(globally_ranked, ranker) + console.rule() + + except Exception as e: + logger.warning(f"Could not perform global ranking: {e}") + logger.debug("Falling back to original function order") + return all_functions + else: + return globally_ranked + def run(self) -> None: from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint @@ -266,6 +409,12 @@ def run(self) -> None: if self.args.worktree: self.worktree_mode() + if not self.args.replay_test and self.test_cfg.tests_root.exists(): + leftover_trace_files = list(self.test_cfg.tests_root.glob("*.trace")) + if leftover_trace_files: + logger.debug(f"Cleaning up {len(leftover_trace_files)} leftover trace file(s) from previous runs") + cleanup_paths(leftover_trace_files) + cleanup_paths(Optimizer.find_leftover_instrumented_test_files(self.test_cfg.tests_root)) function_optimizer = None @@ -297,84 +446,77 @@ def run(self) -> None: if self.args.all: self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root) - for original_module_path in file_to_funcs_to_optimize: - module_prep_result = self.prepare_module_for_optimization(original_module_path) - if module_prep_result is None: - continue - - validated_original_code, original_module_ast = module_prep_result + # GLOBAL RANKING: Rank all functions together before optimizing + globally_ranked_functions = self.rank_all_functions_globally(file_to_funcs_to_optimize, trace_file_path) + # Cache for module preparation (avoid re-parsing same files) + prepared_modules: dict[Path, tuple[dict[Path, ValidCode], ast.Module]] = {} - functions_to_optimize = file_to_funcs_to_optimize[original_module_path] - if trace_file_path and trace_file_path.exists() and len(functions_to_optimize) > 1: - try: - from codeflash.benchmarking.function_ranker import FunctionRanker + # Optimize functions in globally ranked order + for i, (original_module_path, function_to_optimize) in enumerate(globally_ranked_functions): + # Prepare module if not already cached + if original_module_path not in prepared_modules: + module_prep_result = self.prepare_module_for_optimization(original_module_path) + if module_prep_result is None: + logger.warning(f"Skipping functions in {original_module_path} due to preparation error") + continue + prepared_modules[original_module_path] = module_prep_result - ranker = FunctionRanker(trace_file_path) - functions_to_optimize = ranker.rank_functions(functions_to_optimize) - logger.info( - f"Ranked {len(functions_to_optimize)} functions by performance impact in {original_module_path}" - ) - console.rule() - except Exception as e: - logger.debug(f"Could not rank functions in {original_module_path}: {e}") + validated_original_code, original_module_ast = prepared_modules[original_module_path] - for i, function_to_optimize in enumerate(functions_to_optimize): - function_iterator_count = i + 1 - logger.info( - f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: " - f"{function_to_optimize.qualified_name}" + function_iterator_count = i + 1 + logger.info( + f"Optimizing function {function_iterator_count} of {len(globally_ranked_functions)}: " + f"{function_to_optimize.qualified_name} (in {original_module_path.name})" + ) + console.rule() + function_optimizer = None + try: + function_optimizer = self.create_function_optimizer( + function_to_optimize, + function_to_tests=function_to_tests, + function_to_optimize_source_code=validated_original_code[original_module_path].source_code, + function_benchmark_timings=function_benchmark_timings, + total_benchmark_timings=total_benchmark_timings, + original_module_ast=original_module_ast, + original_module_path=original_module_path, ) - console.rule() - function_optimizer = None - try: - function_optimizer = self.create_function_optimizer( - function_to_optimize, - function_to_tests=function_to_tests, - function_to_optimize_source_code=validated_original_code[original_module_path].source_code, - function_benchmark_timings=function_benchmark_timings, - total_benchmark_timings=total_benchmark_timings, - original_module_ast=original_module_ast, - original_module_path=original_module_path, - ) - if function_optimizer is None: - continue + if function_optimizer is None: + continue - self.current_function_optimizer = ( - function_optimizer # needed to clean up from the outside of this function + self.current_function_optimizer = ( + function_optimizer # needed to clean up from the outside of this function + ) + best_optimization = function_optimizer.optimize_function() + if self.functions_checkpoint: + self.functions_checkpoint.add_function_to_checkpoint( + function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root) ) - best_optimization = function_optimizer.optimize_function() - if self.functions_checkpoint: - self.functions_checkpoint.add_function_to_checkpoint( - function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root) + if is_successful(best_optimization): + optimizations_found += 1 + # create a diff patch for successful optimization + if self.current_worktree: + best_opt = best_optimization.unwrap() + read_writable_code = best_opt.code_context.read_writable_code + relative_file_paths = [ + code_string.file_path for code_string in read_writable_code.code_strings + ] + patch_path = create_diff_patch_from_worktree( + self.current_worktree, relative_file_paths, fto_name=function_to_optimize.qualified_name ) - if is_successful(best_optimization): - optimizations_found += 1 - # create a diff patch for successful optimization - if self.current_worktree: - best_opt = best_optimization.unwrap() - read_writable_code = best_opt.code_context.read_writable_code - relative_file_paths = [ - code_string.file_path for code_string in read_writable_code.code_strings - ] - patch_path = create_diff_patch_from_worktree( - self.current_worktree, - relative_file_paths, - fto_name=function_to_optimize.qualified_name, + self.patch_files.append(patch_path) + if i < len(globally_ranked_functions) - 1: + _, next_func = globally_ranked_functions[i + 1] + create_worktree_snapshot_commit( + self.current_worktree, f"Optimizing {next_func.qualified_name}" ) - self.patch_files.append(patch_path) - if i < len(functions_to_optimize) - 1: - create_worktree_snapshot_commit( - self.current_worktree, - f"Optimizing {functions_to_optimize[i + 1].qualified_name}", - ) - else: - logger.warning(best_optimization.failure()) - console.rule() - continue - finally: - if function_optimizer is not None: - function_optimizer.executor.shutdown(wait=True) - function_optimizer.cleanup_generated_files() + else: + logger.warning(best_optimization.failure()) + console.rule() + continue + finally: + if function_optimizer is not None: + function_optimizer.executor.shutdown(wait=True) + function_optimizer.cleanup_generated_files() ph("cli-optimize-run-finished", {"optimizations_found": optimizations_found}) if len(self.patch_files) > 0: @@ -421,9 +563,15 @@ def find_leftover_instrumented_test_files(test_root: Path) -> list[Path]: ] def cleanup_replay_tests(self) -> None: + paths_to_cleanup = [] if self.replay_tests_dir and self.replay_tests_dir.exists(): logger.debug(f"Cleaning up replay tests directory: {self.replay_tests_dir}") - cleanup_paths([self.replay_tests_dir]) + paths_to_cleanup.append(self.replay_tests_dir) + if self.trace_file and self.trace_file.exists(): + logger.debug(f"Cleaning up trace file: {self.trace_file}") + paths_to_cleanup.append(self.trace_file) + if paths_to_cleanup: + cleanup_paths(paths_to_cleanup) def cleanup_temporary_paths(self) -> None: if hasattr(get_run_tmp_file, "tmpdir"): @@ -436,7 +584,14 @@ def cleanup_temporary_paths(self) -> None: if self.current_function_optimizer: self.current_function_optimizer.cleanup_generated_files() - cleanup_paths([self.test_cfg.concolic_test_root_dir, self.replay_tests_dir]) + paths_to_cleanup = [self.test_cfg.concolic_test_root_dir, self.replay_tests_dir] + if self.trace_file: + paths_to_cleanup.append(self.trace_file) + if self.test_cfg.tests_root.exists(): + for trace_file in self.test_cfg.tests_root.glob("*.trace"): + if trace_file not in paths_to_cleanup: + paths_to_cleanup.append(trace_file) + cleanup_paths(paths_to_cleanup) def worktree_mode(self) -> None: if self.current_worktree: diff --git a/codeflash/tracer.py b/codeflash/tracer.py index d1c4cd176..eb011befa 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -120,9 +120,6 @@ def main(args: Namespace | None = None) -> ArgumentParser: result_pickle_file_path = get_run_tmp_file(Path(f"tracer_results_file_{i}.pkl")) result_pickle_file_paths.append(result_pickle_file_path) args_dict["result_pickle_file_path"] = str(result_pickle_file_path) - outpath = parsed_args.outfile - outpath = outpath.parent / f"{outpath.stem}_{i}{outpath.suffix}" - args_dict["output"] = str(outpath) updated_sys_argv = [] for elem in sys.argv: if elem in test_paths_set: @@ -164,7 +161,6 @@ def main(args: Namespace | None = None) -> ArgumentParser: else: result_pickle_file_path = get_run_tmp_file(Path("tracer_results_file.pkl")) args_dict["result_pickle_file_path"] = str(result_pickle_file_path) - args_dict["output"] = str(parsed_args.outfile) args_dict["command"] = " ".join(sys.argv) env = os.environ.copy() diff --git a/codeflash/tracing/replay_test.py b/codeflash/tracing/replay_test.py index d2b8c07b1..b1b10f56e 100644 --- a/codeflash/tracing/replay_test.py +++ b/codeflash/tracing/replay_test.py @@ -43,16 +43,8 @@ def get_function_alias(module: str, function_name: str) -> str: return "_".join(module.split(".")) + "_" + function_name -def create_trace_replay_test( - trace_file: str, - functions: list[FunctionModules], - test_framework: str = "pytest", - max_run_count=100, # noqa: ANN001 -) -> str: - assert test_framework in {"pytest", "unittest"} - - imports = f"""import dill as pickle -{"import unittest" if test_framework == "unittest" else ""} +def create_trace_replay_test(trace_file: str, functions: list[FunctionModules], max_run_count: int = 100) -> str: + imports = """import dill as pickle from codeflash.tracing.replay_test import get_next_arg_and_return """ @@ -112,12 +104,7 @@ def create_trace_replay_test( ret = {class_name_alias}{method_name}(**args) """ ) - if test_framework == "unittest": - self = "self" - test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n" - else: - test_template = "" - self = "" + test_template = "" for func, func_property in zip(functions, function_properties): if func_property is None: continue @@ -167,9 +154,8 @@ def create_trace_replay_test( max_run_count=max_run_count, filter_variables=filter_variables, ) - formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ") + formatted_test_body = textwrap.indent(test_body, " ") - test_template += " " if test_framework == "unittest" else "" - test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n" + test_template += f"def test_{alias}():\n{formatted_test_body}\n" return imports + "\n" + metadata + "\n" + test_template diff --git a/codeflash/tracing/tracing_new_process.py b/codeflash/tracing/tracing_new_process.py index ec1794f09..d4daedd26 100644 --- a/codeflash/tracing/tracing_new_process.py +++ b/codeflash/tracing/tracing_new_process.py @@ -70,7 +70,6 @@ def __init__( self, config: dict, result_pickle_file_path: Path, - output: str = "codeflash.trace", functions: list[str] | None = None, disable: bool = False, # noqa: FBT001, FBT002 project_root: Path | None = None, @@ -80,7 +79,6 @@ def __init__( ) -> None: """Use this class to trace function calls. - :param output: The path to the output trace file :param functions: List of functions to trace. If None, trace all functions :param disable: Disable the tracer if True :param max_function_count: Maximum number of times to trace one function @@ -110,7 +108,6 @@ def __init__( self._db_lock = threading.Lock() self.con = None - self.output_file = Path(output).resolve() self.functions = functions self.function_modules: list[FunctionModules] = [] self.function_count = defaultdict(int) @@ -126,6 +123,15 @@ def __init__( self.ignored_functions = {"", "", "", "", "", ""} self.sanitized_filename = self.sanitize_to_filename(command) + # Place trace file next to replay tests in the tests directory + from codeflash.verification.verification_utils import get_test_file_path + + function_path = "_".join(functions) if functions else self.sanitized_filename + test_file_path = get_test_file_path( + test_dir=Path(config["tests_root"]), function_name=function_path, test_type="replay" + ) + trace_filename = test_file_path.stem + ".trace" + self.output_file = test_file_path.parent / trace_filename self.result_pickle_file_path = result_pickle_file_path assert timeout is None or timeout > 0, "Timeout should be greater than 0" @@ -142,7 +148,6 @@ def __init__( self.timer = time.process_time_ns self.total_tt = 0 self.simulate_call("profiler") - assert "test_framework" in self.config, "Please specify 'test-framework' in pyproject.toml config file" self.t = self.timer() # Store command information for metadata table @@ -273,10 +278,7 @@ def __exit__( from codeflash.verification.verification_utils import get_test_file_path replay_test = create_trace_replay_test( - trace_file=self.output_file, - functions=self.function_modules, - test_framework=self.config["test_framework"], - max_run_count=self.max_function_count, + trace_file=self.output_file, functions=self.function_modules, max_run_count=self.max_function_count ) function_path = "_".join(self.functions) if self.functions else self.sanitized_filename test_file_path = get_test_file_path( @@ -770,11 +772,11 @@ def make_pstats_compatible(self) -> None: self.files = [] self.top_level = [] new_stats = {} - for func, (cc, ns, tt, ct, callers) in self.stats.items(): + for func, (cc, ns, tt, ct, callers) in list(self.stats.items()): new_callers = {(k[0], k[1], k[2]): v for k, v in callers.items()} new_stats[(func[0], func[1], func[2])] = (cc, ns, tt, ct, new_callers) new_timings = {} - for func, (cc, ns, tt, ct, callers) in self.timings.items(): + for func, (cc, ns, tt, ct, callers) in list(self.timings.items()): new_callers = {(k[0], k[1], k[2]): v for k, v in callers.items()} new_timings[(func[0], func[1], func[2])] = (cc, ns, tt, ct, new_callers) self.stats = new_stats @@ -857,7 +859,6 @@ def runctx(self, cmd: str, global_vars: dict[str, Any], local_vars: dict[str, An args_dict["config"]["tests_root"] = Path(args_dict["config"]["tests_root"]) tracer = Tracer( config=args_dict["config"], - output=Path(args_dict["output"]), functions=args_dict["functions"], max_function_count=args_dict["max_function_count"], timeout=args_dict["timeout"], diff --git a/tests/test_function_ranker.py b/tests/test_function_ranker.py index 0cb1bb776..b5f216c0c 100644 --- a/tests/test_function_ranker.py +++ b/tests/test_function_ranker.py @@ -51,7 +51,7 @@ def test_load_function_stats(function_ranker): expected_keys = { "filename", "function_name", "qualified_name", "class_name", "line_number", "call_count", "own_time_ns", "cumulative_time_ns", - "time_in_callees_ns", "ttx_score" + "time_in_callees_ns", "addressable_time_ns" } assert set(func_a_stats.keys()) == expected_keys @@ -62,7 +62,7 @@ def test_load_function_stats(function_ranker): assert func_a_stats["cumulative_time_ns"] == 5443000 -def test_get_function_ttx_score(function_ranker, workload_functions): +def test_get_function_addressable_time(function_ranker, workload_functions): func_a = None for func in workload_functions: if func.function_name == "funcA": @@ -70,34 +70,29 @@ def test_get_function_ttx_score(function_ranker, workload_functions): break assert func_a is not None - ttx_score = function_ranker.get_function_ttx_score(func_a) + addressable_time = function_ranker.get_function_addressable_time(func_a) - # Expected ttX score: own_time + (time_in_callees / call_count) + # Expected addressable time: own_time + (time_in_callees / call_count) # = 63000 + ((5443000 - 63000) / 1) = 5443000 - assert ttx_score == 5443000 + assert addressable_time == 5443000 def test_rank_functions(function_ranker, workload_functions): ranked_functions = function_ranker.rank_functions(workload_functions) - assert len(ranked_functions) == len(workload_functions) + # Should filter out functions below importance threshold and sort by addressable time + assert len(ranked_functions) <= len(workload_functions) + assert len(ranked_functions) > 0 # At least some functions should pass the threshold - # Verify functions are sorted by ttX score in descending order - for i in range(len(ranked_functions) - 1): - current_score = function_ranker.get_function_ttx_score(ranked_functions[i]) - next_score = function_ranker.get_function_ttx_score(ranked_functions[i + 1]) - assert current_score >= next_score - - -def test_rerank_and_filter_functions(function_ranker, workload_functions): - filtered_ranked = function_ranker.rerank_and_filter_functions(workload_functions) - - # Should filter out functions below importance threshold - assert len(filtered_ranked) <= len(workload_functions) - - # funcA should pass the importance threshold (0.33% > 0.1%) - func_a_in_results = any(f.function_name == "funcA" for f in filtered_ranked) + # funcA should pass the importance threshold + func_a_in_results = any(f.function_name == "funcA" for f in ranked_functions) assert func_a_in_results + + # Verify functions are sorted by addressable time in descending order + for i in range(len(ranked_functions) - 1): + current_time = function_ranker.get_function_addressable_time(ranked_functions[i]) + next_time = function_ranker.get_function_addressable_time(ranked_functions[i + 1]) + assert current_time >= next_time def test_get_function_stats_summary(function_ranker, workload_functions): @@ -114,7 +109,7 @@ def test_get_function_stats_summary(function_ranker, workload_functions): assert stats["function_name"] == "funcA" assert stats["own_time_ns"] == 63000 assert stats["cumulative_time_ns"] == 5443000 - assert stats["ttx_score"] == 5443000 + assert stats["addressable_time_ns"] == 5443000 @@ -154,13 +149,13 @@ def test_simple_model_predict_stats(function_ranker, workload_functions): assert stats["call_count"] == 1 assert stats["own_time_ns"] == 2289000 assert stats["cumulative_time_ns"] == 4017000 - assert stats["ttx_score"] == 4017000 + assert stats["addressable_time_ns"] == 4017000 - # Test ttX score calculation - ttx_score = function_ranker.get_function_ttx_score(predict_func) - # Expected ttX score: own_time + (time_in_callees / call_count) + # Test addressable time calculation + addressable_time = function_ranker.get_function_addressable_time(predict_func) + # Expected addressable time: own_time + (time_in_callees / call_count) # = 2289000 + ((4017000 - 2289000) / 1) = 4017000 - assert ttx_score == 4017000 + assert addressable_time == 4017000 # Test importance calculation for predict function total_program_time = sum( diff --git a/tests/test_tracer.py b/tests/test_tracer.py index b00449100..b9b8a7b26 100644 --- a/tests/test_tracer.py +++ b/tests/test_tracer.py @@ -104,10 +104,10 @@ def test_tracer_disabled_by_environment(self, trace_config: TraceConfig) -> None """Test that tracer is disabled when CODEFLASH_TRACER_DISABLE is set.""" with patch.dict("os.environ", {"CODEFLASH_TRACER_DISABLE": "1"}): tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, ) assert tracer.disable is True @@ -120,10 +120,10 @@ def dummy_profiler(_frame: object, _event: str, _arg: object) -> object: sys.setprofile(dummy_profiler) try: tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, ) assert tracer.disable is True finally: @@ -132,17 +132,16 @@ def dummy_profiler(_frame: object, _event: str, _arg: object) -> object: def test_tracer_initialization_normal(self, trace_config: TraceConfig) -> None: """Test normal tracer initialization.""" tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, functions=["test_func"], max_function_count=128, timeout=10, ) assert tracer.disable is False - assert tracer.output_file == trace_config.trace_file.resolve() assert tracer.functions == ["test_func"] assert tracer.max_function_count == 128 assert tracer.timeout == 10 @@ -152,35 +151,37 @@ def test_tracer_initialization_normal(self, trace_config: TraceConfig) -> None: def test_tracer_timeout_validation(self, trace_config: TraceConfig) -> None: with pytest.raises(AssertionError): Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, timeout=0, ) with pytest.raises(AssertionError): Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, timeout=-5, ) def test_tracer_context_manager_disabled(self, trace_config: TraceConfig) -> None: tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, disable=True, ) with tracer: pass - assert not trace_config.trace_file.exists() + # When disabled, the tracer doesn't create a trace file + # Note: output_file attribute won't exist when disabled, so we check if disable is True + assert tracer.disable is True def test_tracer_function_filtering(self, trace_config: TraceConfig) -> None: """Test that tracer respects function filtering.""" @@ -194,10 +195,10 @@ def other_function() -> int: return 24 tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, functions=["test_function"], ) @@ -205,8 +206,8 @@ def other_function() -> int: test_function() other_function() - if trace_config.trace_file.exists(): - con = sqlite3.connect(trace_config.trace_file) + if tracer.output_file.exists(): + con = sqlite3.connect(tracer.output_file) cursor = con.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'") @@ -224,10 +225,10 @@ def counting_function(n: int) -> int: return n * 2 tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, max_function_count=3, ) @@ -243,10 +244,10 @@ def slow_function() -> str: return "done" tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, timeout=1, # 1 second timeout ) @@ -261,10 +262,10 @@ def thread_function(n: int) -> None: results.append(n * 2) tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, ) with tracer: @@ -282,10 +283,10 @@ def thread_function(n: int) -> None: def test_simulate_call(self, trace_config: TraceConfig) -> None: """Test the simulate_call method.""" tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, ) tracer.simulate_call("test_simulation") @@ -293,10 +294,10 @@ def test_simulate_call(self, trace_config: TraceConfig) -> None: def test_simulate_cmd_complete(self, trace_config: TraceConfig) -> None: """Test the simulate_cmd_complete method.""" tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, ) tracer.simulate_call("test") @@ -305,10 +306,10 @@ def test_simulate_cmd_complete(self, trace_config: TraceConfig) -> None: def test_runctx_method(self, trace_config: TraceConfig) -> None: """Test the runctx method for executing code with tracing.""" tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, ) global_vars = {"x": 10} @@ -338,10 +339,10 @@ def static_method() -> str: return "static" tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, ) with tracer: @@ -350,8 +351,8 @@ def static_method() -> str: class_result = TestClass.class_method() static_result = TestClass.static_method() - if trace_config.trace_file.exists(): - con = sqlite3.connect(trace_config.trace_file) + if tracer.output_file.exists(): + con = sqlite3.connect(tracer.output_file) cursor = con.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'") @@ -378,10 +379,10 @@ def failing_function() -> None: raise ValueError("Test exception") tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, ) with tracer, contextlib.suppress(ValueError): @@ -394,10 +395,10 @@ def complex_function( return len(data_dict) + len(nested_list) tracer = Tracer( - output=str(trace_config.trace_file), config=trace_config.trace_config, project_root=trace_config.project_root, result_pickle_file_path=trace_config.result_pickle_file_path, + command=trace_config.command, ) expected_dict = {"key": "value", "nested": {"inner": "data"}} @@ -410,8 +411,8 @@ def complex_function( pickled = pickle.load(trace_config.result_pickle_file_path.open("rb")) assert pickled["replay_test_file_path"].exists() - if trace_config.trace_file.exists(): - con = sqlite3.connect(trace_config.trace_file) + if tracer.output_file.exists(): + con = sqlite3.connect(tracer.output_file) cursor = con.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='function_calls'")