Skip to content

Commit 5af82f0

Browse files
authored
Merge branch 'main' into comparator-ast-recursion-depth
2 parents ff91a0a + 1fd7e6a commit 5af82f0

File tree

4 files changed

+94
-44
lines changed

4 files changed

+94
-44
lines changed

codeflash/either.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from __future__ import annotations
1+
from __future__ import annotations
22

33
from typing import Generic, TypeVar
44

codeflash/optimization/optimizer.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,6 @@ def determine_best_candidate(
551551
)
552552
speedup_ratios[candidate.optimization_id] = perf_gain
553553

554-
speedup_stats = compare_function_runtime_distributions(
555-
original_code_runtime_distribution, candidate_runtime_distribution
556-
)
557-
558554
tree = Tree(f"Candidate #{candidate_index} - Sum of Minimum Runtimes")
559555
if speedup_critic(
560556
candidate_result, original_code_baseline.runtime, best_runtime_until_now
@@ -588,28 +584,33 @@ def determine_best_candidate(
588584
console.print(tree)
589585
console.rule()
590586

591-
tree = Tree(f"Candidate #{candidate_index} - Bayesian Bootstrapping Nonparametric Analysis")
592-
tree.add(
593-
f"Expected candidate runtime (95% Credible Interval) = ["
594-
f"{humanize_runtime(candidate_runtime_statistics['credible_interval_lower_bound'])}, "
595-
f"{humanize_runtime(candidate_runtime_statistics['credible_interval_upper_bound'])}], "
596-
f"\nmedian = {humanize_runtime(candidate_runtime_statistics['median'])}"
597-
f"\nSpeedup ratio of candidate vs original:"
598-
f"\n95% Credible Interval = [{speedup_stats['credible_interval_lower_bound']:.3f}X, "
599-
f"{speedup_stats['credible_interval_upper_bound']:.3f}X]"
600-
f"\nmedian = {speedup_stats['median']:.3f}X"
601-
)
602-
if speedup_stats["credible_interval_lower_bound"] > 1.0:
603-
tree.add("The candidate is faster than the original code with a 95% probability.")
604-
if speedup_stats["median"] > best_speedup_ratio_until_now:
605-
best_speedup_ratio_until_now = speedup_stats["median"]
606-
tree.add("This candidate is the best candidate so far.")
587+
if candidate_runtime_distribution.any() and candidate_runtime_statistics:
588+
speedup_stats = compare_function_runtime_distributions(
589+
original_code_runtime_distribution, candidate_runtime_distribution
590+
)
591+
tree = Tree(f"Candidate #{candidate_index} - Bayesian Bootstrapping Nonparametric Analysis")
592+
tree.add(
593+
f"Expected candidate summed runtime (95% Credible Interval) = ["
594+
f"{humanize_runtime(round(candidate_runtime_statistics['credible_interval_lower_bound']))}"
595+
f", "
596+
f"{humanize_runtime(round(candidate_runtime_statistics['credible_interval_upper_bound']))}]"
597+
f"\nMedian = {humanize_runtime(round(candidate_runtime_statistics['median']))}"
598+
f"\nSpeedup ratio of candidate vs original:"
599+
f"\n95% Credible Interval = [{speedup_stats['credible_interval_lower_bound']:.3f}X, "
600+
f"{speedup_stats['credible_interval_upper_bound']:.3f}X]"
601+
f"\nmedian = {speedup_stats['median']:.3f}X"
602+
)
603+
if speedup_stats["credible_interval_lower_bound"] > 1.0:
604+
tree.add("The candidate is faster than the original code with a 95% probability.")
605+
if speedup_stats["median"] > best_speedup_ratio_until_now:
606+
best_speedup_ratio_until_now = float(speedup_stats["median"])
607+
tree.add("This candidate is the best candidate so far.")
608+
else:
609+
tree.add("This candidate is not faster than the current fastest candidate.")
607610
else:
608-
tree.add("This candidate is not faster than the current fastest candidate.")
609-
else:
610-
tree.add("It is inconclusive whether the candidate is faster than the original code.")
611-
console.print(tree)
612-
console.rule()
611+
tree.add("It is inconclusive whether the candidate is faster than the original code.")
612+
console.print(tree)
613+
console.rule()
613614

614615
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
615616
except KeyboardInterrupt as e:
@@ -1054,9 +1055,6 @@ def establish_original_code_baseline(
10541055
console.rule()
10551056

10561057
total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index
1057-
runtime_distribution, runtime_statistics = benchmarking_results.bayesian_nonparametric_bootstrap_analysis(
1058-
100_000
1059-
)
10601058
functions_to_remove = [
10611059
result.id.test_function_name
10621060
for result in behavioral_results
@@ -1090,9 +1088,12 @@ def establish_original_code_baseline(
10901088
console.rule()
10911089
logger.debug(f"Total original code summed runtime (ns): {total_timing}")
10921090
console.rule()
1091+
runtime_distribution, runtime_statistics = benchmarking_results.bayesian_nonparametric_bootstrap_analysis(
1092+
100_000
1093+
)
10931094
logger.info(
10941095
f"Bayesian Bootstrapping Nonparametric Analysis"
1095-
f"\nExpected original code runtime (95% Credible Interval) = ["
1096+
f"\nExpected original code summed runtime (95% Credible Interval) = ["
10961097
f"{humanize_runtime(round(runtime_statistics['credible_interval_lower_bound']))}, "
10971098
f"{humanize_runtime(round(runtime_statistics['credible_interval_upper_bound']))}], "
10981099
f"\nmedian: {humanize_runtime(round(runtime_statistics['median']))}"
@@ -1196,18 +1197,23 @@ def run_optimized_candidate(
11961197
if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0:
11971198
logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.")
11981199
console.rule()
1199-
runtime_distribution, runtime_statistics = (
1200-
candidate_benchmarking_results.bayesian_nonparametric_bootstrap_analysis(100_000)
1201-
)
1202-
1203-
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
1204-
console.rule()
1205-
logger.debug(
1206-
f"Overall code runtime (95% Credible Interval) = ["
1207-
f"{humanize_runtime(round(runtime_statistics['credible_interval_lower_bound']))}, "
1208-
f"{humanize_runtime(round(runtime_statistics['credible_interval_upper_bound']))}], median: "
1209-
f"{humanize_runtime(round(runtime_statistics['median']))}"
1210-
)
1200+
runtime_distribution: npt.NDArray[np.float64] = np.array([])
1201+
runtime_statistics: dict[str, np.float64] = {}
1202+
else:
1203+
logger.debug(
1204+
f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}"
1205+
)
1206+
console.rule()
1207+
runtime_distribution, runtime_statistics = (
1208+
candidate_benchmarking_results.bayesian_nonparametric_bootstrap_analysis(100_000)
1209+
)
1210+
logger.debug(
1211+
f"Overall code summed runtime (95% Credible Interval) = ["
1212+
f"{humanize_runtime(round(runtime_statistics['credible_interval_lower_bound']))}, "
1213+
f"{humanize_runtime(round(runtime_statistics['credible_interval_upper_bound']))}], median: "
1214+
f"{humanize_runtime(round(runtime_statistics['median']))}"
1215+
)
1216+
console.rule()
12111217
return Success(
12121218
(
12131219
OptimizedCandidateResult(

codeflash/verification/bayesian_analysis.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,36 @@ def bootstrap_combined_function_input_runtime_means(
8484
return draws
8585

8686

87+
@nb.njit(parallel=True, fastmath=True, cache=True)
88+
def bootstrap_combined_function_input_runtime_sums(
89+
posterior_means: list[npt.NDArray[np.float64]], rngs: tuple[np.random.Generator, ...], bootstrap_size: int
90+
) -> npt.NDArray[np.float64]:
91+
"""Given a function, we have posterior draws for each input, and get an overall expected time across these inputs.
92+
93+
We make random draws from each input's distribution using the rngs random generators (one per computation thread),
94+
and compute their arithmetic mean.
95+
Returns an array of shape (bootstrap_size,).
96+
"""
97+
num_inputs = len(posterior_means)
98+
num_input_means = max([len(posterior_mean) for posterior_mean in posterior_means])
99+
draws = np.empty(bootstrap_size, dtype=np.float64)
100+
101+
num_threads = len(rngs)
102+
thread_remainder = bootstrap_size % num_threads
103+
num_bootstraps_per_thread = np.array([bootstrap_size // num_threads] * num_threads) + np.array(
104+
[1] * thread_remainder + [0] * (num_threads - thread_remainder)
105+
)
106+
thread_idx = [0, *list(np.cumsum(num_bootstraps_per_thread))]
107+
108+
for thread_id in nb.prange(num_threads):
109+
thread_draws = draws[thread_idx[thread_id] : thread_idx[thread_id + 1]]
110+
for bootstrap_id in range(num_bootstraps_per_thread[thread_id]):
111+
thread_draws[bootstrap_id] = sum(
112+
[input_means[rngs[thread_id].integers(0, num_input_means)] for input_means in posterior_means]
113+
)
114+
return draws
115+
116+
87117
def compute_statistics(distribution: npt.NDArray[np.float64], gamma: float = 0.95) -> dict[str, np.float64]:
88118
lower_p = (1.0 - gamma) / 2 * 100
89119
return {
@@ -105,6 +135,18 @@ def analyze_function_runtime_data(
105135
return function_runtime_distribution, compute_statistics(function_runtime_distribution)
106136

107137

138+
def analyze_function_runtime_sums_data(
139+
function_runtime_data: list[list[int]], bootstrap_size: int
140+
) -> tuple[npt.NDArray[np.float64], dict[str, np.float64]]:
141+
rng = np.random.default_rng()
142+
function_runtime_distribution = bootstrap_combined_function_input_runtime_sums(
143+
compute_function_runtime_posterior_means(function_runtime_data, bootstrap_size),
144+
tuple(rng.spawn(nb.get_num_threads())),
145+
bootstrap_size,
146+
)
147+
return function_runtime_distribution, compute_statistics(function_runtime_distribution)
148+
149+
108150
def compare_function_runtime_distributions(
109151
function1_runtime_distribution: npt.NDArray[np.float64], function2_runtime_distribution: npt.NDArray[np.float64]
110152
) -> dict[str, np.float64]:

codeflash/verification/test_results.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from rich.tree import Tree
1919

2020
from codeflash.cli_cmds.console import DEBUG_MODE, logger
21-
from codeflash.verification.bayesian_analysis import analyze_function_runtime_data
21+
from codeflash.verification.bayesian_analysis import analyze_function_runtime_sums_data
2222
from codeflash.verification.comparator import comparator
2323

2424

@@ -190,7 +190,9 @@ def total_passed_runtime(self) -> int:
190190
def bayesian_nonparametric_bootstrap_analysis(
191191
self, bootstrap_size: int
192192
) -> tuple[npt.NDArray[np.float64], dict[str, np.float64]]:
193-
return analyze_function_runtime_data(list(self.usable_runtime_data_by_test_case().values()), bootstrap_size)
193+
return analyze_function_runtime_sums_data(
194+
list(self.usable_runtime_data_by_test_case().values()), bootstrap_size
195+
)
194196

195197
def __iter__(self) -> Iterator[FunctionTestInvocation]:
196198
return iter(self.test_results)

0 commit comments

Comments
 (0)