diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 8743ab2ac..f86ffe077 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -625,6 +625,60 @@ def get_optimization_review( console.rule() return "" + def generate_workflow_steps( + self, + repo_files: dict[str, str], + directory_structure: dict[str, Any], + codeflash_config: dict[str, Any] | None = None, + ) -> str | None: + """Generate GitHub Actions workflow steps based on repository analysis. + + :param repo_files: Dictionary mapping file paths to their contents + :param directory_structure: 2-level nested directory structure + :param codeflash_config: Optional codeflash configuration + :return: YAML string for workflow steps section, or None on error + """ + payload = { + "repo_files": repo_files, + "directory_structure": directory_structure, + "codeflash_config": codeflash_config, + } + + logger.debug( + f"[aiservice.py:generate_workflow_steps] Sending request to AI service with {len(repo_files)} files, " + f"{len(directory_structure)} top-level directories" + ) + + try: + response = self.make_ai_service_request("/workflow-gen", payload=payload, timeout=60) + except requests.exceptions.RequestException as e: + # AI service unavailable - this is expected, will fall back to static workflow + logger.debug( + f"[aiservice.py:generate_workflow_steps] Request exception (falling back to static workflow): {e}" + ) + return None + + if response.status_code == 200: + response_data = response.json() + workflow_steps = cast("str", response_data.get("workflow_steps")) + logger.debug( + f"[aiservice.py:generate_workflow_steps] Successfully received workflow steps " + f"({len(workflow_steps) if workflow_steps else 0} chars)" + ) + return workflow_steps + # AI service unavailable or endpoint not found - this is expected, will fall back to static workflow + logger.debug( + f"[aiservice.py:generate_workflow_steps] AI service returned status {response.status_code}, " + f"falling back to static workflow generation" + ) + try: + error_response = response.json() + error = cast("str", error_response.get("error", "Unknown error")) + logger.debug(f"[aiservice.py:generate_workflow_steps] Error: {error}") + except Exception: + logger.debug("[aiservice.py:generate_workflow_steps] Could not parse error response") + return None + class LocalAiServiceClient(AiServiceClient): """Client for interacting with the local AI service.""" diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index fad6eaa4d..688f35278 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -55,13 +55,17 @@ def make_cfapi_request( *, api_key: str | None = None, suppress_errors: bool = False, + params: dict[str, Any] | None = None, ) -> Response: """Make an HTTP request using the specified method, URL, headers, and JSON payload. :param endpoint: The endpoint URL to send the request to. :param method: The HTTP method to use ('GET', 'POST', etc.). :param payload: Optional JSON payload to include in the POST request body. + :param extra_headers: Optional extra headers to include in the request. + :param api_key: Optional API key to use for authentication. :param suppress_errors: If True, suppress error logging for HTTP errors. + :param params: Optional query parameters for GET requests. :return: The response object from the API. """ url = f"{get_cfapi_base_urls().cfapi_base_url}/cfapi{endpoint}" @@ -75,7 +79,7 @@ def make_cfapi_request( cfapi_headers["Content-Type"] = "application/json" response = requests.post(url, data=json_payload, headers=cfapi_headers, timeout=60) else: - response = requests.get(url, headers=cfapi_headers, timeout=60) + response = requests.get(url, headers=cfapi_headers, params=params, timeout=60) response.raise_for_status() return response # noqa: TRY300 except requests.exceptions.HTTPError: @@ -239,6 +243,20 @@ def create_pr( return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload) +def setup_github_actions(owner: str, repo: str, base_branch: str, workflow_content: str) -> Response: + """Set up GitHub Actions workflow by creating a PR with the workflow file. + + :param owner: Repository owner (username or organization) + :param repo: Repository name + :param base_branch: Base branch to create PR against (e.g., "main", "master") + :param workflow_content: Content of the GitHub Actions workflow file (YAML) + :return: Response object with pr_url and pr_number on success + """ + payload = {"owner": owner, "repo": repo, "baseBranch": base_branch, "workflowContent": workflow_content} + + return make_cfapi_request(endpoint="/setup-github-actions", method="POST", payload=payload) + + def create_staging( original_code: dict[Path, str], new_code: dict[Path, str], diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index e63f52b77..c1960a7cc 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -21,7 +21,8 @@ from rich.table import Table from rich.text import Text -from codeflash.api.cfapi import get_user_id, is_github_app_installed_on_repo +from codeflash.api.aiservice import AiServiceClient +from codeflash.api.cfapi import get_user_id, is_github_app_installed_on_repo, setup_github_actions from codeflash.cli_cmds.cli_common import apologize_and_exit from codeflash.cli_cmds.console import console, logger from codeflash.cli_cmds.extension import install_vscode_extension @@ -29,7 +30,7 @@ from codeflash.code_utils.compat import LF from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.env_utils import check_formatter_installed, get_codeflash_api_key -from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name +from codeflash.code_utils.git_utils import get_current_branch, get_git_remotes, get_repo_owner_and_name from codeflash.code_utils.github_utils import get_github_secrets_page_url from codeflash.code_utils.oauth_handler import perform_oauth_signin from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell, save_api_key_to_rc @@ -679,6 +680,61 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n workflows_path = git_root / ".github" / "workflows" optimize_yaml_path = workflows_path / "codeflash.yaml" + # Check if workflow file already exists locally BEFORE showing prompt + if optimize_yaml_path.exists(): + # Workflow file already exists locally - skip prompt and setup + already_exists_message = "βœ… GitHub Actions workflow file already exists.\n\n" + already_exists_message += "No changes needed - your repository is already configured!" + + already_exists_panel = Panel( + Text(already_exists_message, style="green", justify="center"), + title="βœ… Already Configured", + border_style="bright_green", + ) + console.print(already_exists_panel) + console.print() + + logger.info("[cmd_init.py:install_github_actions] Workflow file already exists locally, skipping setup") + return + + # Get repository information for API call + git_remote = config.get("git_remote", "origin") + # get_current_branch handles detached HEAD and other edge cases internally + try: + base_branch = get_current_branch(repo) + except Exception as e: + logger.warning( + f"[cmd_init.py:install_github_actions] Could not determine current branch: {e}. Falling back to 'main'." + ) + base_branch = "main" + + # Generate workflow content + from importlib.resources import files + + benchmark_mode = False + benchmarks_root = config.get("benchmarks_root", "").strip() + if benchmarks_root and benchmarks_root != "": + benchmark_panel = Panel( + Text( + "πŸ“Š Benchmark Mode Available\n\n" + "I noticed you've configured a benchmarks_root in your config. " + "Benchmark mode will show the performance impact of Codeflash's optimizations on your benchmarks.", + style="cyan", + ), + title="πŸ“Š Benchmark Mode", + border_style="bright_cyan", + ) + console.print(benchmark_panel) + console.print() + + benchmark_questions = [ + inquirer.Confirm("benchmark_mode", message="Run GitHub Actions in benchmark mode?", default=True) + ] + + benchmark_answers = inquirer.prompt(benchmark_questions, theme=CodeflashTheme()) + benchmark_mode = benchmark_answers["benchmark_mode"] if benchmark_answers else False + + # Show prompt only if workflow doesn't exist locally actions_panel = Panel( Text( "πŸ€– GitHub Actions Setup\n\n" @@ -692,32 +748,11 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n console.print(actions_panel) console.print() - # Check if the workflow file already exists - if optimize_yaml_path.exists(): - overwrite_questions = [ - inquirer.Confirm( - "confirm_overwrite", - message=f"GitHub Actions workflow already exists at {optimize_yaml_path}. Overwrite?", - default=False, - ) - ] - - overwrite_answers = inquirer.prompt(overwrite_questions, theme=CodeflashTheme()) - if not overwrite_answers or not overwrite_answers["confirm_overwrite"]: - skip_panel = Panel( - Text("⏩️ Skipping workflow creation.", style="yellow"), title="⏩️ Skipped", border_style="yellow" - ) - console.print(skip_panel) - ph("cli-github-workflow-skipped") - return - ph( - "cli-github-optimization-confirm-workflow-overwrite", - {"confirm_overwrite": overwrite_answers["confirm_overwrite"]}, - ) - creation_questions = [ inquirer.Confirm( - "confirm_creation", message="Set up GitHub Actions for continuous optimization?", default=True + "confirm_creation", + message="Set up GitHub Actions for continuous optimization? We'll open a pull request with the workflow file.", + default=True, ) ] @@ -733,60 +768,285 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n "cli-github-optimization-confirm-workflow-creation", {"confirm_creation": creation_answers["confirm_creation"]}, ) + + # Generate workflow content AFTER user confirmation + logger.info("[cmd_init.py:install_github_actions] User confirmed, generating workflow content...") + optimize_yml_content = ( + files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8") + ) + materialized_optimize_yml_content = generate_dynamic_workflow_content( + optimize_yml_content, config, git_root, benchmark_mode + ) + workflows_path.mkdir(parents=True, exist_ok=True) - from importlib.resources import files - benchmark_mode = False - benchmarks_root = config.get("benchmarks_root", "").strip() - if benchmarks_root and benchmarks_root != "": - benchmark_panel = Panel( + pr_created_via_api = False + pr_url = None + + try: + owner, repo_name = get_repo_owner_and_name(repo, git_remote) + except Exception as e: + logger.error(f"[cmd_init.py:install_github_actions] Failed to get repository owner and name: {e}") + # Fall back to local file creation + workflows_path.mkdir(parents=True, exist_ok=True) + with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file: + optimize_yml_file.write(materialized_optimize_yml_content) + workflow_success_panel = Panel( Text( - "πŸ“Š Benchmark Mode Available\n\n" - "I noticed you've configured a benchmarks_root in your config. " - "Benchmark mode will show the performance impact of Codeflash's optimizations on your benchmarks.", - style="cyan", + f"βœ… Created GitHub action workflow at {optimize_yaml_path}\n\n" + "Your repository is now configured for continuous optimization!", + style="green", + justify="center", ), - title="πŸ“Š Benchmark Mode", - border_style="bright_cyan", + title="πŸŽ‰ Workflow Created!", + border_style="bright_green", ) - console.print(benchmark_panel) + console.print(workflow_success_panel) console.print() + else: + # Try to create PR via API + try: + # Workflow file doesn't exist on remote or content differs - proceed with PR creation + console.print("Creating PR with GitHub Actions workflow...") + logger.info( + f"[cmd_init.py:install_github_actions] Calling setup_github_actions API for {owner}/{repo_name} on branch {base_branch}" + ) - benchmark_questions = [ - inquirer.Confirm("benchmark_mode", message="Run GitHub Actions in benchmark mode?", default=True) - ] + response = setup_github_actions( + owner=owner, + repo=repo_name, + base_branch=base_branch, + workflow_content=materialized_optimize_yml_content, + ) - benchmark_answers = inquirer.prompt(benchmark_questions, theme=CodeflashTheme()) - benchmark_mode = benchmark_answers["benchmark_mode"] if benchmark_answers else False + if response.status_code == 200: + response_data = response.json() + if response_data.get("success"): + pr_url = response_data.get("pr_url") + + if pr_url: + pr_created_via_api = True + success_message = f"βœ… PR created: {pr_url}\n\n" + success_message += "Your repository is now configured for continuous optimization!" + + workflow_success_panel = Panel( + Text(success_message, style="green", justify="center"), + title="πŸŽ‰ Workflow PR Created!", + border_style="bright_green", + ) + console.print(workflow_success_panel) + console.print() + + logger.info( + f"[cmd_init.py:install_github_actions] Successfully created PR #{response_data.get('pr_number')} for {owner}/{repo_name}" + ) + else: + # File already exists with same content + pr_created_via_api = True # Mark as handled (no PR needed) + already_exists_message = "βœ… Workflow file already exists with the same content.\n\n" + already_exists_message += "No changes needed - your repository is already configured!" + + already_exists_panel = Panel( + Text(already_exists_message, style="green", justify="center"), + title="βœ… Already Configured", + border_style="bright_green", + ) + console.print(already_exists_panel) + console.print() + else: + # API returned success=false, extract error details + error_data = response_data + error_msg = error_data.get("error", "Unknown error") + error_message = error_data.get("message", error_msg) + error_help = error_data.get("help", "") + installation_url = error_data.get("installation_url") + + # For permission errors, don't fall back - show a focused message and abort early + if response.status_code == 403: + logger.error( + f"[cmd_init.py:install_github_actions] Permission denied for {owner}/{repo_name}" + ) + # Extract installation_url if available, otherwise use default + installation_url_403 = error_data.get( + "installation_url", "https://github.com/apps/codeflash-ai/installations/select_target" + ) + + permission_error_panel = Panel( + Text( + "❌ Access Denied\n\n" + f"The GitHub App may not be installed on {owner}/{repo_name}, or it doesn't have the required permissions.\n\n" + "πŸ’‘ To fix this:\n" + "1. Install the CodeFlash GitHub App on your repository\n" + "2. Ensure the app has 'Contents: write', 'Workflows: write', and 'Pull requests: write' permissions\n" + "3. Make sure you have write access to the repository\n\n" + f"πŸ”— Install GitHub App: {installation_url_403}", + style="red", + ), + title="❌ Setup Failed", + border_style="red", + ) + console.print(permission_error_panel) + console.print() + click.echo( + f"Please install the CodeFlash GitHub App and ensure it has the required permissions.{LF}" + f"Visit: {installation_url_403}{LF}" + ) + apologize_and_exit() + + # Show detailed error panel for all other errors + error_panel_text = f"❌ {error_msg}\n\n{error_message}\n" + if error_help: + error_panel_text += f"\nπŸ’‘ {error_help}\n" + if installation_url: + error_panel_text += f"\nπŸ”— Install GitHub App: {installation_url}" + + error_panel = Panel( + Text(error_panel_text, style="red"), title="❌ Setup Failed", border_style="red" + ) + console.print(error_panel) + console.print() + + # For GitHub App not installed, don't fall back - show clear instructions + if response.status_code == 404 and installation_url: + logger.error( + f"[cmd_init.py:install_github_actions] GitHub App not installed on {owner}/{repo_name}" + ) + click.echo( + f"Please install the CodeFlash GitHub App on your repository to continue.{LF}" + f"Visit: {installation_url}{LF}" + ) + return + + # For other errors, fall back to local file creation + raise Exception(error_message) # noqa: TRY002, TRY301 + else: + # API call returned non-200 status, try to parse error response + try: + error_data = response.json() + error_msg = error_data.get("error", "API request failed") + error_message = error_data.get("message", f"API returned status {response.status_code}") + error_help = error_data.get("help", "") + installation_url = error_data.get("installation_url") + + # For permission errors, don't fall back - show a focused message and abort early + if response.status_code == 403: + logger.error( + f"[cmd_init.py:install_github_actions] Permission denied for {owner}/{repo_name}" + ) + # Extract installation_url if available, otherwise use default + installation_url_403 = error_data.get( + "installation_url", "https://github.com/apps/codeflash-ai/installations/select_target" + ) + + permission_error_panel = Panel( + Text( + "❌ Access Denied\n\n" + f"The GitHub App may not be installed on {owner}/{repo_name}, or it doesn't have the required permissions.\n\n" + "πŸ’‘ To fix this:\n" + "1. Install the CodeFlash GitHub App on your repository\n" + "2. Ensure the app has 'Contents: write', 'Workflows: write', and 'Pull requests: write' permissions\n" + "3. Make sure you have write access to the repository\n\n" + f"πŸ”— Install GitHub App: {installation_url_403}", + style="red", + ), + title="❌ Setup Failed", + border_style="red", + ) + console.print(permission_error_panel) + console.print() + click.echo( + f"Please install the CodeFlash GitHub App and ensure it has the required permissions.{LF}" + f"Visit: {installation_url_403}{LF}" + ) + apologize_and_exit() + + # Show detailed error panel for all other errors + error_panel_text = f"❌ {error_msg}\n\n{error_message}\n" + if error_help: + error_panel_text += f"\nπŸ’‘ {error_help}\n" + if installation_url: + error_panel_text += f"\nπŸ”— Install GitHub App: {installation_url}" + + error_panel = Panel( + Text(error_panel_text, style="red"), title="❌ Setup Failed", border_style="red" + ) + console.print(error_panel) + console.print() + + # For GitHub App not installed, don't fall back - show clear instructions + if response.status_code == 404 and installation_url: + logger.error( + f"[cmd_init.py:install_github_actions] GitHub App not installed on {owner}/{repo_name}" + ) + click.echo( + f"Please install the CodeFlash GitHub App on your repository to continue.{LF}" + f"Visit: {installation_url}{LF}" + ) + return + + # For authentication errors, don't fall back + if response.status_code == 401: + logger.error( + f"[cmd_init.py:install_github_actions] Authentication failed for {owner}/{repo_name}" + ) + click.echo(f"Authentication failed. Please check your API key and try again.{LF}") + return + + # For other errors, fall back to local file creation + raise Exception(error_message) # noqa: TRY002 + except (ValueError, KeyError) as parse_error: + # Couldn't parse error response, use generic message + status_msg = f"API returned status {response.status_code}" + raise Exception(status_msg) from parse_error # noqa: TRY002 + + except Exception as api_error: + # Fall back to local file creation if API call fails (for non-critical errors) + logger.warning( + f"[cmd_init.py:install_github_actions] API call failed, falling back to local file creation: {api_error}" + ) + workflows_path.mkdir(parents=True, exist_ok=True) + with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file: + optimize_yml_file.write(materialized_optimize_yml_content) + workflow_success_panel = Panel( + Text( + f"βœ… Created GitHub action workflow at {optimize_yaml_path}\n\n" + "Your repository is now configured for continuous optimization!", + style="green", + justify="center", + ), + title="πŸŽ‰ Workflow Created!", + border_style="bright_green", + ) + console.print(workflow_success_panel) + console.print() - optimize_yml_content = ( - files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8") - ) - materialized_optimize_yml_content = customize_codeflash_yaml_content( - optimize_yml_content, config, git_root, benchmark_mode - ) - with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file: - optimize_yml_file.write(materialized_optimize_yml_content) - # Success panel for workflow creation - workflow_success_panel = Panel( - Text( - f"βœ… Created GitHub action workflow at {optimize_yaml_path}\n\n" - "Your repository is now configured for continuous optimization!", - style="green", - justify="center", - ), - title="πŸŽ‰ Workflow Created!", - border_style="bright_green", - ) - console.print(workflow_success_panel) - console.print() + # Show appropriate message based on whether PR was created via API + if pr_created_via_api: + if pr_url: + click.echo( + f"πŸš€ Codeflash is now configured to automatically optimize new Github PRs!{LF}" + f"Once you merge the PR, the workflow will be active.{LF}" + ) + else: + # File already exists + click.echo( + f"πŸš€ Codeflash is now configured to automatically optimize new Github PRs!{LF}" + f"The workflow is ready to use.{LF}" + ) + else: + # Fell back to local file creation + click.echo( + f"Please edit, commit and push this GitHub actions file to your repo, and you're all set!{LF}" + f"πŸš€ Codeflash is now configured to automatically optimize new Github PRs!{LF}" + ) + # Show GitHub secrets setup panel (needed in both cases - PR created via API or local file) try: existing_api_key = get_codeflash_api_key() except OSError: existing_api_key = None - # GitHub secrets setup panel + # GitHub secrets setup panel - always shown since secrets are required for the workflow to work secrets_message = ( "πŸ” Next Step: Add API Key as GitHub Secret\n\n" "You'll need to add your CODEFLASH_API_KEY as a secret to your GitHub repository.\n\n" @@ -823,11 +1083,7 @@ def install_github_actions(override_formatter_check: bool = False) -> None: # n ) console.print(launch_panel) click.pause() - click.echo() - click.echo( - f"Please edit, commit and push this GitHub actions file to your repo, and you're all set!{LF}" - f"πŸš€ Codeflash is now configured to automatically optimize new Github PRs!{LF}" - ) + console.print() ph("cli-github-workflow-created") except KeyboardInterrupt: apologize_and_exit() @@ -879,7 +1135,9 @@ def get_dependency_installation_commands(dep_manager: DependencyManager) -> tupl pip install poetry poetry install --all-extras""" if dep_manager == DependencyManager.UV: - return "uv sync --all-extras" + return """| + uv sync --all-extras + uv pip install --upgrade codeflash""" # PIP or UNKNOWN return """| python -m pip install --upgrade pip @@ -910,6 +1168,215 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str: working-directory: ./{working_dir}""" +def collect_repo_files_for_workflow(git_root: Path) -> dict[str, Any]: + """Collect important repository files and directory structure for workflow generation. + + :param git_root: Root directory of the git repository + :return: Dictionary with 'files' (path -> content) and 'directory_structure' (nested dict) + """ + # Important files to collect with contents + important_files = [ + "pyproject.toml", + "requirements.txt", + "requirements-dev.txt", + "requirements/requirements.txt", + "requirements/dev.txt", + "Pipfile", + "Pipfile.lock", + "poetry.lock", + "uv.lock", + "setup.py", + "setup.cfg", + "Dockerfile", + "docker-compose.yml", + "docker-compose.yaml", + "Makefile", + "README.md", + "README.rst", + ] + + # Also collect GitHub workflows + workflows_path = git_root / ".github" / "workflows" + if workflows_path.exists(): + important_files.extend( + str(workflow_file.relative_to(git_root)) for workflow_file in workflows_path.glob("*.yml") + ) + important_files.extend( + str(workflow_file.relative_to(git_root)) for workflow_file in workflows_path.glob("*.yaml") + ) + + files_dict: dict[str, str] = {} + max_file_size = 8 * 1024 # 8KB limit per file + + for file_path_str in important_files: + file_path = git_root / file_path_str + if file_path.exists() and file_path.is_file(): + try: + content = file_path.read_text(encoding="utf-8", errors="ignore") + # Limit file size + if len(content) > max_file_size: + content = content[:max_file_size] + "\n... (truncated)" + files_dict[file_path_str] = content + except Exception as e: + logger.warning(f"[cmd_init.py:collect_repo_files_for_workflow] Failed to read {file_path_str}: {e}") + + # Collect 2-level directory structure + directory_structure: dict[str, Any] = {} + try: + for item in sorted(git_root.iterdir()): + if item.name.startswith(".") and item.name not in [".github", ".git"]: + continue # Skip hidden files/folders except .github + + if item.is_dir(): + # Level 1: directory + dir_dict: dict[str, Any] = {"type": "directory", "contents": {}} + try: + # Level 2: contents of directory + for subitem in sorted(item.iterdir()): + if subitem.name.startswith("."): + continue + if subitem.is_dir(): + dir_dict["contents"][subitem.name] = {"type": "directory"} + else: + dir_dict["contents"][subitem.name] = {"type": "file"} + except PermissionError: + pass # Skip directories we can't read + directory_structure[item.name] = dir_dict + elif item.is_file(): + directory_structure[item.name] = {"type": "file"} + except Exception as e: + logger.warning(f"[cmd_init.py:collect_repo_files_for_workflow] Error collecting directory structure: {e}") + + return {"files": files_dict, "directory_structure": directory_structure} + + +def generate_dynamic_workflow_content( + optimize_yml_content: str, + config: tuple[dict[str, Any], Path], + git_root: Path, + benchmark_mode: bool = False, # noqa: FBT001, FBT002 +) -> str: + """Generate workflow content with dynamic steps from AI service, falling back to static template. + + :param optimize_yml_content: Base workflow template content + :param config: Codeflash configuration tuple (dict, Path) + :param git_root: Root directory of the git repository + :param benchmark_mode: Whether to enable benchmark mode + :return: Complete workflow YAML content + """ + # First, do the basic replacements that are always needed + module_path = str(Path(config["module_root"]).relative_to(git_root) / "**") + optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path) + + # Get working directory + toml_path = Path.cwd() / "pyproject.toml" + try: + with toml_path.open(encoding="utf8") as pyproject_file: + pyproject_data = tomlkit.parse(pyproject_file.read()) + except FileNotFoundError: + click.echo( + f"I couldn't find a pyproject.toml in the current directory.{LF}" + f"Please create a new empty pyproject.toml file here, OR if you use poetry then run `poetry init`, OR run `codeflash init` again from a directory with an existing pyproject.toml file." + ) + apologize_and_exit() + + working_dir = get_github_action_working_directory(toml_path, git_root) + optimize_yml_content = optimize_yml_content.replace("{{ working_directory }}", working_dir) + + # Try to generate dynamic steps using AI service + try: + repo_data = collect_repo_files_for_workflow(git_root) + + # Prepare codeflash config for AI + codeflash_config = { + "module_root": config["module_root"], + "tests_root": config.get("tests_root", ""), + "benchmark_mode": benchmark_mode, + } + + aiservice_client = AiServiceClient() + dynamic_steps = aiservice_client.generate_workflow_steps( + repo_files=repo_data["files"], + directory_structure=repo_data["directory_structure"], + codeflash_config=codeflash_config, + ) + + if dynamic_steps: + # Replace the entire steps section with AI-generated steps + # Find the steps section in the template + steps_start = optimize_yml_content.find(" steps:") + if steps_start != -1: + # Find the end of the steps section (next line at same or less indentation) + lines = optimize_yml_content.split("\n") + steps_start_line = optimize_yml_content[:steps_start].count("\n") + steps_end_line = len(lines) + + # Find where steps section ends (next job or end of file) + for i in range(steps_start_line + 1, len(lines)): + line = lines[i] + # Stop if we hit a line that's not indented (new job or end of jobs) + if line and not line.startswith(" ") and not line.startswith("\t"): + steps_end_line = i + break + + # Extract steps content from AI response (remove "steps:" prefix if present) + steps_content = dynamic_steps + if steps_content.startswith("steps:"): + # Remove "steps:" and leading newline + steps_content = steps_content[6:].lstrip("\n") + + # Ensure proper indentation (8 spaces for steps section in YAML) + indented_steps = [] + for line in steps_content.split("\n"): + if line.strip(): + # If line doesn't start with enough spaces, add them + if not line.startswith(" "): + indented_steps.append(" " + line) + else: + # Preserve existing indentation but ensure minimum 8 spaces + current_indent = len(line) - len(line.lstrip()) + if current_indent < 8: + indented_steps.append(" " * 8 + line.lstrip()) + else: + indented_steps.append(line) + else: + indented_steps.append("") + + # Add codeflash command step at the end + dep_manager = determine_dependency_manager(pyproject_data) + codeflash_cmd = get_codeflash_github_action_command(dep_manager) + if benchmark_mode: + codeflash_cmd += " --benchmark" + + # Format codeflash command properly + if "|" in codeflash_cmd: + # Multi-line command + cmd_lines = codeflash_cmd.split("\n") + codeflash_step = f" - name: ⚑️Codeflash Optimization\n run: {cmd_lines[0].strip()}" + for cmd_line in cmd_lines[1:]: + codeflash_step += f"\n {cmd_line.strip()}" + else: + codeflash_step = f" - name: ⚑️Codeflash Optimization\n run: {codeflash_cmd}" + + indented_steps.append(codeflash_step) + + # Reconstruct the workflow + return "\n".join([*lines[:steps_start_line], " steps:", *indented_steps, *lines[steps_end_line:]]) + logger.warning("[cmd_init.py:generate_dynamic_workflow_content] Could not find steps section in template") + else: + logger.debug( + "[cmd_init.py:generate_dynamic_workflow_content] AI service returned no steps, falling back to static" + ) + + except Exception as e: + logger.warning( + f"[cmd_init.py:generate_dynamic_workflow_content] Error generating dynamic workflow, falling back to static: {e}" + ) + + # Fallback to static template + return customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode) + + def customize_codeflash_yaml_content( optimize_yml_content: str, config: tuple[dict[str, Any], Path], diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py index 40a725692..9abe86403 100644 --- a/codeflash/code_utils/git_utils.py +++ b/codeflash/code_utils/git_utils.py @@ -70,12 +70,43 @@ def get_git_diff( def get_current_branch(repo: Repo | None = None) -> str: """Return the name of the current branch in the given repository. + Handles detached HEAD state and other edge cases by falling back to + the default branch (main or master) or "main" if no default branch exists. + :param repo: An optional Repo object. If not provided, the function will search for a repository in the current and parent directories. - :return: The name of the current branch. + :return: The name of the current branch, or "main" if HEAD is detached or + the branch cannot be determined. """ repository: Repo = repo if repo else git.Repo(search_parent_directories=True) - return repository.active_branch.name + + # Check if HEAD is detached (active_branch will be None) + if repository.head.is_detached: + logger.warning( + "HEAD is detached. Cannot determine current branch. Falling back to 'main'. " + "Consider checking out a branch before running Codeflash." + ) + # Try to find the default branch (main or master) + for default_branch in ["main", "master"]: + try: + if default_branch in repository.branches: + logger.info(f"Using '{default_branch}' as fallback branch.") + return default_branch + except Exception as e: + logger.debug(f"Error checking for branch '{default_branch}': {e}") + continue + # If no default branch found, return "main" as a safe default + return "main" + + # HEAD is not detached, safe to access active_branch + try: + return repository.active_branch.name + except (AttributeError, TypeError) as e: + logger.warning( + f"Could not determine active branch: {e}. Falling back to 'main'. " + "This may indicate the repository is in an unusual state." + ) + return "main" def get_remote_url(repo: Repo | None = None, git_remote: str | None = "origin") -> str: @@ -126,8 +157,19 @@ def confirm_proceeding_with_no_git_repo() -> str | bool: def check_and_push_branch(repo: git.Repo, git_remote: str | None = "origin", *, wait_for_push: bool = False) -> bool: - current_branch = repo.active_branch - current_branch_name = current_branch.name + # Check if HEAD is detached + if repo.head.is_detached: + logger.warning("⚠️ HEAD is detached. Cannot push branch. Please check out a branch before creating a PR.") + return False + + # Safe to access active_branch when HEAD is not detached + try: + current_branch = repo.active_branch + current_branch_name = current_branch.name + except (AttributeError, TypeError) as e: + logger.warning(f"⚠️ Could not determine active branch: {e}. Cannot push branch.") + return False + remote = repo.remote(name=git_remote) # Check if the branch is pushed diff --git a/tests/test_git_utils.py b/tests/test_git_utils.py index dd40521cc..b82f87ac3 100644 --- a/tests/test_git_utils.py +++ b/tests/test_git_utils.py @@ -67,6 +67,8 @@ def test_check_running_in_git_repo_not_in_git_repo_non_interactive(self, mock_is @patch("codeflash.code_utils.git_utils.Confirm.ask", return_value=True) def test_check_and_push_branch(self, mock_confirm, mock_isatty, mock_repo): mock_repo_instance = mock_repo.return_value + # Mock HEAD not being detached + mock_repo_instance.head.is_detached = False mock_repo_instance.active_branch.name = "test-branch" mock_repo_instance.refs = [] @@ -87,6 +89,8 @@ def test_check_and_push_branch(self, mock_confirm, mock_isatty, mock_repo): @patch("codeflash.code_utils.git_utils.sys.__stdin__.isatty", return_value=False) def test_check_and_push_branch_non_tty(self, mock_isatty, mock_repo): mock_repo_instance = mock_repo.return_value + # Mock HEAD not being detached + mock_repo_instance.head.is_detached = False mock_repo_instance.active_branch.name = "test-branch" mock_repo_instance.refs = [] @@ -97,6 +101,19 @@ def test_check_and_push_branch_non_tty(self, mock_isatty, mock_repo): mock_origin.push.assert_not_called() mock_origin.push.reset_mock() + @patch("codeflash.code_utils.git_utils.git.Repo") + def test_check_and_push_branch_detached_head(self, mock_repo): + mock_repo_instance = mock_repo.return_value + # Mock HEAD being detached + mock_repo_instance.head.is_detached = True + + mock_origin = mock_repo_instance.remote.return_value + mock_origin.push.return_value = None + + # Should return False when HEAD is detached + assert not check_and_push_branch(mock_repo_instance) + mock_origin.push.assert_not_called() + if __name__ == "__main__": unittest.main()