Skip to content

Commit 9c670db

Browse files
authored
Merge branch 'main' into feat/cli/login
2 parents 8d7bc7f + 848faa5 commit 9c670db

23 files changed

+1667
-876
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -249,25 +249,29 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
249249

250250

251251
def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
252-
if hasattr(args, "all"):
253-
import git
254-
255-
from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
256-
from codeflash.code_utils.github_utils import require_github_app_or_exit
257-
258-
# Ensure that the user can actually open PRs on the repo.
259-
try:
260-
git_repo = git.Repo(search_parent_directories=True)
261-
except git.exc.InvalidGitRepositoryError:
262-
logger.exception(
263-
"I couldn't find a git repository in the current directory. "
264-
"I need a git repository to run --all and open PRs for optimizations. Exiting..."
265-
)
266-
apologize_and_exit()
267-
if not args.no_pr and not check_and_push_branch(git_repo, git_remote=args.git_remote):
268-
exit_with_message("Branch is not pushed...", error_on_exit=True)
269-
owner, repo = get_repo_owner_and_name(git_repo)
270-
if not args.no_pr:
252+
if hasattr(args, "all") or (hasattr(args, "file") and args.file):
253+
no_pr = getattr(args, "no_pr", False)
254+
255+
if not no_pr:
256+
import git
257+
258+
from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
259+
from codeflash.code_utils.github_utils import require_github_app_or_exit
260+
261+
# Ensure that the user can actually open PRs on the repo.
262+
try:
263+
git_repo = git.Repo(search_parent_directories=True)
264+
except git.exc.InvalidGitRepositoryError:
265+
mode = "--all" if hasattr(args, "all") else "--file"
266+
logger.exception(
267+
f"I couldn't find a git repository in the current directory. "
268+
f"I need a git repository to run {mode} and open PRs for optimizations. Exiting..."
269+
)
270+
apologize_and_exit()
271+
git_remote = getattr(args, "git_remote", None)
272+
if not check_and_push_branch(git_repo, git_remote=git_remote):
273+
exit_with_message("Branch is not pushed...", error_on_exit=True)
274+
owner, repo = get_repo_owner_and_name(git_repo)
271275
require_github_app_or_exit(owner, repo)
272276
if not hasattr(args, "all"):
273277
args.all = None

codeflash/cli_cmds/cmd_init.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,23 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
172172
run_end_to_end_test(args, file_path)
173173

174174

175-
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[bool, dict[str, Any] | None, str]: # noqa: PLR0911
175+
def config_found(pyproject_toml_path: Union[str, Path]) -> tuple[bool, str]:
176+
pyproject_toml_path = Path(pyproject_toml_path)
177+
176178
if not pyproject_toml_path.exists():
177-
return False, None, f"Configuration file not found: {pyproject_toml_path}"
179+
return False, f"Configuration file not found: {pyproject_toml_path}"
180+
181+
if not pyproject_toml_path.is_file():
182+
return False, f"Configuration file is not a file: {pyproject_toml_path}"
183+
184+
if pyproject_toml_path.suffix != ".toml":
185+
return False, f"Configuration file is not a .toml file: {pyproject_toml_path}"
186+
187+
return True, ""
178188

189+
190+
def is_valid_pyproject_toml(pyproject_toml_path: Union[str, Path]) -> tuple[bool, dict[str, Any] | None, str]:
191+
pyproject_toml_path = Path(pyproject_toml_path)
179192
try:
180193
config, _ = parse_config_file(pyproject_toml_path)
181194
except Exception as e:
@@ -207,6 +220,10 @@ def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
207220

208221
pyproject_toml_path = Path.cwd() / "pyproject.toml"
209222

223+
found, _ = config_found(pyproject_toml_path)
224+
if not found:
225+
return True, None
226+
210227
valid, config, _message = is_valid_pyproject_toml(pyproject_toml_path)
211228
if not valid:
212229
# needs to be re-configured
@@ -1407,7 +1424,7 @@ def ask_for_telemetry() -> bool:
14071424
from rich.prompt import Confirm
14081425

14091426
return Confirm.ask(
1410-
"⚡️ Would you like to enable telemetry to help us improve the Codeflash experience?",
1427+
"⚡️ Help us improve Codeflash by sharing anonymous usage data (e.g. errors encountered)?",
14111428
default=True,
14121429
show_default=True,
14131430
)

