-
Notifications
You must be signed in to change notification settings - Fork 15
feat: Destination-passing Style #125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThis PR refactors the apply, benchmark, and compile subsystems to support dual calling conventions (value-returning and destination-passing styles). It reorganizes input/output data structures from dictionaries to lists, migrates ApplyKey from dataclass to Pydantic-based frozen model, updates tracing runtime to accept explicit args/kwargs, and renames parameters across runners and utilities for consistency. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Areas requiring extra attention:
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @Ubospica, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for Destination-Passing Style (DPS) across the FlashInfer-Bench framework. It refactors the core Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant refactoring to support Destination-Passing Style (DPS) for solutions, which is a great enhancement for performance-oriented code. The changes span across the apply, tracing, timing, and benchmark systems, and the overall design for handling different calling conventions is clean and well-abstracted, especially within the Runnable class. I've identified a few issues related to argument handling and documentation consistency that should be addressed to improve the robustness and clarity of the new APIs. My main concerns are around ensuring duplicate arguments are handled correctly and that fallback functions are called in a consistent and flexible manner.
| if kwargs: | ||
| args = definition.merge_kwargs_to_args(args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The args tuple is being reassigned here, which causes the original positional arguments to be lost. This becomes an issue later when calling the fallback function, as it can no longer be called with the original *args and **kwargs, breaking consistency and flexibility. It's better to store the result of the merge in a new variable, e.g., merged_args, and use that for subsequent logic related to the specialized kernel call, while preserving the original args and kwargs for the fallback call. You will need to update subsequent uses of args to merged_args.
| if kwargs: | |
| args = definition.merge_kwargs_to_args(args, kwargs) | |
| merged_args = args | |
| if kwargs: | |
| merged_args = definition.merge_kwargs_to_args(args, kwargs) |
| if not kwargs: | ||
| return args | ||
|
|
||
| param_names = list(self.inputs.keys()) + list(self.outputs.keys()) | ||
| result = list(args) | ||
| for i in range(len(args), len(param_names)): | ||
| name = param_names[i] | ||
| if name in kwargs: | ||
| result.append(kwargs[name]) | ||
| else: | ||
| break | ||
| return tuple(result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The merge_kwargs_to_args method does not check for duplicate arguments that are provided both positionally and as a keyword. This can lead to surprising behavior where a keyword argument is silently ignored if it corresponds to a parameter already provided positionally. A robust implementation should raise a TypeError in this case, similar to how Python's function calls behave.
if not kwargs:
return args
param_names = list(self.inputs.keys()) + list(self.outputs.keys())
if len(args) > len(param_names):
raise TypeError(f"Too many positional arguments: got {len(args)}, expected at most {len(param_names)}")
# Check for duplicate arguments
positional_arg_names = set(param_names[:len(args)])
for name in kwargs:
if name in positional_arg_names:
raise TypeError(f"Got multiple values for argument '{name}'")
result = list(args)
for i in range(len(args), len(param_names)):
name = param_names[i]
if name in kwargs:
result.append(kwargs[name])
else:
break
return tuple(result)| The kernel name, or a resolver ``fn(*args) -> str`` that maps runtime | ||
| arguments to a kernel name (definition name). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for def_name_or_resolver states its signature is fn(*args) -> str, but the implementation in _dispatch_apply_or_tracing calls it with *args, **kwargs. The docstring should be updated to match the more flexible implementation.
| The kernel name, or a resolver ``fn(*args) -> str`` that maps runtime | |
| arguments to a kernel name (definition name). | |
| The kernel name, or a resolver ``fn(*args, **kwargs) -> str`` that maps runtime | |
| arguments to a kernel name (definition name). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
flashinfer_bench/apply/apply_api.py (1)
208-208: Update docstring examples to use new API.The docstring examples in
enable_applystill reference the oldruntime_kwargsparameter, which no longer exists in theapply()function signature. Update these to use the newargs/kwargsparameters.>>> # Direct usage >>> enable_apply("/path/to/traceset", cfg) >>> # Apply is now enabled - >>> out = apply("rmsnorm_d4096", runtime_kwargs={...}, fallback=ref_fn) + >>> out = apply("rmsnorm_d4096", args=(...,), fallback=ref_fn) >>> disable_apply() >>> # Apply is now disabled. >>> # Context manager usage >>> with enable_apply("/path/to/traceset", cfg): - ... out = apply("rmsnorm_d4096", runtime_kwargs={...}, fallback=ref_fn) + ... out = apply("rmsnorm_d4096", args=(...,), fallback=ref_fn) >>> # Apply is now disabled.Also applies to: 214-214
flashinfer_bench/tracing/runtime.py (1)
196-224: Stale docstring: parameters no longer match signature.The docstring still references
definition,axes, andnameparameters that were removed. Update to reflect the newvalanddtypeparameters.def _convert_arg_to_tensor( self, val: Union[int, float, bool, list, tuple, torch.Tensor], dtype: str ) -> Optional[torch.Tensor]: """Convert a runtime argument to a tensor for further dumping. If the conversion fails, log an error and return None. Parameters ---------- - definition : Definition - The workload definition containing axis specifications. - axes : Dict[str, int] - Runtime axis values provided during tracing. - name : str - Name of the argument to convert. - val : Any - The runtime argument to convert. + val : Union[int, float, bool, list, tuple, torch.Tensor] + The runtime argument to convert (scalar, list, or tensor). + dtype : str + Target dtype string for scalar/list conversion. Returns ------- - Optional[torch.Tensor] - The converted tensor. None if conversion fails. + torch.Tensor + The converted tensor. + + Raises + ------ + ValueError + If the value type is unsupported. """
🧹 Nitpick comments (15)
flashinfer_bench/bench/timing.py (2)
163-163: Consider validatingrep > 0.If
rep=0is passed,timeswill be empty, causing failures in_summarize_statistics(e.g.,min()on empty sequence). Adding a validation ensures clearer error messages.assert return_mode in ["min", "max", "mean", "median", "all"] + assert rep > 0, "rep must be at least 1"
202-202: Addstrict=Truetozipfor safety.Both
start_eventsandend_eventsare guaranteed to have the same length here, but addingstrict=Trueis a good defensive practice and satisfies the static analysis hint.- times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + times = [s.elapsed_time(e) for s, e in zip(start_events, end_events, strict=True)]flashinfer_bench/bench/benchmark.py (1)
119-155: LGTM - Consistent workload naming throughout benchmark execution.The variable rename from
wltoworkloadimproves code readability and aligns with the project-wide standardization. All references toworkload.uuidare consistent, and the functional logic remains unchanged.Optional: Consider using
logging.exceptionfor better error context.On line 136, you could use
logger.exceptioninstead oflogger.errorto automatically include the exception traceback:- logger.error(f"Failed to run workload {workload.uuid}: {e}") + logger.exception(f"Failed to run workload {workload.uuid}")This would provide more debugging context when workload execution fails.
flashinfer_bench/bench/evaluators/utils.py (1)
38-41: Consider addingstrict=Trueto zip for defensive programming.The
zip()call pairsoutput_shapeswithdtypes. Addingstrict=Trueensures both lists have the same length and catches potential inconsistencies early.Apply this diff:
dtypes = definition.torch_output_dtypes return [ torch.empty(shape, dtype=dtype, device=device) - for shape, dtype in zip(output_shapes, dtypes) + for shape, dtype in zip(output_shapes, dtypes, strict=True) ]flashinfer_bench/bench/evaluators/lowbit.py (1)
66-66: Addstrict=Trueto zip to catch length mismatches.The
zip(out, ref_out)withoutstrict=Truecould silently ignore cases where the number of output tensors differs between the solution and reference. This could mask bugs where a solution produces a different number of outputs.- for sol_tensor, ref_tensor in zip(out, ref_out): + for sol_tensor, ref_tensor in zip(out, ref_out, strict=True):flashinfer_bench/tracing/config.py (2)
141-141: Addstrict=Truetozip()to catch length mismatches.If
namesandvalueshave different lengths,zip()will silently truncate. Since this indicates a programming error, it should fail explicitly.- name_to_value = dict(zip(names, values)) + name_to_value = dict(zip(names, values, strict=True))
148-151: Consider usingTypeErrorfor type validation errors.When the
input_dump_policyis neither a list nor callable, or when the callable returns a non-list, these are type errors rather than value errors. Static analysis suggestsTypeErroris more appropriate here.else: - raise ValueError("input_dump_policy must be a list of strings or a callable") + raise TypeError("input_dump_policy must be a list of strings or a callable") if not isinstance(names_to_dump, list): - raise ValueError("input_dump_policy callable must return a list of strings") + raise TypeError("input_dump_policy callable must return a list of strings")flashinfer_bench/tracing/runtime.py (1)
159-161: Uselogger.exceptionto preserve stack trace.When catching exceptions,
logger.exceptionautomatically includes the stack trace, which aids debugging. The same applies to line 173.except ValueError as e: - logger.error(f"Error getting axis values for {def_name}: {e}") + logger.exception(f"Error getting axis values for {def_name}") returnSimilarly for line 173:
except ValueError as e: - logger.error(f"Error converting argument '{name}' to tensor for {def_name}: {e}") + logger.exception(f"Error converting argument '{name}' to tensor for {def_name}") returnflashinfer_bench/compile/runnable.py (2)
224-226: Addstrict=Truetozip()for shape/dtype pairing.If
output_shapesanddtype_listhave different lengths (indicating a definition inconsistency), this should fail explicitly rather than silently truncating.- for shape, dtype in zip(output_shapes, dtype_list): + for shape, dtype in zip(output_shapes, dtype_list, strict=True): shape = shape if shape is not None else () output_tensors.append(torch.empty(shape, dtype=dtype, device=device))
127-129: Minor: Unnecessary string split in error message.The string concatenation
"...must " "be set."can be a single string.raise ValueError( - "When calling in keyword passing style, metadata.full_definition must " "be set." + "When calling in keyword passing style, metadata.definition must be set." )Note: The error message also references
full_definitionbut the attribute isdefinition.flashinfer_bench/bench/evaluators/default.py (1)
122-122: Addstrict=Truetozip()for output comparison.If
outandref_outhave different lengths, this indicates a mismatch in the number of outputs between the solution and reference, which should be caught explicitly rather than silently skipping outputs.- for sol_tensor, ref_tensor in zip(out, ref_out): + for sol_tensor, ref_tensor in zip(out, ref_out, strict=True):flashinfer_bench/bench/evaluators/sampling.py (1)
328-353: Consider addingstrict=Truetozip()for safety.The
zip(input_names, inputs)at line 330 assumes both iterables have the same length. While this should always be true by construction (both derived from the same definition), addingstrict=Truewould catch any inconsistency early.- for name, value in zip(input_names, inputs): + for name, value in zip(input_names, inputs, strict=True):flashinfer_bench/data/definition.py (3)
279-296: Addstrict=Trueto outer zip for safety.At line 279, the zip between
self.inputs.items()andinput_shapesshould have matching lengths. Addingstrict=Truewould catch mismatches early rather than silently truncating.The inner zip at line 287 is protected by the dimension check at lines 282-286, so it's less critical.
- for (inp_name, inp_spec), inp_shape in zip(self.inputs.items(), input_shapes): + for (inp_name, inp_spec), inp_shape in zip(self.inputs.items(), input_shapes, strict=True):
305-321: Minor: Docstring references old method name.Line 308 mentions "get_var_axes_values" but the actual method is
get_axes_values. This appears to be a documentation artifact from renaming.- Convenience method that combines extract_shapes and get_var_axes_values. + Convenience method that combines shape extraction and get_axes_values.
180-186: Redundant None check onconstraints.Since
constraintsis defined withdefault_factory=list(line 132), it will never beNone- it will be an empty list at minimum. Theif self.constraints is not None:check is redundant.- if self.constraints is not None: - for constraint in self.constraints: - try: - ast.parse(constraint, mode="eval") - except SyntaxError as e: - raise ValueError(f"Constraints must be valid Python expressions: {e}") from e + for constraint in self.constraints: + try: + ast.parse(constraint, mode="eval") + except SyntaxError as e: + raise ValueError(f"Constraints must be valid Python expressions: {e}") from e
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (43)
examples/win_at_p.py(1 hunks)flashinfer_bench/apply/apply_api.py(4 hunks)flashinfer_bench/apply/key.py(2 hunks)flashinfer_bench/apply/runtime.py(4 hunks)flashinfer_bench/apply/table.py(3 hunks)flashinfer_bench/bench/benchmark.py(1 hunks)flashinfer_bench/bench/evaluators/default.py(5 hunks)flashinfer_bench/bench/evaluators/evaluator.py(5 hunks)flashinfer_bench/bench/evaluators/lowbit.py(3 hunks)flashinfer_bench/bench/evaluators/sampling.py(6 hunks)flashinfer_bench/bench/evaluators/utils.py(1 hunks)flashinfer_bench/bench/runner/isolated_runner.py(4 hunks)flashinfer_bench/bench/runner/persistent_runner.py(5 hunks)flashinfer_bench/bench/runner/runner.py(2 hunks)flashinfer_bench/bench/timing.py(3 hunks)flashinfer_bench/bench/utils.py(2 hunks)flashinfer_bench/compile/builder.py(3 hunks)flashinfer_bench/compile/builders/python_builder.py(1 hunks)flashinfer_bench/compile/builders/torch_builder.py(1 hunks)flashinfer_bench/compile/builders/tvm_ffi_builder.py(2 hunks)flashinfer_bench/compile/runnable.py(4 hunks)flashinfer_bench/data/__init__.py(2 hunks)flashinfer_bench/data/definition.py(10 hunks)flashinfer_bench/data/solution.py(1 hunks)flashinfer_bench/data/trace.py(1 hunks)flashinfer_bench/data/workload.py(1 hunks)flashinfer_bench/tracing/__init__.py(1 hunks)flashinfer_bench/tracing/builtin/policies.py(1 hunks)flashinfer_bench/tracing/config.py(2 hunks)flashinfer_bench/tracing/filter_policy.py(0 hunks)flashinfer_bench/tracing/runtime.py(6 hunks)flashinfer_bench/utils.py(4 hunks)tests/apply/test_runtime.py(6 hunks)tests/bench/test_evaluator.py(7 hunks)tests/bench/test_isolated_runner.py(7 hunks)tests/bench/test_persistent_runner.py(14 hunks)tests/compile/test_builder.py(2 hunks)tests/compile/test_runnable.py(2 hunks)tests/data/test_definition.py(1 hunks)tests/data/test_load_dump_schema.py(3 hunks)tests/data/test_trace.py(2 hunks)tests/tracing/test_tracing_config.py(4 hunks)tests/tracing/test_tracing_runtime.py(2 hunks)
💤 Files with no reviewable changes (1)
- flashinfer_bench/tracing/filter_policy.py
🧰 Additional context used
🧬 Code graph analysis (31)
tests/data/test_trace.py (2)
flashinfer_bench/data/workload.py (1)
Workload(52-66)flashinfer_bench/data/trace.py (1)
Trace(153-193)
flashinfer_bench/bench/evaluators/utils.py (2)
flashinfer_bench/data/definition.py (3)
get_axes_values_from_inputs(305-321)get_output_shapes(386-406)torch_output_dtypes(420-428)flashinfer_bench/bench/utils.py (1)
to_tensor(55-60)
flashinfer_bench/bench/runner/isolated_runner.py (3)
flashinfer_bench/data/workload.py (1)
Workload(52-66)web/apps/web/lib/schemas/trace.ts (1)
Workload(142-142)flashinfer_bench/bench/runner/persistent_runner.py (1)
run_ref(246-262)
flashinfer_bench/data/trace.py (3)
flashinfer_bench/data/utils.py (1)
BaseModelWithDocstrings(12-15)flashinfer_bench/data/workload.py (1)
Workload(52-66)web/apps/web/lib/schemas/trace.ts (1)
Workload(142-142)
tests/data/test_load_dump_schema.py (3)
flashinfer_bench/data/workload.py (1)
Workload(52-66)web/apps/web/lib/schemas/trace.ts (2)
Workload(142-142)Trace(144-144)flashinfer_bench/data/trace.py (1)
Trace(153-193)
flashinfer_bench/bench/benchmark.py (3)
flashinfer_bench/bench/runner/runner.py (1)
run_workload(39-46)flashinfer_bench/data/trace.py (2)
Trace(153-193)EvaluationStatus(67-87)web/apps/web/lib/schemas/trace.ts (1)
Trace(144-144)
flashinfer_bench/apply/apply_api.py (3)
tests/apply/test_runtime.py (1)
fallback(507-509)flashinfer_bench/apply/runtime.py (2)
get_apply_runtime(39-50)dispatch(108-209)flashinfer_bench/tracing/runtime.py (2)
get_tracing_runtime(405-413)collect(92-194)
flashinfer_bench/bench/utils.py (3)
flashinfer_bench/data/definition.py (1)
get_input_shapes(364-384)flashinfer_bench/data/workload.py (1)
Workload(52-66)flashinfer_bench/utils.py (1)
dtype_str_to_torch_dtype(49-55)
flashinfer_bench/tracing/config.py (2)
flashinfer_bench/tracing/workload_entry.py (1)
WorkloadEntry(6-23)flashinfer_bench/tracing/builtin/policies.py (15)
submit(65-67)submit(89-92)submit(124-130)submit(192-202)submit(222-224)drain(69-73)drain(94-98)drain(132-136)drain(204-208)drain(226-228)reset(75-77)reset(100-102)reset(138-141)reset(210-213)reset(230-232)
flashinfer_bench/apply/table.py (2)
flashinfer_bench/env.py (1)
get_fib_cache_path(46-57)flashinfer_bench/apply/key.py (1)
ApplyKey(11-22)
flashinfer_bench/compile/builders/torch_builder.py (3)
flashinfer_bench/compile/builders/python_builder.py (2)
cleaner(63-77)_get_cleaner(44-79)flashinfer_bench/compile/builder.py (1)
_try_validate_signature(129-194)flashinfer_bench/compile/runnable.py (1)
Runnable(40-279)
flashinfer_bench/apply/key.py (2)
flashinfer_bench/data/definition.py (1)
get_axes_values_from_inputs(305-321)flashinfer_bench/data/workload.py (1)
Workload(52-66)
flashinfer_bench/bench/evaluators/lowbit.py (3)
flashinfer_bench/compile/runnable.py (1)
Runnable(40-279)flashinfer_bench/data/definition.py (1)
Definition(101-458)flashinfer_bench/bench/evaluators/utils.py (2)
allocate_outputs(14-41)normalize_result(44-90)
flashinfer_bench/apply/runtime.py (3)
tests/apply/test_runtime.py (1)
fallback(507-509)flashinfer_bench/data/definition.py (1)
merge_kwargs_to_args(430-458)flashinfer_bench/apply/key.py (4)
ApplyKeyFactory(74-89)specialize(87-89)build_from_args(30-32)build_from_args(49-51)
tests/data/test_definition.py (1)
flashinfer_bench/data/definition.py (6)
AxisVar(35-46)TensorSpec(81-94)merge_kwargs_to_args(430-458)AxisConst(19-32)get_axes_values_from_inputs(305-321)get_axes_values(259-303)
flashinfer_bench/data/workload.py (1)
flashinfer_bench/data/utils.py (1)
BaseModelWithDocstrings(12-15)
tests/bench/test_evaluator.py (5)
tests/bench/test_persistent_runner.py (1)
_simple_def(42-50)flashinfer_bench/bench/evaluators/default.py (1)
DefaultEvaluator(32-218)flashinfer_bench/bench/evaluators/evaluator.py (1)
evaluate(66-108)flashinfer_bench/bench/evaluators/sampling.py (1)
SamplingEvaluator(42-232)flashinfer_bench/bench/evaluators/lowbit.py (1)
LowBitEvaluator(19-119)
tests/compile/test_builder.py (1)
flashinfer_bench/compile/runnable.py (1)
RunnableMetadata(17-37)
flashinfer_bench/bench/runner/runner.py (1)
flashinfer_bench/data/workload.py (1)
Workload(52-66)
flashinfer_bench/compile/runnable.py (2)
flashinfer_bench/data/definition.py (3)
Definition(101-458)get_axes_values_from_inputs(305-321)get_output_shapes(386-406)flashinfer_bench/utils.py (1)
dtype_str_to_torch_dtype(49-55)
tests/compile/test_runnable.py (2)
flashinfer_bench/compile/runnable.py (6)
Runnable(40-279)RunnableMetadata(17-37)call_destination_passing(156-202)call_value_returning(230-266)call_kwargs(136-144)_allocate_output_tensors(204-228)flashinfer_bench/data/definition.py (3)
AxisConst(19-32)AxisVar(35-46)Definition(101-458)
tests/bench/test_isolated_runner.py (5)
flashinfer_bench/data/workload.py (3)
Workload(52-66)RandomInput(8-16)ScalarInput(19-29)web/apps/web/lib/schemas/trace.ts (1)
Workload(142-142)flashinfer_bench/bench/utils.py (2)
gen_inputs(206-239)load_safetensors(168-203)flashinfer_bench/bench/runner/isolated_runner.py (1)
run_ref(44-60)flashinfer_bench/bench/runner/persistent_runner.py (1)
run_ref(246-262)
flashinfer_bench/bench/evaluators/evaluator.py (1)
flashinfer_bench/bench/runner/runner.py (1)
DeviceBaseline(26-32)
flashinfer_bench/compile/builder.py (2)
flashinfer_bench/data/definition.py (1)
Definition(101-458)flashinfer_bench/data/solution.py (1)
Solution(101-214)
flashinfer_bench/compile/builders/python_builder.py (1)
flashinfer_bench/compile/builder.py (1)
_try_validate_signature(129-194)
flashinfer_bench/bench/evaluators/sampling.py (4)
flashinfer_bench/data/trace.py (3)
Correctness(13-36)Evaluation(90-150)EvaluationStatus(67-87)flashinfer_bench/data/definition.py (1)
Definition(101-458)flashinfer_bench/bench/evaluators/default.py (1)
DefaultEvaluator(32-218)flashinfer_bench/bench/evaluators/utils.py (2)
allocate_outputs(14-41)normalize_result(44-90)
flashinfer_bench/utils.py (1)
flashinfer_bench/data/trace.py (1)
Environment(54-64)
flashinfer_bench/bench/timing.py (1)
flashinfer_bench/compile/runnable.py (1)
Runnable(40-279)
tests/tracing/test_tracing_runtime.py (5)
flashinfer_bench/tracing/runtime.py (2)
collect(92-194)TracingRuntime(29-387)flashinfer_bench/data/trace_set.py (1)
TraceSet(22-476)flashinfer_bench/tracing/config.py (1)
TracingConfig(67-158)flashinfer_bench/tracing/builtin/policies.py (1)
KeepFirstKPolicy(80-102)flashinfer_bench/data/definition.py (3)
Definition(101-458)AxisVar(35-46)TensorSpec(81-94)
flashinfer_bench/tracing/__init__.py (1)
flashinfer_bench/tracing/config.py (2)
FilterPolicy(24-59)TracingConfig(67-158)
flashinfer_bench/data/definition.py (4)
flashinfer_bench/data/utils.py (1)
BaseModelWithDocstrings(12-15)flashinfer_bench/utils.py (1)
dtype_str_to_torch_dtype(49-55)web/apps/web/lib/schemas/trace.ts (1)
Definition(139-139)web/apps/web/scripts/prebuild.mjs (2)
result(27-31)i(83-83)
🪛 Ruff (0.14.7)
flashinfer_bench/bench/evaluators/utils.py
40-40: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
81-83: Avoid specifying long messages outside the exception class
(TRY003)
88-88: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer_bench/bench/benchmark.py
136-136: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
flashinfer_bench/apply/apply_api.py
179-179: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer_bench/bench/utils.py
174-174: Avoid specifying long messages outside the exception class
(TRY003)
226-226: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer_bench/tracing/config.py
141-141: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
148-148: Prefer TypeError exception for invalid type
(TRY004)
148-148: Avoid specifying long messages outside the exception class
(TRY003)
151-151: Prefer TypeError exception for invalid type
(TRY004)
151-151: Avoid specifying long messages outside the exception class
(TRY003)
156-156: Avoid specifying long messages outside the exception class
(TRY003)
tests/apply/test_runtime.py
178-178: Unused lambda argument: args
(ARG005)
178-178: Unused lambda argument: kwargs
(ARG005)
305-305: Unpacked variable d is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
323-323: Unpacked variable d is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
340-340: Unpacked variable d is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
357-357: Unpacked variable d is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
458-458: Unpacked variable d is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
479-479: Unpacked variable d is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
498-498: Unpacked variable d is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
507-507: Unused function argument: args
(ARG001)
522-522: Unpacked variable d is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
flashinfer_bench/compile/builders/tvm_ffi_builder.py
292-292: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer_bench/apply/key.py
57-57: Unused method argument: args
(ARG002)
flashinfer_bench/bench/evaluators/lowbit.py
58-58: Do not catch blind exception: Exception
(BLE001)
66-66: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
flashinfer_bench/apply/runtime.py
154-154: Avoid specifying long messages outside the exception class
(TRY003)
167-171: Avoid specifying long messages outside the exception class
(TRY003)
201-201: Avoid specifying long messages outside the exception class
(TRY003)
tests/tracing/test_tracing_config.py
370-370: Unused function argument: inputs
(ARG001)
381-381: Unused function argument: inputs
(ARG001)
399-399: Unused function argument: inputs
(ARG001)
tests/data/test_definition.py
204-204: Unused method argument: definition
(ARG002)
tests/bench/test_evaluator.py
149-149: Unused function argument: args
(ARG001)
149-149: Unused function argument: kwargs
(ARG001)
150-150: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer_bench/tracing/runtime.py
160-160: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
173-173: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
224-224: Prefer TypeError exception for invalid type
(TRY004)
224-224: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer_bench/bench/evaluators/default.py
114-114: Do not catch blind exception: Exception
(BLE001)
122-122: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
flashinfer_bench/compile/runnable.py
127-129: Avoid specifying long messages outside the exception class
(TRY003)
150-152: Avoid specifying long messages outside the exception class
(TRY003)
175-178: Avoid specifying long messages outside the exception class
(TRY003)
191-194: Avoid specifying long messages outside the exception class
(TRY003)
196-196: Loop control variable result overrides iterable it iterates
(B020)
196-196: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
198-201: Prefer TypeError exception for invalid type
(TRY004)
198-201: Avoid specifying long messages outside the exception class
(TRY003)
209-212: Avoid specifying long messages outside the exception class
(TRY003)
224-224: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
tests/compile/test_runnable.py
28-28: Unused function argument: args
(ARG001)
143-143: Unused function argument: A
(ARG001)
304-304: Unused lambda argument: args
(ARG005)
335-335: Unused lambda argument: args
(ARG005)
flashinfer_bench/compile/builder.py
158-162: Avoid specifying long messages outside the exception class
(TRY003)
165-168: Avoid specifying long messages outside the exception class
(TRY003)
178-181: Avoid specifying long messages outside the exception class
(TRY003)
185-188: Avoid specifying long messages outside the exception class
(TRY003)
191-194: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer_bench/bench/evaluators/sampling.py
30-30: Avoid specifying long messages outside the exception class
(TRY003)
140-140: Do not catch blind exception: Exception
(BLE001)
330-330: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
flashinfer_bench/utils.py
51-51: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer_bench/bench/timing.py
79-79: Avoid specifying long messages outside the exception class
(TRY003)
110-110: Avoid specifying long messages outside the exception class
(TRY003)
202-202: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
tests/tracing/test_tracing_runtime.py
100-100: Loop control variable i not used within loop body
Rename unused i to _i
(B007)
142-142: Unused method argument: minimal_traceset
(ARG002)
flashinfer_bench/data/definition.py
168-168: Avoid specifying long messages outside the exception class
(TRY003)
279-279: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
283-286: Avoid specifying long messages outside the exception class
(TRY003)
287-287: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
291-294: Avoid specifying long messages outside the exception class
(TRY003)
299-302: Avoid specifying long messages outside the exception class
(TRY003)
358-358: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.13
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.12
- GitHub Check: Run unit tests on ubuntu-latest and Python 3.11
🔇 Additional comments (79)
flashinfer_bench/bench/timing.py (4)
42-69: Verify multiprocessing lock semantics match intended use case.The
multiprocessing.Lockobjects stored in_device_lockswill only provide cross-process synchronization if processes are forked (inheriting the locks) rather than spawned. On spawn-based platforms (Windows, macOS with Python 3.8+), each process gets independent copies, defeating serialization.If cross-process serialization is required with spawn, consider using
multiprocessing.Manager().Lock()instead. If this is only used within a single process (multithreading) or with fork, the current approach is fine.
89-110: LGTM!The statistical summarization logic is correct and covers the documented return modes appropriately.
206-212: LGTM!The cloning logic correctly handles tensor arguments while passing through non-tensor values. This ensures each benchmark iteration operates on independent data.
215-250: LGTM!The
time_runnablefunction correctly:
- Serializes access per-device via locking
- Uses the setup mechanism to exclude argument cloning from timing
- Properly unpacks the cloned arguments when invoking the Runnable
This aligns well with the Runnable contract described in the relevant code snippets.
examples/win_at_p.py (1)
112-113: LGTM - Variable naming improved for clarity.The rename from
wltoworkloadenhances readability and aligns with the broader codebase standardization introduced in this PR.flashinfer_bench/data/workload.py (4)
8-17: LGTM - RandomInput model is well-defined.The model correctly uses Pydantic validation with a literal type discriminator, enabling proper union type discrimination for InputSpec.
19-30: LGTM - ScalarInput model properly validates scalar values.The Union[int, float, bool] type for the value field appropriately covers the scalar input types mentioned in the docstring.
32-46: LGTM - SafetensorsInput model is well-structured.The use of NonEmptyString for path and tensor_key ensures validation at the model level, preventing empty values from being accepted.
52-66: LGTM - Workload model provides comprehensive configuration.The model appropriately uses Dict types for axes and inputs, with proper validation constraints (NonNegativeInt, NonEmptyString). The uuid field enables workload tracking across the benchmark system.
tests/data/test_trace.py (1)
99-116: LGTM - Test updated with improved variable naming.The rename from
wltoworkloadin the test maintains consistency with the broader refactoring while preserving all test logic and assertions.flashinfer_bench/data/trace.py (1)
5-10: LGTM - Import reorganization supports modularization.The changes correctly import Workload from the new workload module, improving code organization by separating workload specifications from trace definitions.
tests/data/test_load_dump_schema.py (2)
50-59: LGTM - Consistent variable naming in test fixtures.The variable rename enhances test readability and aligns with the project-wide standardization.
74-75: LGTM - Test updated with improved naming.The workload variable naming is consistent with the rest of the test suite.
flashinfer_bench/tracing/__init__.py (1)
14-14: LGTM - Import source updated to reflect new module organization.The InputDumpPolicyFunction is now imported from the config module, which appears to centralize policy-related types. The public API via
__all__remains unchanged, maintaining backward compatibility.tests/data/test_definition.py (4)
121-132: LGTM - Important validation test added.This test ensures that input and output parameter names remain distinct, preventing ambiguity in parameter resolution. This is a critical constraint for the DPS/VR calling convention support.
134-176: LGTM - Comprehensive test coverage for argument merging.The test class thoroughly validates the merge_kwargs_to_args behavior across multiple scenarios:
- Empty kwargs
- Kwargs-only
- Mixed args and kwargs
- Output parameters as kwargs
- Partial kwargs with missing parameters
This coverage is essential for ensuring correct parameter handling in the new calling conventions.
178-224: LGTM - Axis extraction tests cover key scenarios.The tests validate axis value extraction from both tensor inputs (with
.shapeattribute) and scalar inputs (without.shape). The test correctly handles the case where scalar inputs are skipped during axis extraction.Note: The static analysis warning about the unused
definitionparameter in line 204 is a false positive - thedefinitionparameter is the pytest fixture, andscalar_defis intentionally created within the test to demonstrate scalar input handling.
226-254: LGTM - Shape-based axis extraction properly tested.The tests validate both successful axis extraction from shape tuples and proper error handling when axis values are inconsistent across inputs. This is crucial for catching shape mismatches early.
flashinfer_bench/bench/runner/isolated_runner.py (1)
414-421: LGTM! Consistent parameter naming.The rename from
wltoworkloadimproves readability and aligns with the broader API standardization across the codebase.flashinfer_bench/bench/runner/runner.py (2)
42-42: LGTM! Consistent parameter naming.The rename from
wltoworkloadimproves code clarity.
30-31: Verify all consumers of the API change.The
inputsandoutputsfields have changed from dictionary-based to list-based structures. This is a breaking change that affects how downstream code accesses these values. Ensure all consumers ofDeviceBaseline(including evaluators, serialization logic, and test utilities) have been updated to use positional indexing instead of named access.tests/compile/test_runnable.py (2)
10-22: LGTM! Well-structured test helper.The
_make_definition()helper provides a clear, reusable definition for testing core Runnable functionality.
83-239: Excellent test coverage for DPS/VR functionality.These test classes comprehensively cover destination-passing style (DPS) and value-returning (VR) scenarios, including:
- Native DPS/VR callables
- Conversions between styles
- Edge cases (no outputs, multiple outputs)
The test structure is clear and well-organized.
flashinfer_bench/apply/table.py (2)
113-113: LGTM! Idiomatic Pydantic usage.Migrating from custom
from_encoded()to Pydantic's standardmodel_validate_json()is the correct approach for deserializing JSON to a Pydantic model.
140-140: LGTM! Idiomatic Pydantic usage.Using
model_dump_json()instead of a customencode()method is the standard Pydantic approach for JSON serialization.flashinfer_bench/data/solution.py (2)
79-79: LGTM! Best practice for mutable defaults.Changing from
default=[]todefault_factory=listis the correct approach for mutable field defaults in Pydantic. This prevents unintended sharing of the default list instance across model instances.
82-82: LGTM! New field for DPS feature.The
destination_passing_stylefield addition supports the PR's core feature of enabling destination-passing style for solutions.tests/bench/test_persistent_runner.py (1)
89-89: LGTM! Consistent variable naming.The rename from
wltoworkloadthroughout the test suite improves readability and aligns with the API changes in the runner modules.Also applies to: 114-114, 162-162, 225-225, 274-274, 323-325, 364-364, 418-420
tests/bench/test_isolated_runner.py (3)
109-118: LGTM! Clear documentation of list-based API.The comment explaining that
gen_inputsreturns a list in definition order is helpful for understanding the new API. The positional access pattern (e.g.,out[0]for X,out[1]for Y) is correctly aligned with the list-based input structure introduced in this PR.
131-143: LGTM! Consistent with API changes.The updates to use
safe_tensors(instead ofstensors) and list-based indexing are consistent with the broader API refactoring.
193-193: LGTM! Consistent variable naming.The rename from
wltoworkloadimproves readability.Also applies to: 222-222
flashinfer_bench/tracing/builtin/policies.py (1)
18-20: LGTM! Proper use of TYPE_CHECKING.Wrapping type-only imports in a
TYPE_CHECKINGguard is the standard pattern for avoiding runtime import overhead and potential circular dependencies while maintaining type checking capabilities.flashinfer_bench/bench/runner/persistent_runner.py (1)
508-583: LGTM! Consistent parameter rename fromwltoworkload.The renaming is applied consistently across the method signature, docstring, and all internal usages including
run_refcalls and error logging. This aligns with the broader codebase terminology updates in this PR.flashinfer_bench/compile/builders/python_builder.py (1)
141-152: LGTM! Well-structured metadata updates and signature validation.The
RunnableMetadataconstruction correctly uses the new field names (definition_name,solution_name) and includesdestination_passing_stylefrom the solution spec. The signature validation call before returning theRunnableensures early detection of interface mismatches. The inclusion of the fulldefinitionobject enables richer runtime context for downstream components.tests/compile/test_builder.py (2)
38-44: LGTM! Test updated to match newRunnableMetadataAPI.The field names are correctly updated from
definition/solutiontodefinition_name/solution_name, aligning with the updatedRunnableMetadatamodel.
75-79: LGTM! Mock builder updated for consistency.The mock
RunnableMetadataconstruction correctly uses the new field names with empty strings, which is appropriate for testing dispatch logic where the metadata content is not the focus.flashinfer_bench/bench/utils.py (3)
168-203: LGTM! Consistent rename toworkloadandsafe_tensors.The parameter and variable renames are applied consistently throughout
load_safetensors. The logic remains unchanged.
206-239: Return type changed fromDicttoListto support positional arguments.The change to return a
List[Any]in definition order aligns with the DPS calling convention where inputs are passed positionally. The docstring clearly documents this behavior.However, the same potential issue exists at line 232 where
shapes[name]is accessed - verify thatget_input_shapesreturns a structure that supports string key access.
176-192: Potential type mismatch:get_input_shapesreturn type requires verification.The code at lines 176-192 accesses
expected[name]using dictionary-style key access, whereexpectedis assigned fromdefinition.get_input_shapes(workload.axes). The review claims this method returnsList[Optional[Tuple[int, ...]]]based on definition.py lines 363-383, which would cause aTypeErrorat runtime since list indices must be integers, not string keys.This claim requires verification of the actual return type annotation and implementation of
get_input_shapes.flashinfer_bench/compile/builder.py (3)
33-38: LGTM! Class attributes documented.Adding explicit class attribute declarations with docstrings improves code clarity and IDE support.
129-154: LGTM! Graceful handling of unavailable signatures.The method correctly catches
ValueErrorandTypeErrorwheninspect.signaturefails (e.g., for built-in functions or C extensions), allowing builds to proceed without validation in such cases.
156-168: LGTM! DPS and VR parameter count validation.The parameter count checks are correct:
- DPS: expects
len(inputs) + len(outputs)parameters- VR: expects
len(inputs)parametersflashinfer_bench/bench/evaluators/utils.py (1)
44-90: LGTM! Well-structured result normalization.The function correctly handles various return types and provides clear error messages when validation fails. The descriptive exception messages are appropriate for helping users diagnose issues.
flashinfer_bench/data/__init__.py (1)
12-14: LGTM! Clean module reorganization.The refactoring moves workload-related types to a dedicated module while maintaining the same public API. The updated grouping in
__all__improves clarity.Also applies to: 28-34
flashinfer_bench/compile/builders/torch_builder.py (1)
158-171: LGTM! Enhanced metadata and signature validation.The updated metadata structure with explicit field names (
definition_name,solution_name,destination_passing_style) improves clarity. Adding signature validation via_try_validate_signaturecatches mismatches early, and using the callable directly simplifies the implementation.tests/tracing/test_tracing_config.py (1)
323-406: LGTM! Comprehensive test coverage for the new API.The
TestGetInputsToDumpclass thoroughly exercises the updatedget_inputs_to_dump(names, values)signature with tests covering static lists, callable policies, validation errors, and edge cases. The test structure is clear and well-organized.tests/tracing/test_tracing_runtime.py (1)
113-221: LGTM! Thorough testing of args/kwargs handling and calling conventions.The new test classes
TestCollectArgsKwargsandTestCollectCallingConventionprovide excellent coverage of the updated API. Tests validate kwargs merging, argument validation, and correct behavior for both value-returning and destination-passing styles.flashinfer_bench/compile/builders/tvm_ffi_builder.py (1)
6-6: LGTM! Consistent updates aligned with torch_builder.The metadata structure updates and signature validation addition mirror the changes in
torch_builder.py, maintaining consistency across builders. The addition of theinspectimport supports the new validation functionality.Also applies to: 280-297
tests/bench/test_evaluator.py (2)
61-80: Excellent mock design for testing both calling conventions.The
_make_dps_mockand_make_vr_mockhelper functions clearly separate the behavior of destination-passing and value-returning styles. The DPS mock's side effect correctly writes to the output argument, while the VR mock simply returns the result.
87-488: LGTM! Comprehensive test coverage for all evaluator variants.The test suite thoroughly exercises all evaluator types (Default, Sampling, LowBit) with both DPS and VR styles. Each test class validates:
- Success cases with correct outputs
- Shape mismatch detection
- Numerical error detection
- Runtime error handling
- Evaluator-specific metrics (e.g., matched_ratio for LowBit)
The organization into separate test classes for each evaluator and style combination makes the test suite clear and maintainable.
tests/apply/test_runtime.py (2)
438-447: Nice mock design for DPS testing.The
FakeTensorWithFillclass elegantly simulates tensor mutation for destination-passing style tests. Thefill_method allows verification that DPS functions correctly modify their output arguments in place.
296-533: LGTM! Comprehensive test coverage for dispatch with both calling conventions.The test suite thoroughly validates:
- Args/kwargs handling and merging
- Value-returning style dispatch and results
- Destination-passing style with output mutation
- Fallback behavior for cache misses
- Policy-based routing for both styles
The organization into separate test classes for different aspects (args/kwargs, calling conventions, DPS style) makes the test suite easy to navigate and maintain.
flashinfer_bench/apply/runtime.py (2)
157-171: Argument validation logic looks correct.The DPS detection based on argument count comparison with
num_inputsandnum_inputs + num_outputsis well-designed. The early extraction ofinput_argsfor key building correctly separates inputs from outputs in DPS mode.
204-209: LGTM!The dual calling convention handling is correctly implemented. DPS uses full
argsincluding pre-allocated outputs and returnsNone, while value-returning uses onlyinput_argsand returns the result.flashinfer_bench/apply/apply_api.py (2)
134-180: LGTM!The
_dispatch_apply_or_tracinghelper cleanly centralizes the dispatch logic for both decorator and function modes. The priority order (apply → tracing → fallback) and the resolver invocation pattern are well-designed.
123-131: LGTM!The decorator mode correctly uses
@wraps(fallback)to preserve function metadata and cleanly delegates to the centralized dispatch helper.flashinfer_bench/utils.py (2)
29-46: LGTM!The lazy dtype mapping with
@lru_cache(maxsize=1)is a good pattern for deferring torch import until first use while ensuring the mapping is built only once. This improves import-time performance for modules that don't immediately need torch.
67-79: LGTM!Local torch imports in
is_cuda_availableandlist_cuda_devicesare consistent with the lazy import strategy established in this file.flashinfer_bench/bench/evaluators/evaluator.py (2)
41-63: LGTM!The signature updates for
check_correctnessandeval_performanceto useList[List[Any]]for inputs andList[List[torch.Tensor]]for ref_outputs are consistent with the PR's refactoring to list-based data structures. Addingdefinition: Definitiontoeval_performanceenables DPS-aware output allocation in implementations.
77-97: LGTM!The
evaluatemethod correctly passes thedefinitionparameter to bothcheck_correctnessandeval_performance, maintaining consistency with the updated abstract method signatures.flashinfer_bench/bench/evaluators/lowbit.py (2)
46-62: LGTM!The DPS and value-returning execution paths are correctly implemented. The DPS path properly allocates outputs first using
allocate_outputs, while the VR path normalizes the result afterward. Both paths correctly usetorch.no_grad()and synchronize before proceeding.
77-103: LGTM!The per-tensor validation logic is thorough: it checks for non-finite values first, then computes error statistics. The aggregation of max errors and minimum matched ratio across all tensors provides a comprehensive correctness assessment.
flashinfer_bench/tracing/config.py (1)
24-60: Well-structuredFilterPolicyProtocol.The Protocol clearly defines the contract for filter policies with comprehensive docstrings. This aligns well with the existing implementations in
flashinfer_bench/tracing/builtin/policies.py.flashinfer_bench/tracing/runtime.py (1)
92-124: Clean refactor ofcollect()to support both calling conventions.The signature change from
runtime_args: Dict[str, Any]to(args, kwargs)with proper validation for VR vs DPS calling conventions is well-implemented. The error handling with early returns prevents runtime crashes during tracing.flashinfer_bench/apply/key.py (2)
11-22: Clean migration to Pydantic frozen model.Using
ConfigDict(frozen=True)provides immutability for use as dict keys and in sets, replacing the manual__hash__/__eq__implementations. The Pydantic model also provides built-in JSON serialization viamodel_dump_json()andmodel_validate_json().
48-58: Builder methods updated for positional args interface.The
build_from_argsandfeaturesmethods now acceptTuple[Any, ...]instead of dict-based runtime kwargs. The unusedargsparameter infeatures()at line 57 is intentional—it maintains the abstract interface contract even when the implementation doesn't need the value.flashinfer_bench/compile/runnable.py (2)
27-36: Updated metadata fields support DPS/VR distinction.The rename from
definition/solutiontodefinition_name/solution_nameimproves clarity, and the newdestination_passing_styleflag enables runtime branching between calling conventions.
98-121: Clean return value normalization logic.The
_revise_return_valuehelper properly handles edge cases: pass-through for non-tuples, unpacking single-element tuples, and converting empty tuples toNone.flashinfer_bench/bench/evaluators/default.py (3)
46-64: Baseline generation correctly uses value-returning style.The reference implementation is always value-returning, and the new
normalize_resulthelper properly converts outputs to a consistentList[torch.Tensor]format for comparison.
98-118: Clean DPS vs VR branching in correctness checking.The conditional logic properly handles both calling conventions: allocating outputs for DPS and normalizing results for VR. The structure is clear and maintainable.
186-197: Performance evaluation correctly handles DPS output allocation.The DPS path properly allocates outputs and includes them in the timing arguments, ensuring accurate latency measurement that includes any output tensor initialization overhead.
flashinfer_bench/bench/evaluators/sampling.py (4)
26-39: LGTM! Clear helper utilities for name-based input access.The helper functions provide a clean abstraction for accessing inputs by name from the list-based structure, aligning with the DPS refactoring pattern used across evaluators.
66-81: LGTM! Clean list-based baseline construction.The baseline correctly stores inputs and outputs as
List[List[Any]]andList[List[torch.Tensor]]respectively, with expected_probs stored as the first tensor in the outputs list, aligning with the DPS-oriented structure.
130-147: DPS branch correctly integrated.The DPS handling matches the pattern in
DefaultEvaluator.check_correctness(seedefault.pylines 90-103), properly allocating outputs viaallocate_outputsand invoking the runnable with combined input/output tensors.The broad
Exceptioncatch at line 140 is consistent with the error handling pattern used across all evaluators in this codebase to capture runtime failures.
362-373: LGTM! DPS sampling loop correctly implemented.The DPS branch properly allocates outputs for the padded inputs and invokes the runnable with combined arguments, while the non-DPS branch normalizes the returned result. Both paths correctly extract the samples tensor from
out[0].flashinfer_bench/data/definition.py (4)
15-16: LGTM! Proper use of TYPE_CHECKING for torch import.The
TYPE_CHECKINGguard correctly avoids importing torch at runtime while still enabling type hints for thetorch.dtypereturn types in the cached properties.
158-169: LGTM! Important validation for DPS support.This validator ensures input and output names don't overlap, which is critical for the DPS calling convention where outputs are passed as additional arguments and need unique identifiers.
450-457: Silent break on missing kwargs may be unexpected.The loop breaks at line 457 when a parameter name is not found in
kwargs. This silently stops merging if kwargs don't form a contiguous sequence after the positional args. This behavior might be intentional for partial application, but could also mask errors where expected kwargs are accidentally omitted.Consider whether raising an error or documenting this behavior would be more appropriate.
408-428: Unable to independently verify DType enum conversion claims.Repository access failed, preventing verification of: (1) whether DType inherits from
(str, Enum)at line 49, (2) thedtype_str_to_torch_dtypefunction signature and implementation, and (3) whether the code correctly handlesspec.dtypeas a DType enum. The review comment's conclusion that the code is correct cannot be confirmed without access to the actual codebase. Web search provided only generic flashinfer API patterns, not the specific implementation details needed.
| if fallback is None: | ||
| raise RuntimeError(f"Definition '{def_name}' not found and no fallback provided") | ||
| return fallback(**runtime_kwargs) | ||
| return fallback(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent fallback invocation: kwargs dropped on miss path.
When the definition is not found (line 155), fallback is called with *args, **kwargs. However, when no runnable is available after miss policy (line 202), fallback is called with only *args, dropping kwargs. This inconsistency could cause unexpected behavior when fallback functions expect keyword arguments.
Apply this diff to make fallback invocation consistent:
if runnable is None:
if fallback is None:
raise RuntimeError(f"Apply miss for '{def_name}' and no fallback provided")
- return fallback(*args)
+ return fallback(*args, **kwargs)Also applies to: 202-202
🤖 Prompt for AI Agents
In flashinfer_bench/apply/runtime.py around lines 155 and 202, the fallback
invocation is inconsistent: line 155 calls fallback(*args, **kwargs) but line
202 calls fallback(*args) which drops keyword arguments; update the miss-path at
line 202 to call fallback with the same signature (fallback(*args, **kwargs)) so
kwargs are preserved on both miss paths, and run tests to ensure no callers rely
on the older behavior.
| def _quantile(a: List[float], q: Sequence[float]) -> List[float]: | ||
| """Compute quantiles of a list of values.""" | ||
| n = len(a) | ||
| a = sorted(a) | ||
|
|
||
| def get_quantile(q: float) -> float: | ||
| if not (0 <= q <= 1): | ||
| raise ValueError("Quantiles must be in the range [0, 1]") | ||
| point = q * (n - 1) | ||
| lower = math.floor(point) | ||
| upper = math.ceil(point) | ||
| t = point - lower | ||
| return (1 - t) * a[lower] + t * a[upper] | ||
|
|
||
| return [get_quantile(qi) for qi in q] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle empty input list edge case.
If a is empty, accessing a[lower] and a[upper] will raise an IndexError. Consider adding an early guard.
def _quantile(a: List[float], q: Sequence[float]) -> List[float]:
"""Compute quantiles of a list of values."""
n = len(a)
+ if n == 0:
+ raise ValueError("Cannot compute quantiles of an empty list")
a = sorted(a)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _quantile(a: List[float], q: Sequence[float]) -> List[float]: | |
| """Compute quantiles of a list of values.""" | |
| n = len(a) | |
| a = sorted(a) | |
| def get_quantile(q: float) -> float: | |
| if not (0 <= q <= 1): | |
| raise ValueError("Quantiles must be in the range [0, 1]") | |
| point = q * (n - 1) | |
| lower = math.floor(point) | |
| upper = math.ceil(point) | |
| t = point - lower | |
| return (1 - t) * a[lower] + t * a[upper] | |
| return [get_quantile(qi) for qi in q] | |
| def _quantile(a: List[float], q: Sequence[float]) -> List[float]: | |
| """Compute quantiles of a list of values.""" | |
| n = len(a) | |
| if n == 0: | |
| raise ValueError("Cannot compute quantiles of an empty list") | |
| a = sorted(a) | |
| def get_quantile(q: float) -> float: | |
| if not (0 <= q <= 1): | |
| raise ValueError("Quantiles must be in the range [0, 1]") | |
| point = q * (n - 1) | |
| lower = math.floor(point) | |
| upper = math.ceil(point) | |
| t = point - lower | |
| return (1 - t) * a[lower] + t * a[upper] | |
| return [get_quantile(qi) for qi in q] |
🧰 Tools
🪛 Ruff (0.14.7)
79-79: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In flashinfer_bench/bench/timing.py around lines 72-86, the _quantile function
will IndexError on an empty input list; add an early guard that checks if n == 0
and raise a clear ValueError (e.g. "Cannot compute quantiles of an empty list")
before sorting/processing so callers get a meaningful error instead of an
IndexError; keep the rest of the quantile logic unchanged.
| def _get_empty_cache_for_benchmark() -> torch.Tensor: | ||
| """Create a buffer for clearing L2 cache before benchmark runs. | ||
| We maintain a buffer of 256 MB that we clear before each kernel call | ||
| to make sure that the L2 cache doesn't contain any input data before the run. | ||
| """ | ||
| cache_size = 256 * 1024 * 1024 | ||
| return torch.empty(int(cache_size // 4), dtype=torch.int, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cache buffer uses hardcoded "cuda" device.
The cache tensor is always allocated on the default CUDA device ("cuda"), but time_runnable accepts a specific device parameter (e.g., "cuda:1"). If the benchmark runs on a non-default GPU, the cache-clearing operation may not effectively clear that GPU's L2 cache, and cross-device synchronization could introduce overhead.
Consider parameterizing the device:
-def _get_empty_cache_for_benchmark() -> torch.Tensor:
+def _get_empty_cache_for_benchmark(device: str = "cuda") -> torch.Tensor:
"""Create a buffer for clearing L2 cache before benchmark runs.
We maintain a buffer of 256 MB that we clear before each kernel call
to make sure that the L2 cache doesn't contain any input data before the run.
"""
cache_size = 256 * 1024 * 1024
- return torch.empty(int(cache_size // 4), dtype=torch.int, device="cuda")
+ return torch.empty(int(cache_size // 4), dtype=torch.int, device=device)This would require propagating the device through do_bench as well.
🤖 Prompt for AI Agents
In flashinfer_bench/bench/timing.py around lines 113-120 the cache buffer is
always allocated on the default "cuda" device which breaks when benchmarking on
a non-default GPU (e.g., "cuda:1"); change _get_empty_cache_for_benchmark to
accept a device argument and allocate the tensor on that device, then propagate
that device through the benchmarking call path (update do_bench and any callers
such as time_runnable to pass the selected device) so the cache-clear buffer is
created on the same CUDA device as the workload and avoids cross-device
sync/overhead.
| # Check return annotation | ||
| num_outputs = len(definition.outputs) | ||
| ret_ann = signature.return_annotation | ||
| if ret_ann is not inspect.Signature.empty: | ||
| origin = getattr(ret_ann, "__origin__", None) | ||
| if origin is tuple: | ||
| # If returning tuple, length must match num_outputs | ||
| args = getattr(ret_ann, "__args__", None) | ||
| if args is not None and len(args) != num_outputs: | ||
| raise BuildError( | ||
| f"Value-returning style callable with {num_outputs} outputs must " | ||
| f"return a {num_outputs}-element tuple, got {ret_ann}" | ||
| ) | ||
| elif origin is None: | ||
| # If return None, num_outputs must be 0 | ||
| if num_outputs != 0: | ||
| raise BuildError( | ||
| f"Value-returning style callable with {num_outputs} outputs must " | ||
| f"return None, got {ret_ann}" | ||
| ) | ||
| elif num_outputs != 1: | ||
| # If returning non-tuple, num_outputs must be 1 | ||
| raise BuildError( | ||
| f"Value-returning style callable returning non-tuple must " | ||
| f"only have one output, got {num_outputs}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logic bug: origin is None conflates -> None with non-generic types like -> int.
The condition at line 182 (elif origin is None) is true both when:
- The return annotation is
None(i.e.,-> None) - The return annotation is a non-generic type like
int,float,torch.Tensor
For a function with signature def foo(x) -> torch.Tensor and 1 output, origin is None, so the code enters line 182 and raises an error at line 184 because num_outputs (1) != 0.
Consider checking for None type explicitly:
- elif origin is None:
- # If return None, num_outputs must be 0
- if num_outputs != 0:
- raise BuildError(
- f"Value-returning style callable with {num_outputs} outputs must "
- f"return None, got {ret_ann}"
- )
- elif num_outputs != 1:
+ elif ret_ann is type(None):
+ # If return None, num_outputs must be 0
+ if num_outputs != 0:
+ raise BuildError(
+ f"Value-returning style callable with {num_outputs} outputs must "
+ f"not return None when outputs are expected"
+ )
+ elif origin is None and num_outputs != 1:
# If returning non-tuple, num_outputs must be 1
raise BuildError(
f"Value-returning style callable returning non-tuple must "
f"only have one output, got {num_outputs}"
)🧰 Tools
🪛 Ruff (0.14.7)
178-181: Avoid specifying long messages outside the exception class
(TRY003)
185-188: Avoid specifying long messages outside the exception class
(TRY003)
191-194: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In flashinfer_bench/compile/builder.py around lines 169-194, the code wrongly
treats any non-generic return annotation (origin is None) as if it were the
explicit None return; change the branch so you explicitly check for an exact
None return (e.g., ret_ann is None or ret_ann is type(None)) and handle the
"num_outputs must be 0" error only in that case, otherwise fall through to the
non-tuple handling for non-generic types (int, float, torch.Tensor) so a
single-output non-tuple return is allowed when num_outputs == 1.
| for result, output in zip(result, args_output): | ||
| if not isinstance(result, torch.Tensor) or not isinstance(output, torch.Tensor): | ||
| raise ValueError( | ||
| "Destination-passing style callable must return a tuple of tensors, got " | ||
| f"{type(result)} and {type(output)}" | ||
| ) | ||
| output.copy_(result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Loop variable result shadows the iterable.
The loop for result, output in zip(result, args_output) reuses result as both the iterable and the loop variable. This works by accident because Python evaluates the iterable before binding loop variables, but it's confusing and error-prone. Additionally, zip() should use strict=True since mismatched lengths indicate a bug.
- for result, output in zip(result, args_output):
- if not isinstance(result, torch.Tensor) or not isinstance(output, torch.Tensor):
+ for res_tensor, out_tensor in zip(result, args_output, strict=True):
+ if not isinstance(res_tensor, torch.Tensor) or not isinstance(out_tensor, torch.Tensor):
raise ValueError(
"Destination-passing style callable must return a tuple of tensors, got "
- f"{type(result)} and {type(output)}"
+ f"{type(res_tensor)} and {type(out_tensor)}"
)
- output.copy_(result)
+ out_tensor.copy_(res_tensor)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for result, output in zip(result, args_output): | |
| if not isinstance(result, torch.Tensor) or not isinstance(output, torch.Tensor): | |
| raise ValueError( | |
| "Destination-passing style callable must return a tuple of tensors, got " | |
| f"{type(result)} and {type(output)}" | |
| ) | |
| output.copy_(result) | |
| for res_tensor, out_tensor in zip(result, args_output, strict=True): | |
| if not isinstance(res_tensor, torch.Tensor) or not isinstance(out_tensor, torch.Tensor): | |
| raise ValueError( | |
| "Destination-passing style callable must return a tuple of tensors, got " | |
| f"{type(res_tensor)} and {type(out_tensor)}" | |
| ) | |
| out_tensor.copy_(res_tensor) |
🧰 Tools
🪛 Ruff (0.14.7)
196-196: Loop control variable result overrides iterable it iterates
(B020)
196-196: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
198-201: Prefer TypeError exception for invalid type
(TRY004)
198-201: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In flashinfer_bench/compile/runnable.py around lines 196-202, the for-loop
reuses the name `result` for both the iterable and the loop variable which is
confusing and error-prone; rename the iterable (e.g., `results` or
`result_iter`) and iterate with a distinct loop variable (e.g., `result_item`)
and use zip(..., strict=True) to catch length mismatches; keep the existing type
checks and the output.copy_(result_item) call but replace references to the
renamed variables accordingly.
| InputDumpPolicyFunction = Callable[Dict[str, Any], List[str]] | ||
| """Function that selects which inputs to dump from input names and values.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Invalid Callable type alias syntax.
The type alias uses incorrect bracket syntax for Callable. This will cause a runtime TypeError when Python evaluates the type hint.
-InputDumpPolicyFunction = Callable[Dict[str, Any], List[str]]
+InputDumpPolicyFunction = Callable[[Dict[str, Any]], List[str]]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| InputDumpPolicyFunction = Callable[Dict[str, Any], List[str]] | |
| """Function that selects which inputs to dump from input names and values.""" | |
| InputDumpPolicyFunction = Callable[[Dict[str, Any]], List[str]] | |
| """Function that selects which inputs to dump from input names and values.""" |
🤖 Prompt for AI Agents
In flashinfer_bench/tracing/config.py around lines 16-17, the type alias uses
incorrect Callable syntax; change the alias to use a list for the argument
types, e.g. InputDumpPolicyFunction = Callable[[Dict[str, Any]], List[str]], and
ensure typing.Callable is imported (or adjust imports) so the evaluated type
hint is valid.
This PR supports DPS for solutions. It also supports the substituted function in apply() and tracing() to be DPS. It refactors the apply system, the tracing system, timing system, benchmark system to support to changes.
Signed-off-by: Ubospica ubospica@gmail.com
Summary by CodeRabbit
Release Notes
✏️ Tip: You can customize this high-level summary in your review settings.