Skip to content

Commit b90cb3c

Browse files
author
Codeflash Bot
committed
Merge remote-tracking branch 'origin/main' into generated-tests-markdown
2 parents 0b425b9 + 653fa9d commit b90cb3c

37 files changed

+1707
-311
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ jobs:
2828
- name: install dependencies
2929
run: uv sync
3030

31-
- name: Install test-only dependencies (Python 3.13)
32-
if: matrix.python-version == '3.13'
31+
- name: Install test-only dependencies (Python 3.9 and 3.13)
32+
if: matrix.python-version == '3.9' || matrix.python-version == '3.13'
3333
run: uv sync --group tests
3434

3535
- name: Unit tests

code_to_optimize/bubble_sort_method.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,27 @@ def sorter(self, arr):
1515
arr[j + 1] = temp
1616
print("stderr test", file=sys.stderr)
1717
return arr
18+
19+
@classmethod
20+
def sorter_classmethod(cls, arr):
21+
print("codeflash stdout : BubbleSorter.sorter_classmethod() called")
22+
for i in range(len(arr)):
23+
for j in range(len(arr) - 1):
24+
if arr[j] > arr[j + 1]:
25+
temp = arr[j]
26+
arr[j] = arr[j + 1]
27+
arr[j + 1] = temp
28+
print("stderr test classmethod", file=sys.stderr)
29+
return arr
30+
31+
@staticmethod
32+
def sorter_staticmethod(arr):
33+
print("codeflash stdout : BubbleSorter.sorter_staticmethod() called")
34+
for i in range(len(arr)):
35+
for j in range(len(arr) - 1):
36+
if arr[j] > arr[j + 1]:
37+
temp = arr[j]
38+
arr[j] = arr[j + 1]
39+
arr[j + 1] = temp
40+
print("stderr test staticmethod", file=sys.stderr)
41+
return arr

code_to_optimize/tests/pytest/test_topological_sort.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_topological_sort():
1010
g.addEdge(2, 3)
1111
g.addEdge(3, 1)
1212

13-
assert g.topologicalSort() == [5, 4, 2, 3, 1, 0]
13+
assert g.topologicalSort()[0] == [5, 4, 2, 3, 1, 0]
1414

1515

1616
def test_topological_sort_2():
@@ -20,15 +20,15 @@ def test_topological_sort_2():
2020
for j in range(i + 1, 10):
2121
g.addEdge(i, j)
2222

23-
assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
23+
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
2424

2525
g = Graph(10)
2626

2727
for i in range(10):
2828
for j in range(i + 1, 10):
2929
g.addEdge(i, j)
3030

31-
assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
31+
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
3232

3333

3434
def test_topological_sort_3():
@@ -38,4 +38,4 @@ def test_topological_sort_3():
3838
for j in range(i + 1, 1000):
3939
g.addEdge(j, i)
4040

41-
assert g.topologicalSort() == list(reversed(range(1000)))
41+
assert g.topologicalSort()[0] == list(reversed(range(1000)))

codeflash-benchmark/codeflash_benchmark/plugin.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,11 @@ def pytest_addoption(parser: pytest.Parser) -> None:
5252
parser.addoption(
5353
"--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing for benchmarks"
5454
)
55-
# These options are ignored when --codeflash-trace is used
56-
for option, action, default, help_text in benchmark_options:
57-
help_suffix = " (ignored when --codeflash-trace is used)"
58-
parser.addoption(option, action=action, default=default, help=help_text + help_suffix)
55+
# Only add benchmark options if pytest-benchmark is not installed for backward compatibility with existing pytest-benchmark setup
56+
if not PYTEST_BENCHMARK_INSTALLED:
57+
for option, action, default, help_text in benchmark_options:
58+
help_suffix = " (ignored when --codeflash-trace is used)"
59+
parser.addoption(option, action=action, default=default, help=help_text + help_suffix)
5960

6061

6162
@pytest.fixture

