Skip to content

Commit b7ccd5b

Browse files
authored
Merge branch 'main' into import-analyser-fix
2 parents 67ae044 + ba06d24 commit b7ccd5b

File tree

8 files changed

+43
-26
lines changed

8 files changed

+43
-26
lines changed

codeflash/api/cfapi.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import json
66
import os
7-
import sys
87
from functools import lru_cache
98
from pathlib import Path
109
from typing import TYPE_CHECKING, Any, Optional
@@ -89,12 +88,13 @@ def make_cfapi_request(
8988

9089

9190
@lru_cache(maxsize=1)
92-
def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
91+
def get_user_id(api_key: Optional[str] = None) -> Optional[str]: # noqa: PLR0911
9392
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
9493
9594
:param api_key: The API key to use. If None, uses get_codeflash_api_key().
9695
:return: The userid or None if the request fails.
9796
"""
97+
lsp_enabled = is_LSP_enabled()
9898
if not api_key and not ensure_codeflash_api_key():
9999
return None
100100

@@ -115,19 +115,21 @@ def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
115115
if min_version and version.parse(min_version) > version.parse(__version__):
116116
msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`."
117117
console.print(f"[bold red]{msg}[/bold red]")
118-
if is_LSP_enabled():
118+
if lsp_enabled:
119119
logger.debug(msg)
120120
return f"Error: {msg}"
121-
sys.exit(1)
121+
exit_with_message(msg, error_on_exit=True)
122122
return userid
123123

124124
logger.error("Failed to retrieve userid from the response.")
125125
return None
126126

127-
# Handle 403 (Invalid API key) - exit with error message
128127
if response.status_code == 403:
128+
error_title = "Invalid Codeflash API key. The API key you provided is not valid."
129+
if lsp_enabled:
130+
return f"Error: {error_title}"
129131
msg = (
130-
"Invalid Codeflash API key. The API key you provided is not valid.\n"
132+
f"{error_title}\n"
131133
"Please generate a new one at https://app.codeflash.ai/app/apikeys ,\n"
132134
"then set it as a CODEFLASH_API_KEY environment variable.\n"
133135
"For more information, refer to the documentation at \n"

codeflash/code_utils/code_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from codeflash.cli_cmds.console import logger, paneled_text
1919
from codeflash.code_utils.config_parser import find_pyproject_toml, get_all_closest_config_files
20+
from codeflash.lsp.helpers import is_LSP_enabled
2021

2122
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
2223

@@ -352,6 +353,10 @@ def restore_conftest(path_to_content_map: dict[Path, str]) -> None:
352353

353354

354355
def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
356+
"""Don't Call it inside the lsp process, it will terminate the lsp server."""
357+
if is_LSP_enabled():
358+
logger.error(message)
359+
return
355360
paneled_text(message, panel_args={"style": "red"})
356361

357362
sys.exit(1 if error_on_exit else 0)

codeflash/code_utils/env_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from codeflash.code_utils.code_utils import exit_with_message
1414
from codeflash.code_utils.formatter import format_code
1515
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
16+
from codeflash.lsp.helpers import is_LSP_enabled
1617

1718

1819
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
@@ -70,7 +71,10 @@ def get_codeflash_api_key() -> str:
7071
except Exception as e:
7172
logger.debug(f"Failed to automatically save API key to shell config: {e}")
7273

73-
api_key = env_api_key or shell_api_key
74+
# Prefer the shell configuration over environment variables for lsp,
75+
# as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart
76+
# within the same process, the environment variable could become outdated.
77+
api_key = shell_api_key or env_api_key if is_LSP_enabled() else env_api_key or shell_api_key
7478

7579
api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa
7680
if not api_key:

codeflash/lsp/beta.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_functions_within_git_diff,
3030
)
3131
from codeflash.either import is_successful
32+
from codeflash.lsp.context import execution_context_vars
3233
from codeflash.lsp.features.perform_optimization import get_cancelled_reponse, sync_perform_optimization
3334
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
3435

@@ -71,11 +72,6 @@ class ValidateProjectParams:
7172
skip_validation: bool = False
7273

