Skip to content

Commit 63b6e77

Browse files
mashraf-222HeshamHM28KRRT7
authored
saving the API key correctly for windows PowerShell (#940)
* remove emoji * save cf api key correctly for powershell * fix test_shell_utils test * fix linting and tests * FIX ALL TESTS * revert tests/test_trace_benchmarks.py --------- Co-authored-by: HeshamHM28 <HeshamMohamedFathy@outlook.com> Co-authored-by: Kevin Turcios <106575910+KRRT7@users.noreply.github.com>
1 parent 024ef1a commit 63b6e77

File tree

4 files changed

+200
-37
lines changed

4 files changed

+200
-37
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from codeflash.code_utils.git_utils import get_git_remotes, get_repo_owner_and_name
3333
from codeflash.code_utils.github_utils import get_github_secrets_page_url
3434
from codeflash.code_utils.oauth_handler import perform_oauth_signin
35-
from codeflash.code_utils.shell_utils import get_shell_rc_path, save_api_key_to_rc
35+
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell, save_api_key_to_rc
3636
from codeflash.either import is_successful
3737
from codeflash.lsp.helpers import is_LSP_enabled
3838
from codeflash.telemetry.posthog_cf import ph
@@ -136,7 +136,10 @@ def init_codeflash() -> None:
136136
completion_message += (
137137
"\n\n🐚 Don't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
138138
)
139-
reload_cmd = f"call {get_shell_rc_path()}" if os.name == "nt" else f"source {get_shell_rc_path()}"
139+
if os.name == "nt":
140+
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
141+
else:
142+
reload_cmd = f"source {get_shell_rc_path()}"
140143
completion_message += f"\nOr run: {reload_cmd}"
141144

142145
completion_panel = Panel(
@@ -1087,7 +1090,7 @@ def configure_pyproject_toml(
10871090

10881091
with toml_path.open("w", encoding="utf8") as pyproject_file:
10891092
pyproject_file.write(tomlkit.dumps(pyproject_data))
1090-
click.echo(f"Added Codeflash configuration to {toml_path}")
1093+
click.echo(f"Added Codeflash configuration to {toml_path}")
10911094
click.echo()
10921095
return True
10931096

@@ -1264,7 +1267,8 @@ def enter_api_key_and_save_to_rc() -> None:
12641267
browser_launched = True # This does not work on remote consoles
12651268
shell_rc_path = get_shell_rc_path()
12661269
if not shell_rc_path.exists() and os.name == "nt":
1267-
# On Windows, create a batch file in the user's home directory (not auto-run, just used to store api key)
1270+
# On Windows, create the appropriate file (PowerShell .ps1 or CMD .bat) in the user's home directory
1271+
shell_rc_path.parent.mkdir(parents=True, exist_ok=True)
12681272
shell_rc_path.touch()
12691273
click.echo(f"✅ Created {shell_rc_path}")
12701274
get_user_id(api_key=api_key) # Used to verify whether the API key is valid.

codeflash/code_utils/env_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,26 @@ def get_codeflash_api_key() -> str:
5959
# Check environment variable first
6060
env_api_key = os.environ.get("CODEFLASH_API_KEY")
6161
shell_api_key = read_api_key_from_shell_config()
62-
62+
logger.debug(
63+
f"env_utils.py:get_codeflash_api_key - env_api_key: {'***' + env_api_key[-4:] if env_api_key else None}, shell_api_key: {'***' + shell_api_key[-4:] if shell_api_key else None}"
64+
)
6365
# If we have an env var but it's not in shell config, save it for persistence
6466
if env_api_key and not shell_api_key:
6567
try:
6668
from codeflash.either import is_successful
6769

70+
logger.debug("env_utils.py:get_codeflash_api_key - Saving API key from environment to shell config")
6871
result = save_api_key_to_rc(env_api_key)
6972
if is_successful(result):
70-
logger.debug(f"Automatically saved API key from environment to shell config: {result.unwrap()}")
73+
logger.debug(
74+
f"env_utils.py:get_codeflash_api_key - Automatically saved API key from environment to shell config: {result.unwrap()}"
75+
)
76+
else:
77+
logger.debug(f"env_utils.py:get_codeflash_api_key - Failed to save API key: {result.failure()}")
7178
except Exception as e:
72-
logger.debug(f"Failed to automatically save API key to shell config: {e}")
79+
logger.debug(
80+
f"env_utils.py:get_codeflash_api_key - Failed to automatically save API key to shell config: {e}"
81+
)
7382

7483
# Prefer the shell configuration over environment variables for lsp,
7584
# as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart
Lines changed: 179 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,107 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import os
45
import re
56
from pathlib import Path
67
from typing import TYPE_CHECKING, Optional
78

9+
from codeflash.cli_cmds.console import logger
810
from codeflash.code_utils.compat import LF
911
from codeflash.either import Failure, Success
1012

1113
if TYPE_CHECKING:
1214
from codeflash.either import Result
1315

14-
if os.name == "nt": # Windows
15-
SHELL_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
16-
SHELL_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
17-
else:
18-
SHELL_RC_EXPORT_PATTERN = re.compile(
19-
r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE
20-
)
21-
SHELL_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
16+
# PowerShell patterns and prefixes
17+
POWERSHELL_RC_EXPORT_PATTERN = re.compile(
18+
r'^\$env:CODEFLASH_API_KEY\s*=\s*(?:"|\')?(cf-[^\s"\']+)(?:"|\')?\s*$', re.MULTILINE
19+
)
20+
POWERSHELL_RC_EXPORT_PREFIX = "$env:CODEFLASH_API_KEY = "
21+
22+
# CMD/Batch patterns and prefixes
23+
CMD_RC_EXPORT_PATTERN = re.compile(r"^set CODEFLASH_API_KEY=(cf-.*)$", re.MULTILINE)
24+
CMD_RC_EXPORT_PREFIX = "set CODEFLASH_API_KEY="
25+
26+
# Unix shell patterns and prefixes
27+
UNIX_RC_EXPORT_PATTERN = re.compile(r'^(?!#)export CODEFLASH_API_KEY=(?:"|\')?(cf-[^\s"\']+)(?:"|\')?$', re.MULTILINE)
28+
UNIX_RC_EXPORT_PREFIX = "export CODEFLASH_API_KEY="
29+
30+
31+
def is_powershell() -> bool:
32+
"""Detect if we're running in PowerShell on Windows.
33+
34+
Uses multiple heuristics:
35+
1. PSModulePath environment variable (PowerShell always sets this)
36+
2. COMSPEC pointing to powershell.exe
37+
3. TERM_PROGRAM indicating Windows Terminal (often uses PowerShell)
38+
"""
39+
if os.name != "nt":
40+
return False
41+
42+
# Primary check: PSMODULEPATH is set by PowerShell
43+
# This is the most reliable indicator as PowerShell always sets this
44+
ps_module_path = os.environ.get("PSMODULEPATH")
45+
if ps_module_path:
46+
logger.debug("shell_utils.py:is_powershell - Detected PowerShell via PSModulePath")
47+
return True
48+
49+
# Secondary check: COMSPEC points to PowerShell
50+
comspec = os.environ.get("COMSPEC", "").lower()
51+
if "powershell" in comspec:
52+
logger.debug(f"shell_utils.py:is_powershell - Detected PowerShell via COMSPEC: {comspec}")
53+
return True
54+
55+
# Tertiary check: Windows Terminal often uses PowerShell by default
56+
# But we only use this if other indicators are ambiguous
57+
term_program = os.environ.get("TERM_PROGRAM", "").lower()
58+
# Check if we can find evidence of CMD (cmd.exe in COMSPEC)
59+
# If not, assume PowerShell for Windows Terminal
60+
if "windows" in term_program and "terminal" in term_program and "cmd.exe" not in comspec:
61+
logger.debug(f"shell_utils.py:is_powershell - Detected PowerShell via Windows Terminal (COMSPEC: {comspec})")
62+
return True
63+
64+
logger.debug(f"shell_utils.py:is_powershell - Not PowerShell (COMSPEC: {comspec}, TERM_PROGRAM: {term_program})")
65+
return False
2266

2367

2468
def read_api_key_from_shell_config() -> Optional[str]:
69+
"""Read API key from shell configuration file."""
70+
shell_rc_path = get_shell_rc_path()
71+
# Ensure shell_rc_path is a Path object for consistent handling
72+
if not isinstance(shell_rc_path, Path):
73+
shell_rc_path = Path(shell_rc_path)
74+
75+
# Determine the correct pattern to use based on the file extension and platform
76+
if os.name == "nt": # Windows
77+
pattern = POWERSHELL_RC_EXPORT_PATTERN if shell_rc_path.suffix == ".ps1" else CMD_RC_EXPORT_PATTERN
78+
else: # Unix-like
79+
pattern = UNIX_RC_EXPORT_PATTERN
80+
2581
try:
26-
shell_rc_path = get_shell_rc_path()
27-
with open(shell_rc_path, encoding="utf8") as shell_rc: # noqa: PTH123
82+
# Convert Path to string using as_posix() for cross-platform path compatibility
83+
shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path)
84+
with open(shell_rc_path_str, encoding="utf8") as shell_rc: # noqa: PTH123
2885
shell_contents = shell_rc.read()
29-
matches = SHELL_RC_EXPORT_PATTERN.findall(shell_contents)
30-
return matches[-1] if matches else None
86+
matches = pattern.findall(shell_contents)
87+
if matches:
88+
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - Found API key in file: {shell_rc_path}")
89+
return matches[-1]
90+
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - No API key found in file: {shell_rc_path}")
91+
return None
3192
except FileNotFoundError:
93+
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - File not found: {shell_rc_path}")
94+
return None
95+
except Exception as e:
96+
logger.debug(f"shell_utils.py:read_api_key_from_shell_config - Error reading file: {e}")
3297
return None
3398

3499

35100
def get_shell_rc_path() -> Path:
36101
"""Get the path to the user's shell configuration file."""
37-
if os.name == "nt": # on Windows, we use a batch file in the user's home directory
102+
if os.name == "nt": # Windows
103+
if is_powershell():
104+
return Path.home() / "codeflash_env.ps1"
38105
return Path.home() / "codeflash_env.bat"
39106
shell = os.environ.get("SHELL", "/bin/bash").split("/")[-1]
40107
shell_rc_filename = {"zsh": ".zshrc", "ksh": ".kshrc", "csh": ".cshrc", "tcsh": ".cshrc", "dash": ".profile"}.get(
@@ -44,40 +111,123 @@ def get_shell_rc_path() -> Path:
44111

45112

46113
def get_api_key_export_line(api_key: str) -> str:
47-
return f'{SHELL_RC_EXPORT_PREFIX}"{api_key}"'
114+
"""Get the appropriate export line based on the shell type."""
115+
if os.name == "nt": # Windows
116+
if is_powershell():
117+
return f'{POWERSHELL_RC_EXPORT_PREFIX}"{api_key}"'
118+
return f'{CMD_RC_EXPORT_PREFIX}"{api_key}"'
119+
# Unix-like
120+
return f'{UNIX_RC_EXPORT_PREFIX}"{api_key}"'
48121

49122

50123
def save_api_key_to_rc(api_key: str) -> Result[str, str]:
124+
"""Save API key to the appropriate shell configuration file."""
51125
shell_rc_path = get_shell_rc_path()
126+
# Ensure shell_rc_path is a Path object for consistent handling
127+
if not isinstance(shell_rc_path, Path):
128+
shell_rc_path = Path(shell_rc_path)
52129
api_key_line = get_api_key_export_line(api_key)
130+
131+
logger.debug(f"shell_utils.py:save_api_key_to_rc - Saving API key to: {shell_rc_path}")
132+
logger.debug(f"shell_utils.py:save_api_key_to_rc - API key line format: {api_key_line[:30]}...")
133+
134+
# Determine the correct pattern to use for replacement
135+
if os.name == "nt": # Windows
136+
if is_powershell():
137+
pattern = POWERSHELL_RC_EXPORT_PATTERN
138+
logger.debug("shell_utils.py:save_api_key_to_rc - Using PowerShell pattern")
139+
else:
140+
pattern = CMD_RC_EXPORT_PATTERN
141+
logger.debug("shell_utils.py:save_api_key_to_rc - Using CMD pattern")
142+
else: # Unix-like
143+
pattern = UNIX_RC_EXPORT_PATTERN
144+
logger.debug("shell_utils.py:save_api_key_to_rc - Using Unix pattern")
145+
53146
try:
54-
with open(shell_rc_path, "r+", encoding="utf8") as shell_file: # noqa: PTH123
55-
shell_contents = shell_file.read()
56-
if os.name == "nt" and not shell_contents: # on windows we're writing to a batch file
147+
# Create directory if it doesn't exist (ignore errors - file operation will fail if needed)
148+
# Directory creation failed, but we'll still try to open the file
149+
# The file operation itself will raise the appropriate exception if there are permission issues
150+
with contextlib.suppress(OSError, PermissionError):
151+
shell_rc_path.parent.mkdir(parents=True, exist_ok=True)
152+
153+
# Convert Path to string using as_posix() for cross-platform path compatibility
154+
shell_rc_path_str = shell_rc_path.as_posix() if isinstance(shell_rc_path, Path) else str(shell_rc_path)
155+
156+
# Try to open in r+ mode (read and write in single operation)
157+
# Handle FileNotFoundError if file doesn't exist (r+ requires file to exist)
158+
try:
159+
with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: # noqa: PTH123
160+
shell_contents = shell_file.read()
161+
logger.debug(f"shell_utils.py:save_api_key_to_rc - Read existing file, length: {len(shell_contents)}")
162+
163+
# Initialize empty file with header for batch files if needed
164+
if not shell_contents:
165+
logger.debug("shell_utils.py:save_api_key_to_rc - File is empty, initializing")
166+
if os.name == "nt" and not is_powershell():
167+
shell_contents = "@echo off"
168+
logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file")
169+
170+
# Check if API key already exists in the current file
171+
matches = pattern.findall(shell_contents)
172+
existing_in_file = bool(matches)
173+
logger.debug(f"shell_utils.py:save_api_key_to_rc - Existing key in file: {existing_in_file}")
174+
175+
if existing_in_file:
176+
# Replace the existing API key line in this file
177+
updated_shell_contents = re.sub(pattern, api_key_line, shell_contents)
178+
action = "Updated CODEFLASH_API_KEY in"
179+
logger.debug("shell_utils.py:save_api_key_to_rc - Replaced existing API key")
180+
else:
181+
# Append the new API key line
182+
if shell_contents and not shell_contents.endswith(LF):
183+
updated_shell_contents = shell_contents + LF + api_key_line + LF
184+
else:
185+
updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}"
186+
action = "Added CODEFLASH_API_KEY to"
187+
logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key")
188+
189+
# Write the updated contents
190+
shell_file.seek(0)
191+
shell_file.write(updated_shell_contents)
192+
shell_file.truncate()
193+
except FileNotFoundError:
194+
# File doesn't exist, create it first with initial content
195+
logger.debug("shell_utils.py:save_api_key_to_rc - File does not exist, creating new")
196+
shell_contents = ""
197+
# Initialize with header for batch files if needed
198+
if os.name == "nt" and not is_powershell():
57199
shell_contents = "@echo off"
58-
existing_api_key = read_api_key_from_shell_config()
200+
logger.debug("shell_utils.py:save_api_key_to_rc - Added @echo off header for batch file")
201+
202+
# Create the file by opening in write mode
203+
with open(shell_rc_path_str, "w", encoding="utf8") as shell_file: # noqa: PTH123
204+
shell_file.write(shell_contents)
59205

60-
if existing_api_key:
61-
# Replace the existing API key line
62-
updated_shell_contents = re.sub(SHELL_RC_EXPORT_PATTERN, api_key_line, shell_contents)
63-
action = "Updated CODEFLASH_API_KEY in"
64-
else:
206+
# Re-open in r+ mode to add the API key (r+ allows both read and write)
207+
with open(shell_rc_path_str, "r+", encoding="utf8") as shell_file: # noqa: PTH123
65208
# Append the new API key line
66209
updated_shell_contents = shell_contents.rstrip() + f"{LF}{api_key_line}{LF}"
67210
action = "Added CODEFLASH_API_KEY to"
211+
logger.debug("shell_utils.py:save_api_key_to_rc - Appended new API key to new file")
212+
213+
# Write the updated contents
214+
shell_file.seek(0)
215+
shell_file.write(updated_shell_contents)
216+
shell_file.truncate()
217+
218+
logger.debug(f"shell_utils.py:save_api_key_to_rc - Successfully wrote to {shell_rc_path}")
68219

69-
shell_file.seek(0)
70-
shell_file.write(updated_shell_contents)
71-
shell_file.truncate()
72220
return Success(f"✅ {action} {shell_rc_path}")
73-
except PermissionError:
221+
except PermissionError as e:
222+
logger.debug(f"shell_utils.py:save_api_key_to_rc - Permission error: {e}")
74223
return Failure(
75224
f"💡 I tried adding your Codeflash API key to {shell_rc_path} - but seems like I don't have permissions to do so.{LF}"
76225
f"You'll need to open it yourself and add the following line:{LF}{LF}{api_key_line}{LF}"
77226
)
78-
except FileNotFoundError:
227+
except Exception as e:
228+
logger.debug(f"shell_utils.py:save_api_key_to_rc - Error: {e}")
79229
return Failure(
80-
f"💡 I went to save your Codeflash API key to {shell_rc_path}, but noticed that it doesn't exist.{LF}"
230+
f"💡 I went to save your Codeflash API key to {shell_rc_path}, but encountered an error: {e}{LF}"
81231
f"To ensure your Codeflash API key is automatically loaded into your environment at startup, you can create {shell_rc_path} and add the following line:{LF}"
82232
f"{LF}{api_key_line}{LF}"
83233
)

tests/test_shell_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def test_save_api_key_to_rc_success(self, mock_get_shell_rc_path, mock_file):
1515
api_key = "cf-12345"
1616
result = save_api_key_to_rc(api_key)
1717
self.assertTrue(isinstance(result, Success))
18-
mock_file.assert_called_with("/fake/path/.bashrc", encoding="utf8")
18+
mock_file.assert_called_with("/fake/path/.bashrc", "r+", encoding="utf8")
1919
handle = mock_file()
2020
handle.write.assert_called_once()
2121
handle.truncate.assert_called_once()

0 commit comments

Comments
 (0)