Skip to content

Commit 3d33f8b

Browse files
committed
Merge branch 'refs/heads/main' into init_caching
2 parents ee8c94e + 1fd7e6a commit 3d33f8b

File tree

4 files changed

+94
-44
lines changed

4 files changed

+94
-44
lines changed

codeflash/either.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from __future__ import annotations
1+
from __future__ import annotations
22

33
from typing import Generic, TypeVar
44

codeflash/optimization/optimizer.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -572,10 +572,6 @@ def determine_best_candidate(
572572
)
573573
speedup_ratios[candidate.optimization_id] = perf_gain
574574

575-
speedup_stats = compare_function_runtime_distributions(
576-
original_code_runtime_distribution, candidate_runtime_distribution
577-
)
578-
579575
tree = Tree(f"Candidate #{candidate_index} - Sum of Minimum Runtimes")
580576
if speedup_critic(
581577
candidate_result, original_code_baseline.runtime, best_runtime_until_now
@@ -609,28 +605,33 @@ def determine_best_candidate(
609605
console.print(tree)
610606
console.rule()
611607

612-
tree = Tree(f"Candidate #{candidate_index} - Bayesian Bootstrapping Nonparametric Analysis")
613-
tree.add(
614-
f"Expected candidate runtime (95% Credible Interval) = ["
615-
f"{humanize_runtime(candidate_runtime_statistics['credible_interval_lower_bound'])}, "
616-
f"{humanize_runtime(candidate_runtime_statistics['credible_interval_upper_bound'])}], "
617-
f"\nmedian = {humanize_runtime(candidate_runtime_statistics['median'])}"
618-
f"\nSpeedup ratio of candidate vs original:"
619-
f"\n95% Credible Interval = [{speedup_stats['credible_interval_lower_bound']:.3f}X, "
620-
f"{speedup_stats['credible_interval_upper_bound']:.3f}X]"
621-
f"\nmedian = {speedup_stats['median']:.3f}X"
622-
)
623-
if speedup_stats["credible_interval_lower_bound"] > 1.0:
624-
tree.add("The candidate is faster than the original code with a 95% probability.")
625-
if speedup_stats["median"] > best_speedup_ratio_until_now:
626-
best_speedup_ratio_until_now = speedup_stats["median"]
627-
tree.add("This candidate is the best candidate so far.")
608+
if candidate_runtime_distribution.any() and candidate_runtime_statistics:
609+
speedup_stats = compare_function_runtime_distributions(
610+
original_code_runtime_distribution, candidate_runtime_distribution
611+
)
612+
tree = Tree(f"Candidate #{candidate_index} - Bayesian Bootstrapping Nonparametric Analysis")
613+
tree.add(
614+
f"Expected candidate summed runtime (95% Credible Interval) = ["
615+
f"{humanize_runtime(round(candidate_runtime_statistics['credible_interval_lower_bound']))}"
616+
f", "
617+
f"{humanize_runtime(round(candidate_runtime_statistics['credible_interval_upper_bound']))}]"
618+
f"\nMedian = {humanize_runtime(round(candidate_runtime_statistics['median']))}"
619+
f"\nSpeedup ratio of candidate vs original:"
620+
f"\n95% Credible Interval = [{speedup_stats['credible_interval_lower_bound']:.3f}X, "
621+
f"{speedup_stats['credible_interval_upper_bound']:.3f}X]"
622+
f"\nmedian = {speedup_stats['median']:.3f}X"
623+
)
624+
if speedup_stats["credible_interval_lower_bound"] > 1.0:
625+
tree.add("The candidate is faster than the original code with a 95% probability.")
626+
if speedup_stats["median"] > best_speedup_ratio_until_now:
627+
best_speedup_ratio_until_now = float(speedup_stats["median"])
628+
tree.add("This candidate is the best candidate so far.")
629+
else:
630+
tree.add("This candidate is not faster than the current fastest candidate.")
628631
else:
629-
tree.add("This candidate is not faster than the current fastest candidate.")
630-
else:
631-
tree.add("It is inconclusive whether the candidate is faster than the original code.")
632-
console.print(tree)
633-
console.rule()
632+
tree.add("It is inconclusive whether the candidate is faster than the original code.")
633+
console.print(tree)
634+
console.rule()
634635

635636
self.write_code_and_helpers(original_code, original_helper_code, function_to_optimize.file_path)
636637
except KeyboardInterrupt as e:
@@ -1087,9 +1088,6 @@ def establish_original_code_baseline(
10871088
console.rule()
10881089

10891090
total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index
1090-
runtime_distribution, runtime_statistics = benchmarking_results.bayesian_nonparametric_bootstrap_analysis(
1091-
100_000
1092-
)
10931091
functions_to_remove = [
10941092
result.id.test_function_name
10951093
for result in behavioral_results
@@ -1123,9 +1121,12 @@ def establish_original_code_baseline(
11231121
console.rule()
11241122
logger.debug(f"Total original code summed runtime (ns): {total_timing}")
11251123
console.rule()
1124+
runtime_distribution, runtime_statistics = benchmarking_results.bayesian_nonparametric_bootstrap_analysis(
1125+
100_000
1126+
)
11261127
logger.info(
11271128
f"Bayesian Bootstrapping Nonparametric Analysis"
1128-
f"\nExpected original code runtime (95% Credible Interval) = ["
1129+
f"\nExpected original code summed runtime (95% Credible Interval) = ["
11291130
f"{humanize_runtime(round(runtime_statistics['credible_interval_lower_bound']))}, "
11301131
f"{humanize_runtime(round(runtime_statistics['credible_interval_upper_bound']))}], "
11311132
f"\nmedian: {humanize_runtime(round(runtime_statistics['median']))}"
@@ -1245,18 +1246,23 @@ def run_optimized_candidate(
12451246
if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0:
12461247
logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.")
12471248
console.rule()
1248-
runtime_distribution, runtime_statistics = (
1249-
candidate_benchmarking_results.bayesian_nonparametric_bootstrap_analysis(100_000)
1250-
)
1251-
1252-
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
1253-
console.rule()
1254-
logger.debug(
1255-
f"Overall code runtime (95% Credible Interval) = ["
1256-
f"{humanize_runtime(round(runtime_statistics['credible_interval_lower_bound']))}, "
1257-
f"{humanize_runtime(round(runtime_statistics['credible_interval_upper_bound']))}], median: "
1258-
f"{humanize_runtime(round(runtime_statistics['median']))}"
1259-
)
1249+
runtime_distribution: npt.NDArray[np.float64] = np.array([])
1250+
runtime_statistics: dict[str, np.float64] = {}
1251+
else:
1252+
logger.debug(
1253+
f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}"
1254+
)
1255+
console.rule()
1256+
runtime_distribution, runtime_statistics = (
1257+
candidate_benchmarking_results.bayesian_nonparametric_bootstrap_analysis(100_000)
1258+
)
1259+
logger.debug(
1260+
f"Overall code summed runtime (95% Credible Interval) = ["
1261+
f"{humanize_runtime(round(runtime_statistics['credible_interval_lower_bound']))}, "
1262+
f"{humanize_runtime(round(runtime_statistics['credible_interval_upper_bound']))}], median: "
1263+
f"{humanize_runtime(round(runtime_statistics['median']))}"
1264+
)
1265+
console.rule()
12601266
return Success(
12611267
(
12621268
OptimizedCandidateResult(

codeflash/verification/bayesian_analysis.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,36 @@ def bootstrap_combined_function_input_runtime_means(
8484
return draws
8585

8686

87+
@nb.njit(parallel=True, fastmath=True, cache=True)
88+
def bootstrap_combined_function_input_runtime_sums(
89+
posterior_means: list[npt.NDArray[np.float64]], rngs: tuple[np.random.Generator, ...], bootstrap_size: int
90+
) -> npt.NDArray[np.float64]:
91+
"""Given a function, we have posterior draws for each input, and get an overall expected time across these inputs.
92+
93+
We make random draws from each input's distribution using the rngs random generators (one per computation thread),
94+
and compute their arithmetic mean.
95+
Returns an array of shape (bootstrap_size,).
96+
"""
97+
num_inputs = len(posterior_means)
98+
num_input_means = max([len(posterior_mean) for posterior_mean in posterior_means])
99+
draws = np.empty(bootstrap_size, dtype=np.float64)
100+
101+
num_threads = len(rngs)
102+
thread_remainder = bootstrap_size % num_threads
103+
num_bootstraps_per_thread = np.array([bootstrap_size // num_threads] * num_threads) + np.array(
104+
[1] * thread_remainder + [0] * (num_threads - thread_remainder)
105+
)
106+
thread_idx = [0, *list(np.cumsum(num_bootstraps_per_thread))]
107+
108+
for thread_id in nb.prange(num_threads):
109+
thread_draws = draws[thread_idx[thread_id] : thread_idx[thread_id + 1]]
110+
for bootstrap_id in range(num_bootstraps_per_thread[thread_id]):
111+
thread_draws[bootstrap_id] = sum(
112+
[input_means[rngs[thread_id].integers(0, num_input_means)] for input_means in posterior_means]
113+
)
114+
return draws
115+
116+
87117
def compute_statistics(distribution: npt.NDArray[np.float64], gamma: float = 0.95) -> dict[str, np.float64]:
88118
lower_p = (1.0 - gamma) / 2 * 100
89119
return {
@@ -105,6 +135,18 @@ def analyze_function_runtime_data(
105135
return function_runtime_distribution, compute_statistics(function_runtime_distribution)
106136

107137

138+
def analyze_function_runtime_sums_data(
139+
function_runtime_data: list[list[int]], bootstrap_size: int
140+
) -> tuple[npt.NDArray[np.float64], dict[str, np.float64]]:
141+
rng = np.random.default_rng()
142+
function_runtime_distribution = bootstrap_combined_function_input_runtime_sums(
143+
compute_function_runtime_posterior_means(function_runtime_data, bootstrap_size),
144+
tuple(rng.spawn(nb.get_num_threads())),
145+
bootstrap_size,
146+
)
147+
return function_runtime_distribution, compute_statistics(function_runtime_distribution)
148+
149+
108150
def compare_function_runtime_distributions(
109151
function1_runtime_distribution: npt.NDArray[np.float64], function2_runtime_distribution: npt.NDArray[np.float64]
110152
) -> dict[str, np.float64]:

codeflash/verification/test_results.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from rich.tree import Tree
1919

2020
from codeflash.cli_cmds.console import DEBUG_MODE, logger
21-
from codeflash.verification.bayesian_analysis import analyze_function_runtime_data
21+
from codeflash.verification.bayesian_analysis import analyze_function_runtime_sums_data
2222
from codeflash.verification.comparator import comparator
2323

2424

@@ -207,7 +207,9 @@ def total_passed_runtime(self) -> int:
207207
def bayesian_nonparametric_bootstrap_analysis(
208208
self, bootstrap_size: int
209209
) -> tuple[npt.NDArray[np.float64], dict[str, np.float64]]:
210-
return analyze_function_runtime_data(list(self.usable_runtime_data_by_test_case().values()), bootstrap_size)
210+
return analyze_function_runtime_sums_data(
211+
list(self.usable_runtime_data_by_test_case().values()), bootstrap_size
212+
)
211213

212214
def __iter__(self) -> Iterator[FunctionTestInvocation]:
213215
return iter(self.test_results)

0 commit comments

Comments
 (0)