Skip to content

Commit 6696f07

Browse files
authored
fix the tracer (#884)
* parse args correctly When there were fewer than 4 test files, the pytest_split() function returned a flat list of strings instead of a list of lists * update python path correctly * improve messaging here * Revert "improve messaging here" This reverts commit b6ab255. * improve error slightly
1 parent ed065b8 commit 6696f07

File tree

4 files changed

+29
-5
lines changed

4 files changed

+29
-5
lines changed

codeflash/code_utils/config_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
3838
dir_path = dir_path.parent
3939
msg = f"Could not find pyproject.toml in the current directory {Path.cwd()} or any of the parent directories. Please create it by running `codeflash init`, or pass the path to pyproject.toml with the --config-file argument."
4040

41-
raise ValueError(msg)
41+
raise ValueError(msg) from None
4242

4343

4444
def get_all_closest_config_files() -> list[Path]:
@@ -93,7 +93,7 @@ def parse_config_file(
9393
data = tomlkit.parse(f.read())
9494
except tomlkit.exceptions.ParseError as e:
9595
msg = f"Error while parsing the config file {config_file_path}. Please recheck the file for syntax errors. Error: {e}"
96-
raise ValueError(msg) from e
96+
raise ValueError(msg) from None
9797

9898
lsp_mode = is_LSP_enabled()
9999

codeflash/tracer.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import json
15+
import os
1516
import pickle
1617
import subprocess
1718
import sys
@@ -64,13 +65,15 @@ def main(args: Namespace | None = None) -> ArgumentParser:
6465
parsed_args.tracer_timeout = getattr(args, "timeout", None)
6566
parsed_args.codeflash_config = getattr(args, "config_file_path", None)
6667
parsed_args.trace_only = getattr(args, "trace_only", False)
67-
parsed_args.module = False
68+
69+
temp_parsed, unknown_args = parser.parse_known_args()
70+
parsed_args.module = temp_parsed.module
71+
sys.argv[:] = unknown_args
6872

6973
if getattr(args, "disable", False):
7074
console.rule("Codeflash: Tracer disabled by --disable option", style="bold red")
7175
return parser
7276

73-
unknown_args = []
7477
else:
7578
if not sys.argv[1:]:
7679
parser.print_usage()
@@ -127,6 +130,13 @@ def main(args: Namespace | None = None) -> ArgumentParser:
127130
else:
128131
updated_sys_argv.append(elem)
129132
args_dict["command"] = " ".join(updated_sys_argv)
133+
env = os.environ.copy()
134+
pythonpath = env.get("PYTHONPATH", "")
135+
project_root_str = str(project_root)
136+
if pythonpath:
137+
env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}"
138+
else:
139+
env["PYTHONPATH"] = project_root_str
130140
processes.append(
131141
subprocess.Popen(
132142
[
@@ -136,6 +146,7 @@ def main(args: Namespace | None = None) -> ArgumentParser:
136146
json.dumps(args_dict),
137147
],
138148
cwd=Path.cwd(),
149+
env=env,
139150
)
140151
)
141152
for process in processes:
@@ -156,6 +167,15 @@ def main(args: Namespace | None = None) -> ArgumentParser:
156167
args_dict["output"] = str(parsed_args.outfile)
157168
args_dict["command"] = " ".join(sys.argv)
158169

170+
env = os.environ.copy()
171+
# Add project root to PYTHONPATH so imports work correctly
172+
pythonpath = env.get("PYTHONPATH", "")
173+
project_root_str = str(project_root)
174+
if pythonpath:
175+
env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}"
176+
else:
177+
env["PYTHONPATH"] = project_root_str
178+
159179
subprocess.run(
160180
[
161181
SAFE_SYS_EXECUTABLE,
@@ -164,6 +184,7 @@ def main(args: Namespace | None = None) -> ArgumentParser:
164184
json.dumps(args_dict),
165185
],
166186
cwd=Path.cwd(),
187+
env=env,
167188
check=False,
168189
)
169190
try:

codeflash/tracing/pytest_parallelization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def pytest_split(
6262
# If we have fewer test files than 4 * num_splits, reduce num_splits
6363
max_possible_splits = len(test_files) // 4
6464
if max_possible_splits == 0:
65-
return test_files, test_paths
65+
return [test_files], test_paths
6666

6767
num_splits = min(num_splits, max_possible_splits)
6868

codeflash/update_license_version.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ def main() -> None:
99
# Use the version tuple from version.py
1010
version = __version__
1111

12+
if ".dev" in version or "+" in version or "post" in version:
13+
return
14+
1215
# Use the major and minor version components from the version tuple
1316
major_minor_version = ".".join(map(str, version.split(".")[:2]))
1417

0 commit comments

Comments
 (0)