codeflash/code_utils/config_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def parse_config_file(
105105
if lsp_mode:
106106
# don't fail in lsp mode if codeflash config is not found.
107107
return {}, config_file_path
108-
msg = f"Could not find the 'codeflash' block in the config file {config_file_path}. Please run 'codeflash init' to create the config file."
108+
msg = f"Could not find the 'codeflash' block in the config file {config_file_path}. Please run 'codeflash init' to add Codeflash config in the pyproject.toml config file."
109109
raise ValueError(msg) from e
110110
assert isinstance(config, dict)
111111

codeflash/code_utils/formatter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@ def is_diff_line(line: str) -> bool:
9797

9898

9999
def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str:
100+
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
101+
if formatter_name == "disabled": # nothing to do if no formatter provided
102+
return re.sub(r"\n{2,}", "\n\n", generated_test_source)
100103
with tempfile.TemporaryDirectory() as test_dir_str:
101104
# try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines
102105
original_temp = Path(test_dir_str) / "original_temp.py"

codeflash/code_utils/git_worktree_utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import configparser
34
import subprocess
45
import tempfile
56
import time
@@ -18,14 +19,36 @@
1819

1920
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
2021
repository = git.Repo(worktree_dir, search_parent_directories=True)
21-
with repository.config_writer() as cw:
22+
username = None
23+
no_username = False
24+
email = None
25+
no_email = False
26+
with repository.config_reader(config_level="repository") as cr:
27+
try:
28+
username = cr.get("user", "name")
29+
except (configparser.NoSectionError, configparser.NoOptionError):
30+
no_username = True
31+
try:
32+
email = cr.get("user", "email")
33+
except (configparser.NoSectionError, configparser.NoOptionError):
34+
no_email = True
35+
with repository.config_writer(config_level="repository") as cw:
2236
if not cw.has_option("user", "name"):
2337
cw.set_value("user", "name", "Codeflash Bot")
2438
if not cw.has_option("user", "email"):
2539
cw.set_value("user", "email", "bot@codeflash.ai")
2640

2741
repository.git.add(".")
2842
repository.git.commit("-m", commit_message, "--no-verify")
43+
with repository.config_writer(config_level="repository") as cw:
44+
if username:
45+
cw.set_value("user", "name", username)
46+
elif no_username:
47+
cw.remove_option("user", "name")
48+
if email:
49+
cw.set_value("user", "email", email)
50+
elif no_email:
51+
cw.remove_option("user", "email")
2952

3053

3154
def create_detached_worktree(module_root: Path) -> Optional[Path]:

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -684,27 +684,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
684684
)
685685

686686

687-
def instrument_source_module_with_async_decorators(
688-
source_path: Path, function_to_optimize: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
689-
) -> tuple[bool, str | None]:
690-
if not function_to_optimize.is_async:
691-
return False, None
692-
693-
try:
694-
with source_path.open(encoding="utf8") as f:
695-
source_code = f.read()
696-
697-
modified_code, decorator_added = add_async_decorator_to_function(source_code, function_to_optimize, mode)
698-
699-
if decorator_added:
700-
return True, modified_code
701-
702-
except Exception:
703-
return False, None
704-
else:
705-
return False, None
706-
707-
708687
def inject_async_profiling_into_existing_test(
709688
test_path: Path,
710689
call_positions: list[CodePosition],
@@ -1288,25 +1267,29 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
12881267

12891268

12901269
def add_async_decorator_to_function(
1291-
source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
1292-
) -> tuple[str, bool]:
1293-
"""Add async decorator to an async function definition.
1270+
source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
1271+
) -> bool:
1272+
"""Add async decorator to an async function definition and write back to file.
12941273
12951274
Args:
12961275
----
1297-
source_code: The source code to modify.
1276+
source_path: Path to the source file to modify in-place.
12981277
function: The FunctionToOptimize object representing the target async function.
12991278
mode: The testing mode to determine which decorator to apply.
13001279
13011280
Returns:
13021281
-------
1303-
Tuple of (modified_source_code, was_decorator_added).
1282+
Boolean indicating whether the decorator was successfully added.
13041283
13051284
"""
13061285
if not function.is_async:
1307-
return source_code, False
1286+
return False
13081287

13091288
try:
1289+
# Read source code
1290+
with source_path.open(encoding="utf8") as f:
1291+
source_code = f.read()
1292+
13101293
module = cst.parse_module(source_code)
13111294

13121295
# Add the decorator to the function
@@ -1318,10 +1301,17 @@ def add_async_decorator_to_function(
13181301
import_transformer = AsyncDecoratorImportAdder(mode)
13191302
module = module.visit(import_transformer)
13201303

1321-
return sort_imports(code=module.code, float_to_top=True), decorator_transformer.added_decorator
1304+
modified_code = sort_imports(code=module.code, float_to_top=True)
13221305
except Exception as e:
13231306
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
1324-
return source_code, False
1307+
return False
1308+
else:
1309+
if decorator_transformer.added_decorator:
1310+
with source_path.open("w", encoding="utf8") as f:
1311+
f.write(modified_code)
1312+
logger.debug(f"Applied async {mode.value} instrumentation to {source_path}")
1313+
return True
1314+
return False
13251315

13261316

13271317
def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path:

codeflash/context/unused_definition_remover.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -469,22 +469,32 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
469469
qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname'
470470
471471
"""
472-
module = cst.parse_module(code)
473-
# Collect all definitions (top level classes, variables or function)
474-
definitions = collect_top_level_definitions(module)
472+
try:
473+
module = cst.parse_module(code)
474+
except Exception as e:
475+
logger.debug(f"Failed to parse code with libcst: {type(e).__name__}: {e}")
476+
return code
475477

476-
# Collect dependencies between definitions using the visitor pattern
477-
dependency_collector = DependencyCollector(definitions)
478-
module.visit(dependency_collector)
478+
try:
479+
# Collect all definitions (top level classes, variables or function)
480+
definitions = collect_top_level_definitions(module)
479481

480-
# Mark definitions used by specified functions, and their dependencies recursively
481-
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
482-
usage_marker.mark_used_definitions()
482+
# Collect dependencies between definitions using the visitor pattern
483+
dependency_collector = DependencyCollector(definitions)
484+
module.visit(dependency_collector)
483485

484-
# Apply the recursive removal transformation
485-
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
486+
# Mark definitions used by specified functions, and their dependencies recursively
487+
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
488+
usage_marker.mark_used_definitions()
486489

487-
return modified_module.code if modified_module else ""
490+
# Apply the recursive removal transformation
491+
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
492+
493+
return modified_module.code if modified_module else "" # noqa: TRY300
494+
except Exception as e:
495+
# If any other error occurs during processing, return the original code
496+
logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}")
497+
return code
488498

489499

490500
def print_definitions(definitions: dict[str, UsageInfo]) -> None:

0 commit comments

Comments
 (0)