7374

74-
@dataclass
75-
class OnPatchAppliedParams:
76-
task_id: str
77-
78-
7975
@dataclass
8076
class OptimizableFunctionsInCommitParams:
8177
commit_hash: str
@@ -90,6 +86,11 @@ class WriteConfigParams:
9086
server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
9187

9288

89+
@server.feature("server/listFeatures")
90+
def list_features(_params: any) -> list[str]:
91+
return list(server.protocol.fm.features)
92+
93+
9394
@server.feature("getOptimizableFunctionsInCurrentDiff")
9495
def get_functions_in_current_git_diff(_params: OptimizableFunctionsParams) -> dict[str, str | dict[str, list[str]]]:
9596
functions = get_functions_within_git_diff(uncommitted_changes=True)
@@ -250,7 +251,7 @@ def init_project(params: ValidateProjectParams) -> dict[str, str]:
250251
"existingConfig": config,
251252
}
252253

253-
args = _init()
254+
args = process_args()
254255
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root}
255256

256257

@@ -268,8 +269,9 @@ def _check_api_key_validity(api_key: Optional[str]) -> dict[str, str]:
268269
if user_id is None:
269270
return {"status": "error", "message": "api key not found or invalid"}
270271

271-
if user_id.startswith("Error: "):
272-
error_msg = user_id[7:]
272+
error_prefix = "Error: "
273+
if user_id.startswith(error_prefix):
274+
error_msg = user_id[len(error_prefix) :]
273275
return {"status": "error", "message": error_msg}
274276

275277
return {"status": "success", "user_id": user_id}
@@ -336,12 +338,12 @@ def provide_api_key(params: ProvideApiKeyParams) -> dict[str, str]:
336338
def execution_context(**kwargs: str) -> None:
337339
"""Temporarily set context values for the current async task."""
338340
# Create a fresh copy per use
339-
current = {**server.execution_context_vars.get(), **kwargs}
340-
token = server.execution_context_vars.set(current)
341+
current = {**execution_context_vars.get(), **kwargs}
342+
token = execution_context_vars.set(current)
341343
try:
342344
yield
343345
finally:
344-
server.execution_context_vars.reset(token)
346+
execution_context_vars.reset(token)
345347

346348

347349
@server.feature("cleanupCurrentOptimizerSession")

codeflash/lsp/context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from __future__ import annotations
2+
3+
import contextvars
4+
5+
# Shared execution context for tracking task IDs and other metadata
6+
execution_context_vars: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar(
7+
"execution_context_vars",
8+
default={}, # noqa: B039
9+
)

codeflash/lsp/lsp_message.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def type(self) -> str:
4444
def serialize(self) -> str:
4545
if not is_LSP_enabled():
4646
return ""
47-
from codeflash.lsp.beta import server
47+
from codeflash.lsp.context import execution_context_vars
4848

49-
execution_ctx = server.execution_context_vars.get()
49+
execution_ctx = execution_context_vars.get()
5050
current_task_id = execution_ctx.get("task_id", None)
5151
data = self._loop_through(asdict(self))
5252
ordered = {"type": self.type(), "task_id": current_task_id, **data}

codeflash/lsp/server.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import contextvars
43
from typing import TYPE_CHECKING
54

65
from lsprotocol.types import LogMessageParams, MessageType
@@ -25,10 +24,6 @@ def __init__(self, name: str, version: str, protocol_cls: type[LanguageServerPro
2524
self.optimizer: Optimizer | None = None
2625
self.args = None
2726
self.current_optimization_init_result: tuple[bool, CodeOptimizationContext, dict[Path, str]] | None = None
28-
self.execution_context_vars: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar(
29-
"execution_context_vars",
30-
default={}, # noqa: B039
31-
)
3227

3328
def prepare_optimizer_arguments(self, config_file: Path) -> None:
3429
from codeflash.cli_cmds.cli import parse_args

codeflash/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# These version placeholders will be replaced by uv-dynamic-versioning during build.
2-
__version__ = "0.18.1"
2+
__version__ = "0.18.2"

0 commit comments

Comments
 (0)