codeflash/api/aiservice.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
255255
"optimized_code_runtime": opt.optimized_code_runtime,
256256
"speedup": opt.speedup,
257257
"trace_id": opt.trace_id,
258+
"function_references": opt.function_references,
259+
"python_version": platform.python_version(),
258260
}
259261
for opt in request
260262
]
@@ -308,6 +310,7 @@ def get_new_explanation( # noqa: D417
308310
original_throughput: str | None = None,
309311
optimized_throughput: str | None = None,
310312
throughput_improvement: str | None = None,
313+
function_references: str | None = None,
311314
codeflash_version: str = codeflash_version,
312315
) -> str:
313316
"""Optimize the given python code for performance by making a request to the Django endpoint.
@@ -329,6 +332,7 @@ def get_new_explanation( # noqa: D417
329332
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
330333
- throughput_improvement: str | None - throughput improvement percentage
331334
- current codeflash version
335+
- function_references: str | None - where the function is called in the codebase
332336
333337
Returns
334338
-------
@@ -351,6 +355,7 @@ def get_new_explanation( # noqa: D417
351355
"original_throughput": original_throughput,
352356
"optimized_throughput": optimized_throughput,
353357
"throughput_improvement": throughput_improvement,
358+
"function_references": function_references,
354359
"codeflash_version": codeflash_version,
355360
}
356361
logger.info("loading|Generating explanation")
@@ -376,7 +381,12 @@ def get_new_explanation( # noqa: D417
376381
return ""
377382

378383
def generate_ranking( # noqa: D417
379-
self, trace_id: str, diffs: list[str], optimization_ids: list[str], speedups: list[float]
384+
self,
385+
trace_id: str,
386+
diffs: list[str],
387+
optimization_ids: list[str],
388+
speedups: list[float],
389+
function_references: str | None = None,
380390
) -> list[int] | None:
381391
"""Optimize the given python code for performance by making a request to the Django endpoint.
382392
@@ -385,6 +395,7 @@ def generate_ranking( # noqa: D417
385395
- trace_id : unique uuid of function
386396
- diffs : list of unified diff strings of opt candidates
387397
- speedups : list of speedups of opt candidates
398+
- function_references : where the function is called in the codebase
388399
389400
Returns
390401
-------
@@ -397,6 +408,7 @@ def generate_ranking( # noqa: D417
397408
"speedups": speedups,
398409
"optimization_ids": optimization_ids,
399410
"python_version": platform.python_version(),
411+
"function_references": function_references,
400412
}
401413
logger.info("loading|Generating ranking")
402414
console.rule()
@@ -598,6 +610,7 @@ def get_optimization_review(
598610
"original_runtime": humanize_runtime(explanation.original_runtime_ns),
599611
"codeflash_version": codeflash_version,
600612
"calling_fn_details": calling_fn_details,
613+
"python_version": platform.python_version(),
601614
}
602615
console.rule()
603616
try:

codeflash/api/cfapi.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import os
5-
import sys
5+
from dataclasses import dataclass
66
from functools import lru_cache
77
from pathlib import Path
88
from typing import TYPE_CHECKING, Any, Optional
@@ -13,6 +13,7 @@
1313
from pydantic.json import pydantic_encoder
1414

1515
from codeflash.cli_cmds.console import console, logger
16+
from codeflash.code_utils.code_utils import exit_with_message
1617
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
1718
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name
1819
from codeflash.github.PrComment import FileDiffContent, PrComment
@@ -26,14 +27,24 @@
2627

2728
from packaging import version
2829

29-
if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local":
30-
CFAPI_BASE_URL = "http://localhost:3001"
31-
CFWEBAPP_BASE_URL = "http://localhost:3000"
32-
logger.info(f"Using local CF API at {CFAPI_BASE_URL}.")
33-
console.rule()
34-
else:
35-
CFAPI_BASE_URL = "https://app.codeflash.ai"
36-
CFWEBAPP_BASE_URL = "https://app.codeflash.ai"
30+
31+
@dataclass
32+
class BaseUrls:
33+
cfapi_base_url: Optional[str] = None
34+
cfwebapp_base_url: Optional[str] = None
35+
36+
37+
@lru_cache(maxsize=1)
38+
def get_cfapi_base_urls() -> BaseUrls:
39+
if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local":
40+
cfapi_base_url = "http://localhost:3001"
41+
cfwebapp_base_url = "http://localhost:3000"
42+
logger.info(f"Using local CF API at {cfapi_base_url}.")
43+
console.rule()
44+
else:
45+
cfapi_base_url = "https://app.codeflash.ai"
46+
cfwebapp_base_url = "https://app.codeflash.ai"
47+
return BaseUrls(cfapi_base_url=cfapi_base_url, cfwebapp_base_url=cfwebapp_base_url)
3748

3849

3950
def make_cfapi_request(
@@ -53,8 +64,9 @@ def make_cfapi_request(
5364
:param suppress_errors: If True, suppress error logging for HTTP errors.
5465
:return: The response object from the API.
5566
"""
56-
url = f"{CFAPI_BASE_URL}/cfapi{endpoint}"
57-
cfapi_headers = {"Authorization": f"Bearer {api_key or get_codeflash_api_key()}"}
67+
url = f"{get_cfapi_base_urls().cfapi_base_url}/cfapi{endpoint}"
68+
final_api_key = api_key or get_codeflash_api_key()
69+
cfapi_headers = {"Authorization": f"Bearer {final_api_key}"}
5870
if extra_headers:
5971
cfapi_headers.update(extra_headers)
6072
try:
@@ -86,16 +98,22 @@ def make_cfapi_request(
8698

8799

88100
@lru_cache(maxsize=1)
89-
def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
101+
def get_user_id(api_key: Optional[str] = None) -> Optional[str]: # noqa: PLR0911
90102
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
91103
104+
:param api_key: The API key to use. If None, uses get_codeflash_api_key().
92105
:return: The userid or None if the request fails.
93106
"""
107+
lsp_enabled = is_LSP_enabled()
94108
if not api_key and not ensure_codeflash_api_key():
95109
return None
96110

97111
response = make_cfapi_request(
98-
endpoint="/cli-get-user", method="GET", extra_headers={"cli_version": __version__}, api_key=api_key
112+
endpoint="/cli-get-user",
113+
method="GET",
114+
extra_headers={"cli_version": __version__},
115+
api_key=api_key,
116+
suppress_errors=True,
99117
)
100118
if response.status_code == 200:
101119
if "min_version" not in response.text:
@@ -107,15 +125,31 @@ def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
107125
if min_version and version.parse(min_version) > version.parse(__version__):
108126
msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`."
109127
console.print(f"[bold red]{msg}[/bold red]")
110-
if is_LSP_enabled():
128+
if lsp_enabled:
111129
logger.debug(msg)
112130
return f"Error: {msg}"
113-
sys.exit(1)
131+
exit_with_message(msg, error_on_exit=True)
114132
return userid
115133

116134
logger.error("Failed to retrieve userid from the response.")
117135
return None
118136

137+
if response.status_code == 403:
138+
error_title = "Invalid Codeflash API key. The API key you provided is not valid."
139+
if lsp_enabled:
140+
return f"Error: {error_title}"
141+
msg = (
142+
f"{error_title}\n"
143+
"Please generate a new one at https://app.codeflash.ai/app/apikeys ,\n"
144+
"then set it as a CODEFLASH_API_KEY environment variable.\n"
145+
"For more information, refer to the documentation at \n"
146+
"https://docs.codeflash.ai/optimizing-with-codeflash/codeflash-github-actions#manual-setup\n"
147+
"or\n"
148+
"https://docs.codeflash.ai/optimizing-with-codeflash/codeflash-github-actions#automated-setup-recommended"
149+
)
150+
exit_with_message(msg, error_on_exit=True)
151+
152+
# For other errors, log and return None (backward compatibility)
119153
logger.error(f"Failed to look up your userid; is your CF API key valid? ({response.reason})")
120154
return None
121155

codeflash/cli_cmds/cmd_init.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
167167
console.rule()
168168

169169
if run_tests:
170-
bubble_sort_path, bubble_sort_test_path = create_bubble_sort_file_and_test(args)
171-
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
170+
file_path = create_find_common_tags_file(args, "find_common_tags.py")
171+
run_end_to_end_test(args, file_path)
172172

173173

174174
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[bool, dict[str, Any] | None, str]: # noqa: PLR0911
@@ -264,7 +264,7 @@ def get_valid_subdirs(current_dir: Optional[Path] = None) -> list[str]:
264264
]
265265

266266

267-
def get_suggestions(section: str) -> tuple(list[str], Optional[str]):
267+
def get_suggestions(section: str) -> tuple[list[str], Optional[str]]:
268268
valid_subdirs = get_valid_subdirs()
269269
if section == CommonSections.module_root:
270270
return [d for d in valid_subdirs if d != "tests"], None
@@ -391,7 +391,7 @@ def collect_setup_info() -> CLISetupInfo:
391391
tests_root_answer = tests_answers["tests_root"]
392392

393393
if tests_root_answer == create_for_me_option:
394-
tests_root = Path(curdir) / default_tests_subdir
394+
tests_root = Path(curdir) / (default_tests_subdir or "tests")
395395
tests_root.mkdir()
396396
click.echo(f"✅ Created directory {tests_root}{os.path.sep}{LF}")
397397
elif tests_root_answer == custom_dir_option:
@@ -1207,6 +1207,35 @@ def enter_api_key_and_save_to_rc() -> None:
12071207
os.environ["CODEFLASH_API_KEY"] = api_key
12081208

12091209

1210+
def create_find_common_tags_file(args: Namespace, file_name: str) -> Path:
1211+
find_common_tags_content = """def find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:
1212+
if not articles:
1213+
return set()
1214+
1215+
common_tags = articles[0]["tags"]
1216+
for article in articles[1:]:
1217+
common_tags = [tag for tag in common_tags if tag in article["tags"]]
1218+
return set(common_tags)
1219+
"""
1220+
1221+
file_path = Path(args.module_root) / file_name
1222+
lsp_enabled = is_LSP_enabled()
1223+
if file_path.exists() and not lsp_enabled:
1224+
from rich.prompt import Confirm
1225+
1226+
overwrite = Confirm.ask(
1227+
f"🤔 {file_path} already exists. Do you want to overwrite it?", default=True, show_default=False
1228+
)
1229+
if not overwrite:
1230+
apologize_and_exit()
1231+
console.rule()
1232+
1233+
file_path.write_text(find_common_tags_content, encoding="utf8")
1234+
logger.info(f"Created demo optimization file: {file_path}")
1235+
1236+
return file_path
1237+
1238+
12101239
def create_bubble_sort_file_and_test(args: Namespace) -> tuple[str, str]:
12111240
bubble_sort_content = """from typing import Union, List
12121241
def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
@@ -1276,7 +1305,7 @@ def test_sort():
12761305
return str(bubble_sort_path), str(bubble_sort_test_path)
12771306

12781307

1279-
def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None:
1308+
def run_end_to_end_test(args: Namespace, find_common_tags_path: Path) -> None:
12801309
try:
12811310
check_formatter_installed(args.formatter_cmds)
12821311
except Exception:
@@ -1285,7 +1314,7 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
12851314
)
12861315
return
12871316

1288-
command = ["codeflash", "--file", "bubble_sort.py", "--function", "sorter"]
1317+
command = ["codeflash", "--file", "find_common_tags.py", "--function", "find_common_tags"]
12891318
if args.no_pr:
12901319
command.append("--no-pr")
12911320
if args.verbose:
@@ -1316,10 +1345,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
13161345
console.rule()
13171346
# Delete the bubble_sort.py file after the test
13181347
logger.info("🧹 Cleaning up…")
1319-
for path in [bubble_sort_path, bubble_sort_test_path]:
1320-
console.rule()
1321-
Path(path).unlink(missing_ok=True)
1322-
logger.info(f"🗑️ Deleted {path}")
1348+
find_common_tags_path.unlink(missing_ok=True)
1349+
logger.info(f"🗑️ Deleted {find_common_tags_path}")
13231350

13241351

13251352
def ask_for_telemetry() -> bool:

codeflash/cli_cmds/console.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,16 @@ def paneled_text(
8080
console.print(panel)
8181

8282

83-
def code_print(code_str: str, file_name: Optional[str] = None, function_name: Optional[str] = None) -> None:
83+
def code_print(
84+
code_str: str,
85+
file_name: Optional[str] = None,
86+
function_name: Optional[str] = None,
87+
lsp_message_id: Optional[str] = None,
88+
) -> None:
8489
if is_LSP_enabled():
85-
lsp_log(LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name))
90+
lsp_log(
91+
LspCodeMessage(code=code_str, file_name=file_name, function_name=function_name, message_id=lsp_message_id)
92+
)
8693
return
8794
"""Print code with syntax highlighting."""
8895
from rich.syntax import Syntax

0 commit comments

Comments
 (0)