Skip to content

Commit 410ef8e

Browse files
aseembits93Codeflash Bot
authored andcommitted
inference test
1 parent 65cba54 commit 410ef8e

File tree

2 files changed

+628
-14
lines changed

2 files changed

+628
-14
lines changed

codeflash/code_utils/formatter.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,20 @@ def is_diff_line(line: str) -> bool:
9696
return len(diff_lines)
9797

9898

99-
def format_generated_code(generated_test_source: str) -> str:
100-
return re.sub(r"\n{2,}", "\n\n", generated_test_source)
101-
# formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
102-
# if formatter_name == "disabled":
103-
# return re.sub(r"\n{2,}", "\n\n", generated_test_source)
104-
# # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed)
105-
# original_temp, test_dir_str, exit_on_failure = None, None, True
106-
# formatted_temp, formatted_code, changed = apply_formatter_cmds(
107-
# formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure
108-
# )
109-
# if not changed:
110-
# return re.sub(r"\n{2,}", "\n\n", formatted_code)
111-
# return formatted_code
99+
def format_generated_code(generated_test_source: str, formatter_cmds: Union[list[str], None]) -> str:
100+
formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled"
101+
if formatter_name == "disabled":
102+
return re.sub(r"\n{2,}", "\n\n", generated_test_source)
103+
with tempfile.TemporaryDirectory() as test_dir_str:
104+
# try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed)
105+
original_temp = Path(test_dir_str) / "original_temp.py"
106+
original_temp.write_text(generated_test_source, encoding="utf8")
107+
_, formatted_code, changed = apply_formatter_cmds(
108+
formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False
109+
)
110+
if not changed:
111+
return re.sub(r"\n{2,}", "\n\n", formatted_code)
112+
return formatted_code
112113

113114

114115
def format_code(

0 commit comments

Comments
 (0)