Skip to content

Conversation

@Ubospica
Copy link
Collaborator

@Ubospica Ubospica commented Dec 4, 2025

This PR supports DPS for solutions. It also supports the substituted function in apply() and tracing() to be DPS. It refactors the apply system, the tracing system, timing system, benchmark system to support to changes.

Signed-off-by: Ubospica ubospica@gmail.com

Summary by CodeRabbit

Release Notes

  • Refactor
    • Refactored core APIs to support both value-returning and destination-passing calling conventions for improved flexibility.
    • Converted input/output handling from dictionary-based to positional list-based structures for clearer argument passing.
    • Enhanced benchmarking framework with improved statistical analysis and per-device synchronization.
    • Reorganized internal data structures and validation workflows for better maintainability.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Ubospica <ubospica@gmail.com>
@coderabbitai
Copy link

coderabbitai bot commented Dec 4, 2025

Walkthrough

This PR refactors the apply, benchmark, and compile subsystems to support dual calling conventions (value-returning and destination-passing styles). It reorganizes input/output data structures from dictionaries to lists, migrates ApplyKey from dataclass to Pydantic-based frozen model, updates tracing runtime to accept explicit args/kwargs, and renames parameters across runners and utilities for consistency.

Changes

Cohort / File(s) Summary
Apply API Refactoring
flashinfer_bench/apply/apply_api.py, flashinfer_bench/apply/key.py, flashinfer_bench/apply/runtime.py, flashinfer_bench/apply/table.py
Reworks apply API from runtime_kwargs to positional args/kwargs; replaces ApplyKey dataclass with Pydantic frozen model; updates ApplyKeyBuilder to use args instead of runtime_kwargs; refactors dispatch to handle both DPS and value-returning calling conventions; updates key serialization to use Pydantic model methods.
Benchmark Evaluator Updates
flashinfer_bench/bench/evaluators/evaluator.py, flashinfer_bench/bench/evaluators/default.py, flashinfer_bench/bench/evaluators/lowbit.py, flashinfer_bench/bench/evaluators/sampling.py
Changes input/output representations from dictionaries to lists; adds destination_passing_style branching for DPS vs. value-returning paths; introduces normalize_result and allocate_outputs utilities; updates correctness checks and performance timing to handle both calling conventions.
New Evaluator Utilities
flashinfer_bench/bench/evaluators/utils.py
Adds allocate_outputs and normalize_result functions to provision and normalize output tensors for both calling conventions.
Benchmark Utilities & Timing
flashinfer_bench/bench/utils.py, flashinfer_bench/bench/timing.py
Renames wl to workload; updates gen_inputs to return List instead of Dict; introduces quantile-based statistics and do_bench function with per-device locking; refactors time_runnable to accept List args instead of dict inputs.
Runner & Benchmark
flashinfer_bench/bench/benchmark.py, flashinfer_bench/bench/runner/runner.py, flashinfer_bench/bench/runner/isolated_runner.py, flashinfer_bench/bench/runner/persistent_runner.py
Renames workload parameter from wl to workload; updates DeviceBaseline to use List-based inputs/outputs instead of dicts.
Data Models & Definitions
flashinfer_bench/data/definition.py, flashinfer_bench/data/solution.py, flashinfer_bench/data/trace.py, flashinfer_bench/data/workload.py, flashinfer_bench/data/__init__.py
Extends Definition with torch dtype utilities, axes inference, and kwargs merging; adds BuildSpec.destination_passing_style flag; introduces new Workload module with RandomInput, ScalarInput, SafetensorsInput; relocates workload-related exports.
Compile & Runnable Refactoring
flashinfer_bench/compile/runnable.py, flashinfer_bench/compile/builder.py, flashinfer_bench/compile/builders/python_builder.py, flashinfer_bench/compile/builders/torch_builder.py, flashinfer_bench/compile/builders/tvm_ffi_builder.py
Updates RunnableMetadata with definition_name/solution_name and destination_passing_style flag; refactors Runnable.call to use positional args; adds call_destination_passing, call_value_returning, and signature validation; removes kwarg adapters.
Tracing Infrastructure
flashinfer_bench/tracing/config.py, flashinfer_bench/tracing/runtime.py, flashinfer_bench/tracing/builtin/policies.py, flashinfer_bench/tracing/__init__.py, flashinfer_bench/tracing/filter_policy.py
Introduces FilterPolicy protocol and InputDumpPolicyFunction in config; updates get_inputs_to_dump signature to accept names/values lists; refactors TracingRuntime.collect to use args/kwargs; simplifies argument-to-tensor conversion.
Utility Updates
flashinfer_bench/utils.py
Introduces lazy dtype mapping via _get_dtype_str_to_torch_dtype with lru_cache; defers torch imports to function bodies.
Extensive Test Updates
tests/apply/test_runtime.py, tests/bench/test_evaluator.py, tests/bench/test_isolated_runner.py, tests/bench/test_persistent_runner.py, tests/compile/test_builder.py, tests/compile/test_runnable.py, tests/data/test_definition.py, tests/data/test_load_dump_schema.py, tests/data/test_trace.py, tests/tracing/test_tracing_config.py, tests/tracing/test_tracing_runtime.py
Adds DPS/value-returning style test variants; updates mock factories; changes test inputs from dicts to lists; adds new test suites for args/kwargs merging, calling conventions, and axes inference.
Example Updates
examples/win_at_p.py
Renames variable wl to workload in collect_runs.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Areas requiring extra attention:

  • flashinfer_bench/apply/runtime.py — Dense dispatch logic with argument validation, DPS vs. value-returning branching, and fallback handling; critical for apply correctness
  • flashinfer_bench/compile/runnable.py — Significant refactoring of call signatures and internal argument resolution; interactions with Definition and calling conventions
  • flashinfer_bench/data/definition.py — New validation workflows, axes inference helpers, and utilities (merge_kwargs_to_args, get_axes_values_from_inputs); impacts downstream code
  • flashinfer_bench/bench/evaluators/default.py & lowbit.py — Per-evaluator DPS/VR branching and output handling logic across multiple code paths
  • flashinfer_bench/tracing/runtime.py — Updated collect signature and argument-to-tensor conversion; cascading impact on tracing pipeline
  • flashinfer_bench/bench/timing.py — New statistical summarization and per-device locking; potential concurrency implications

Possibly related PRs

  • [tvm-ffi] TVMFFIBuilder #111 — Modifies TVM FFI builder and related metadata handling; shares apply/compile builder infrastructure changes
  • refactor: Builder System #120 — Updates compile/build subsystem (Runnable, RunnableMetadata, builders); overlapping refactor of builder registration and metadata
  • Fix adapter integration #65 — Touches apply/tracing integration and adapter call paths; related to changes in how apply dispatches and tracing flushes data

Poem

🐰 From dicts to lists, our args now flow,
Destination-passing styles steal the show!
Frozen pydantic keys and axes align,
Refactored builders make runnables shine.
A rabbit's delight: calling conventions so fine! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.80% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'feat: Destination-passing Style' accurately captures the primary change—adding DPS support throughout the codebase across apply, tracing, timing, and benchmark systems.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Ubospica, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive support for Destination-Passing Style (DPS) across the FlashInfer-Bench framework. It refactors the core apply and tracing mechanisms, updates data models for definitions and workloads, and adapts the benchmarking and compilation systems to seamlessly handle both value-returning and destination-passing function conventions. This enhancement allows for greater flexibility in defining and evaluating kernel solutions, ensuring compatibility with various optimization techniques.

Highlights

  • Destination-Passing Style (DPS) Support: Solutions can now be defined and executed using Destination-Passing Style, where output tensors are pre-allocated and passed as arguments, alongside the existing value-returning style.
  • Refactored apply() and tracing() Systems: The core apply() and tracing() mechanisms have been significantly refactored to seamlessly support both value-returning and destination-passing styles for substituted functions.
  • Unified Argument Handling: The apply function and tracing runtime now accept positional arguments (args) and keyword arguments (kwargs), simplifying the API and enabling flexible argument passing. Keyword arguments are merged into positional arguments based on definition order.
  • Updated Data Models and Utilities: The Definition model now explicitly manages input/output order and provides new utilities for inferring axis values from inputs and merging arguments. Workload-related data models have been reorganized into a dedicated workload.py file.
  • Benchmarking System Adaptation: The benchmarking system, including evaluators, runners, and timing utilities, has been updated to correctly handle both calling conventions (value-returning and DPS) during input generation, correctness checks, and performance evaluation.
  • Enhanced Runnable Interface: The Runnable class now explicitly tracks the calling convention (DPS or value-returning) and provides dedicated methods (call_destination_passing, call_value_returning, call_kwargs) to execute compiled solutions in the appropriate style or convert between them.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Signed-off-by: Ubospica <ubospica@gmail.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant refactoring to support Destination-Passing Style (DPS) for solutions, which is a great enhancement for performance-oriented code. The changes span across the apply, tracing, timing, and benchmark systems, and the overall design for handling different calling conventions is clean and well-abstracted, especially within the Runnable class. I've identified a few issues related to argument handling and documentation consistency that should be addressed to improve the robustness and clarity of the new APIs. My main concerns are around ensuring duplicate arguments are handled correctly and that fallback functions are called in a consistent and flexible manner.

Comment on lines +158 to +159
if kwargs:
args = definition.merge_kwargs_to_args(args, kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The args tuple is being reassigned here, which causes the original positional arguments to be lost. This becomes an issue later when calling the fallback function, as it can no longer be called with the original *args and **kwargs, breaking consistency and flexibility. It's better to store the result of the merge in a new variable, e.g., merged_args, and use that for subsequent logic related to the specialized kernel call, while preserving the original args and kwargs for the fallback call. You will need to update subsequent uses of args to merged_args.

Suggested change
if kwargs:
args = definition.merge_kwargs_to_args(args, kwargs)
merged_args = args
if kwargs:
merged_args = definition.merge_kwargs_to_args(args, kwargs)

Comment on lines +447 to +458
if not kwargs:
return args

param_names = list(self.inputs.keys()) + list(self.outputs.keys())
result = list(args)
for i in range(len(args), len(param_names)):
name = param_names[i]
if name in kwargs:
result.append(kwargs[name])
else:
break
return tuple(result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The merge_kwargs_to_args method does not check for duplicate arguments that are provided both positionally and as a keyword. This can lead to surprising behavior where a keyword argument is silently ignored if it corresponds to a parameter already provided positionally. A robust implementation should raise a TypeError in this case, similar to how Python's function calls behave.

        if not kwargs:
            return args

        param_names = list(self.inputs.keys()) + list(self.outputs.keys())

        if len(args) > len(param_names):
            raise TypeError(f"Too many positional arguments: got {len(args)}, expected at most {len(param_names)}")

        # Check for duplicate arguments
        positional_arg_names = set(param_names[:len(args)])
        for name in kwargs:
            if name in positional_arg_names:
                raise TypeError(f"Got multiple values for argument '{name}'")

        result = list(args)
        for i in range(len(args), len(param_names)):
            name = param_names[i]
            if name in kwargs:
                result.append(kwargs[name])
            else:
                break
        return tuple(result)

Comment on lines +58 to 59
The kernel name, or a resolver ``fn(*args) -> str`` that maps runtime
arguments to a kernel name (definition name).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for def_name_or_resolver states its signature is fn(*args) -> str, but the implementation in _dispatch_apply_or_tracing calls it with *args, **kwargs. The docstring should be updated to match the more flexible implementation.

Suggested change
The kernel name, or a resolver ``fn(*args) -> str`` that maps runtime
arguments to a kernel name (definition name).
The kernel name, or a resolver ``fn(*args, **kwargs) -> str`` that maps runtime
arguments to a kernel name (definition name).

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
flashinfer_bench/apply/apply_api.py (1)

208-208: Update docstring examples to use new API.

The docstring examples in enable_apply still reference the old runtime_kwargs parameter, which no longer exists in the apply() function signature. Update these to use the new args/kwargs parameters.

     >>> # Direct usage
     >>> enable_apply("/path/to/traceset", cfg)
     >>> # Apply is now enabled
-    >>> out = apply("rmsnorm_d4096", runtime_kwargs={...}, fallback=ref_fn)
+    >>> out = apply("rmsnorm_d4096", args=(...,), fallback=ref_fn)
     >>> disable_apply()
     >>> # Apply is now disabled.

     >>> # Context manager usage
     >>> with enable_apply("/path/to/traceset", cfg):
-    ...     out = apply("rmsnorm_d4096", runtime_kwargs={...}, fallback=ref_fn)
+    ...     out = apply("rmsnorm_d4096", args=(...,), fallback=ref_fn)
     >>> # Apply is now disabled.

Also applies to: 214-214

flashinfer_bench/tracing/runtime.py (1)

196-224: Stale docstring: parameters no longer match signature.

The docstring still references definition, axes, and name parameters that were removed. Update to reflect the new val and dtype parameters.

     def _convert_arg_to_tensor(
         self, val: Union[int, float, bool, list, tuple, torch.Tensor], dtype: str
     ) -> Optional[torch.Tensor]:
         """Convert a runtime argument to a tensor for further dumping. If the conversion fails,
         log an error and return None.

         Parameters
         ----------
-        definition : Definition
-            The workload definition containing axis specifications.
-        axes : Dict[str, int]
-            Runtime axis values provided during tracing.
-        name : str
-            Name of the argument to convert.
-        val : Any
-            The runtime argument to convert.
+        val : Union[int, float, bool, list, tuple, torch.Tensor]
+            The runtime argument to convert (scalar, list, or tensor).
+        dtype : str
+            Target dtype string for scalar/list conversion.

         Returns
         -------
-        Optional[torch.Tensor]
-            The converted tensor. None if conversion fails.
+        torch.Tensor
+            The converted tensor.
+
+        Raises
+        ------
+        ValueError
+            If the value type is unsupported.
         """
🧹 Nitpick comments (15)
flashinfer_bench/bench/timing.py (2)

163-163: Consider validating rep > 0.

If rep=0 is passed, times will be empty, causing failures in _summarize_statistics (e.g., min() on empty sequence). Adding a validation ensures clearer error messages.

     assert return_mode in ["min", "max", "mean", "median", "all"]
+    assert rep > 0, "rep must be at least 1"

202-202: Add strict=True to zip for safety.

Both start_events and end_events are guaranteed to have the same length here, but adding strict=True is a good defensive practice and satisfies the static analysis hint.

-    times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
+    times = [s.elapsed_time(e) for s, e in zip(start_events, end_events, strict=True)]
flashinfer_bench/bench/benchmark.py (1)

119-155: LGTM - Consistent workload naming throughout benchmark execution.

The variable rename from wl to workload improves code readability and aligns with the project-wide standardization. All references to workload.uuid are consistent, and the functional logic remains unchanged.

Optional: Consider using logging.exception for better error context.

On line 136, you could use logger.exception instead of logger.error to automatically include the exception traceback:

-                    logger.error(f"Failed to run workload {workload.uuid}: {e}")
+                    logger.exception(f"Failed to run workload {workload.uuid}")

This would provide more debugging context when workload execution fails.

flashinfer_bench/bench/evaluators/utils.py (1)

38-41: Consider adding strict=True to zip for defensive programming.

The zip() call pairs output_shapes with dtypes. Adding strict=True ensures both lists have the same length and catches potential inconsistencies early.

Apply this diff:

     dtypes = definition.torch_output_dtypes
     return [
         torch.empty(shape, dtype=dtype, device=device)
-        for shape, dtype in zip(output_shapes, dtypes)
+        for shape, dtype in zip(output_shapes, dtypes, strict=True)
     ]
flashinfer_bench/bench/evaluators/lowbit.py (1)

66-66: Add strict=True to zip to catch length mismatches.

The zip(out, ref_out) without strict=True could silently ignore cases where the number of output tensors differs between the solution and reference. This could mask bugs where a solution produces a different number of outputs.

-            for sol_tensor, ref_tensor in zip(out, ref_out):
+            for sol_tensor, ref_tensor in zip(out, ref_out, strict=True):
flashinfer_bench/tracing/config.py (2)

141-141: Add strict=True to zip() to catch length mismatches.

If names and values have different lengths, zip() will silently truncate. Since this indicates a programming error, it should fail explicitly.

-        name_to_value = dict(zip(names, values))
+        name_to_value = dict(zip(names, values, strict=True))

148-151: Consider using TypeError for type validation errors.

When the input_dump_policy is neither a list nor callable, or when the callable returns a non-list, these are type errors rather than value errors. Static analysis suggests TypeError is more appropriate here.

         else:
-            raise ValueError("input_dump_policy must be a list of strings or a callable")
+            raise TypeError("input_dump_policy must be a list of strings or a callable")
 
         if not isinstance(names_to_dump, list):
-            raise ValueError("input_dump_policy callable must return a list of strings")
+            raise TypeError("input_dump_policy callable must return a list of strings")
flashinfer_bench/tracing/runtime.py (1)

159-161: Use logger.exception to preserve stack trace.

When catching exceptions, logger.exception automatically includes the stack trace, which aids debugging. The same applies to line 173.

         except ValueError as e:
-            logger.error(f"Error getting axis values for {def_name}: {e}")
+            logger.exception(f"Error getting axis values for {def_name}")
             return

Similarly for line 173:

         except ValueError as e:
-            logger.error(f"Error converting argument '{name}' to tensor for {def_name}: {e}")
+            logger.exception(f"Error converting argument '{name}' to tensor for {def_name}")
             return
flashinfer_bench/compile/runnable.py (2)

224-226: Add strict=True to zip() for shape/dtype pairing.

If output_shapes and dtype_list have different lengths (indicating a definition inconsistency), this should fail explicitly rather than silently truncating.

-        for shape, dtype in zip(output_shapes, dtype_list):
+        for shape, dtype in zip(output_shapes, dtype_list, strict=True):
             shape = shape if shape is not None else ()
             output_tensors.append(torch.empty(shape, dtype=dtype, device=device))

127-129: Minor: Unnecessary string split in error message.

The string concatenation "...must " "be set." can be a single string.

             raise ValueError(
-                "When calling in keyword passing style, metadata.full_definition must " "be set."
+                "When calling in keyword passing style, metadata.definition must be set."
             )

Note: The error message also references full_definition but the attribute is definition.

flashinfer_bench/bench/evaluators/default.py (1)

122-122: Add strict=True to zip() for output comparison.

If out and ref_out have different lengths, this indicates a mismatch in the number of outputs between the solution and reference, which should be caught explicitly rather than silently skipping outputs.

-            for sol_tensor, ref_tensor in zip(out, ref_out):
+            for sol_tensor, ref_tensor in zip(out, ref_out, strict=True):
flashinfer_bench/bench/evaluators/sampling.py (1)

328-353: Consider adding strict=True to zip() for safety.

The zip(input_names, inputs) at line 330 assumes both iterables have the same length. While this should always be true by construction (both derived from the same definition), adding strict=True would catch any inconsistency early.

-    for name, value in zip(input_names, inputs):
+    for name, value in zip(input_names, inputs, strict=True):
flashinfer_bench/data/definition.py (3)

279-296: Add strict=True to outer zip for safety.

At line 279, the zip between self.inputs.items() and input_shapes should have matching lengths. Adding strict=True would catch mismatches early rather than silently truncating.

The inner zip at line 287 is protected by the dimension check at lines 282-286, so it's less critical.

-        for (inp_name, inp_spec), inp_shape in zip(self.inputs.items(), input_shapes):
+        for (inp_name, inp_spec), inp_shape in zip(self.inputs.items(), input_shapes, strict=True):

305-321: Minor: Docstring references old method name.

Line 308 mentions "get_var_axes_values" but the actual method is get_axes_values. This appears to be a documentation artifact from renaming.

-        Convenience method that combines extract_shapes and get_var_axes_values.
+        Convenience method that combines shape extraction and get_axes_values.

180-186: Redundant None check on constraints.

Since constraints is defined with default_factory=list (line 132), it will never be None - it will be an empty list at minimum. The if self.constraints is not None: check is redundant.

-        if self.constraints is not None:
-            for constraint in self.constraints:
-                try:
-                    ast.parse(constraint, mode="eval")
-                except SyntaxError as e:
-                    raise ValueError(f"Constraints must be valid Python expressions: {e}") from e
+        for constraint in self.constraints:
+            try:
+                ast.parse(constraint, mode="eval")
+            except SyntaxError as e:
+                raise ValueError(f"Constraints must be valid Python expressions: {e}") from e
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 25666b8 and 7609788.

📒 Files selected for processing (43)
  • examples/win_at_p.py (1 hunks)
  • flashinfer_bench/apply/apply_api.py (4 hunks)
  • flashinfer_bench/apply/key.py (2 hunks)
  • flashinfer_bench/apply/runtime.py (4 hunks)
  • flashinfer_bench/apply/table.py (3 hunks)
  • flashinfer_bench/bench/benchmark.py (1 hunks)
  • flashinfer_bench/bench/evaluators/default.py (5 hunks)
  • flashinfer_bench/bench/evaluators/evaluator.py (5 hunks)
  • flashinfer_bench/bench/evaluators/lowbit.py (3 hunks)
  • flashinfer_bench/bench/evaluators/sampling.py (6 hunks)
  • flashinfer_bench/bench/evaluators/utils.py (1 hunks)
  • flashinfer_bench/bench/runner/isolated_runner.py (4 hunks)
  • flashinfer_bench/bench/runner/persistent_runner.py (5 hunks)
  • flashinfer_bench/bench/runner/runner.py (2 hunks)
  • flashinfer_bench/bench/timing.py (3 hunks)
  • flashinfer_bench/bench/utils.py (2 hunks)
  • flashinfer_bench/compile/builder.py (3 hunks)
  • flashinfer_bench/compile/builders/python_builder.py (1 hunks)
  • flashinfer_bench/compile/builders/torch_builder.py (1 hunks)
  • flashinfer_bench/compile/builders/tvm_ffi_builder.py (2 hunks)
  • flashinfer_bench/compile/runnable.py (4 hunks)
  • flashinfer_bench/data/__init__.py (2 hunks)
  • flashinfer_bench/data/definition.py (10 hunks)
  • flashinfer_bench/data/solution.py (1 hunks)
  • flashinfer_bench/data/trace.py (1 hunks)
  • flashinfer_bench/data/workload.py (1 hunks)
  • flashinfer_bench/tracing/__init__.py (1 hunks)
  • flashinfer_bench/tracing/builtin/policies.py (1 hunks)
  • flashinfer_bench/tracing/config.py (2 hunks)
  • flashinfer_bench/tracing/filter_policy.py (0 hunks)
  • flashinfer_bench/tracing/runtime.py (6 hunks)
  • flashinfer_bench/utils.py (4 hunks)
  • tests/apply/test_runtime.py (6 hunks)
  • tests/bench/test_evaluator.py (7 hunks)
  • tests/bench/test_isolated_runner.py (7 hunks)
  • tests/bench/test_persistent_runner.py (14 hunks)
  • tests/compile/test_builder.py (2 hunks)
  • tests/compile/test_runnable.py (2 hunks)
  • tests/data/test_definition.py (1 hunks)
  • tests/data/test_load_dump_schema.py (3 hunks)
  • tests/data/test_trace.py (2 hunks)
  • tests/tracing/test_tracing_config.py (4 hunks)
  • tests/tracing/test_tracing_runtime.py (2 hunks)
💤 Files with no reviewable changes (1)
  • flashinfer_bench/tracing/filter_policy.py
🧰 Additional context used
🧬 Code graph analysis (31)
tests/data/test_trace.py (2)
flashinfer_bench/data/workload.py (1)
  • Workload (52-66)
flashinfer_bench/data/trace.py (1)
  • Trace (153-193)
flashinfer_bench/bench/evaluators/utils.py (2)
flashinfer_bench/data/definition.py (3)
  • get_axes_values_from_inputs (305-321)
  • get_output_shapes (386-406)
  • torch_output_dtypes (420-428)
flashinfer_bench/bench/utils.py (1)
  • to_tensor (55-60)
flashinfer_bench/bench/runner/isolated_runner.py (3)
flashinfer_bench/data/workload.py (1)
  • Workload (52-66)
web/apps/web/lib/schemas/trace.ts (1)
  • Workload (142-142)
flashinfer_bench/bench/runner/persistent_runner.py (1)
  • run_ref (246-262)
flashinfer_bench/data/trace.py (3)
flashinfer_bench/data/utils.py (1)
  • BaseModelWithDocstrings (12-15)
flashinfer_bench/data/workload.py (1)
  • Workload (52-66)
web/apps/web/lib/schemas/trace.ts (1)
  • Workload (142-142)
tests/data/test_load_dump_schema.py (3)
flashinfer_bench/data/workload.py (1)
  • Workload (52-66)
web/apps/web/lib/schemas/trace.ts (2)
  • Workload (142-142)
  • Trace (144-144)
flashinfer_bench/data/trace.py (1)
  • Trace (153-193)
flashinfer_bench/bench/benchmark.py (3)
flashinfer_bench/bench/runner/runner.py (1)
  • run_workload (39-46)
flashinfer_bench/data/trace.py (2)
  • Trace (153-193)
  • EvaluationStatus (67-87)
web/apps/web/lib/schemas/trace.ts (1)
  • Trace (144-144)
flashinfer_bench/apply/apply_api.py (3)
tests/apply/test_runtime.py (1)
  • fallback (507-509)
flashinfer_bench/apply/runtime.py (2)
  • get_apply_runtime (39-50)
  • dispatch (108-209)
flashinfer_bench/tracing/runtime.py (2)
  • get_tracing_runtime (405-413)
  • collect (92-194)
flashinfer_bench/bench/utils.py (3)
flashinfer_bench/data/definition.py (1)
  • get_input_shapes (364-384)
flashinfer_bench/data/workload.py (1)
  • Workload (52-66)
flashinfer_bench/utils.py (1)
  • dtype_str_to_torch_dtype (49-55)
flashinfer_bench/tracing/config.py (2)
flashinfer_bench/tracing/workload_entry.py (1)
  • WorkloadEntry (6-23)
flashinfer_bench/tracing/builtin/policies.py (15)
  • submit (65-67)
  • submit (89-92)
  • submit (124-130)
  • submit (192-202)
  • submit (222-224)
  • drain (69-73)
  • drain (94-98)
  • drain (132-136)
  • drain (204-208)
  • drain (226-228)
  • reset (75-77)
  • reset (100-102)
  • reset (138-141)
  • reset (210-213)
  • reset (230-232)
flashinfer_bench/apply/table.py (2)
flashinfer_bench/env.py (1)
  • get_fib_cache_path (46-57)
flashinfer_bench/apply/key.py (1)
  • ApplyKey (11-22)
flashinfer_bench/compile/builders/torch_builder.py (3)
flashinfer_bench/compile/builders/python_builder.py (2)
  • cleaner (63-77)
  • _get_cleaner (44-79)
flashinfer_bench/compile/builder.py (1)
  • _try_validate_signature (129-194)
flashinfer_bench/compile/runnable.py (1)
  • Runnable (40-279)
flashinfer_bench/apply/key.py (2)
flashinfer_bench/data/definition.py (1)
  • get_axes_values_from_inputs (305-321)
flashinfer_bench/data/workload.py (1)
  • Workload (52-66)
flashinfer_bench/bench/evaluators/lowbit.py (3)
flashinfer_bench/compile/runnable.py (1)
  • Runnable (40-279)
flashinfer_bench/data/definition.py (1)
  • Definition (101-458)
flashinfer_bench/bench/evaluators/utils.py (2)
  • allocate_outputs (14-41)
  • normalize_result (44-90)
flashinfer_bench/apply/runtime.py (3)
tests/apply/test_runtime.py (1)
  • fallback (507-509)
flashinfer_bench/data/definition.py (1)
  • merge_kwargs_to_args (430-458)
flashinfer_bench/apply/key.py (4)
  • ApplyKeyFactory (74-89)
  • specialize (87-89)
  • build_from_args (30-32)
  • build_from_args (49-51)
tests/data/test_definition.py (1)
flashinfer_bench/data/definition.py (6)
  • AxisVar (35-46)
  • TensorSpec (81-94)
  • merge_kwargs_to_args (430-458)
  • AxisConst (19-32)
  • get_axes_values_from_inputs (305-321)
  • get_axes_values (259-303)
flashinfer_bench/data/workload.py (1)
flashinfer_bench/data/utils.py (1)
  • BaseModelWithDocstrings (12-15)
tests/bench/test_evaluator.py (5)
tests/bench/test_persistent_runner.py (1)
  • _simple_def (42-50)
flashinfer_bench/bench/evaluators/default.py (1)
  • DefaultEvaluator (32-218)
flashinfer_bench/bench/evaluators/evaluator.py (1)
  • evaluate (66-108)
flashinfer_bench/bench/evaluators/sampling.py (1)
  • SamplingEvaluator (42-232)
flashinfer_bench/bench/evaluators/lowbit.py (1)
  • LowBitEvaluator (19-119)
tests/compile/test_builder.py (1)
flashinfer_bench/compile/runnable.py (1)
  • RunnableMetadata (17-37)
flashinfer_bench/bench/runner/runner.py (1)
flashinfer_bench/data/workload.py (1)
  • Workload (52-66)
flashinfer_bench/compile/runnable.py (2)
flashinfer_bench/data/definition.py (3)
  • Definition (101-458)
  • get_axes_values_from_inputs (305-321)
  • get_output_shapes (386-406)
flashinfer_bench/utils.py (1)
  • dtype_str_to_torch_dtype (49-55)
tests/compile/test_runnable.py (2)
flashinfer_bench/compile/runnable.py (6)
  • Runnable (40-279)
  • RunnableMetadata (17-37)
  • call_destination_passing (156-202)
  • call_value_returning (230-266)
  • call_kwargs (136-144)
  • _allocate_output_tensors (204-228)
flashinfer_bench/data/definition.py (3)
  • AxisConst (19-32)
  • AxisVar (35-46)
  • Definition (101-458)
tests/bench/test_isolated_runner.py (5)
flashinfer_bench/data/workload.py (3)
  • Workload (52-66)
  • RandomInput (8-16)
  • ScalarInput (19-29)
web/apps/web/lib/schemas/trace.ts (1)
  • Workload (142-142)
flashinfer_bench/bench/utils.py (2)
  • gen_inputs (206-239)
  • load_safetensors (168-203)
flashinfer_bench/bench/runner/isolated_runner.py (1)
  • run_ref (44-60)
flashinfer_bench/bench/runner/persistent_runner.py (1)
  • run_ref (246-262)
flashinfer_bench/bench/evaluators/evaluator.py (1)
flashinfer_bench/bench/runner/runner.py (1)
  • DeviceBaseline (26-32)
flashinfer_bench/compile/builder.py (2)
flashinfer_bench/data/definition.py (1)
  • Definition (101-458)
flashinfer_bench/data/solution.py (1)
  • Solution (101-214)
flashinfer_bench/compile/builders/python_builder.py (1)
flashinfer_bench/compile/builder.py (1)
  • _try_validate_signature (129-194)
flashinfer_bench/bench/evaluators/sampling.py (4)
flashinfer_bench/data/trace.py (3)
  • Correctness (13-36)
  • Evaluation (90-150)
  • EvaluationStatus (67-87)
flashinfer_bench/data/definition.py (1)
  • Definition (101-458)
flashinfer_bench/bench/evaluators/default.py (1)
  • DefaultEvaluator (32-218)
flashinfer_bench/bench/evaluators/utils.py (2)
  • allocate_outputs (14-41)
  • normalize_result (44-90)
flashinfer_bench/utils.py (1)
flashinfer_bench/data/trace.py (1)
  • Environment (54-64)
flashinfer_bench/bench/timing.py (1)
flashinfer_bench/compile/runnable.py (1)
  • Runnable (40-279)
tests/tracing/test_tracing_runtime.py (5)
flashinfer_bench/tracing/runtime.py (2)
  • collect (92-194)
  • TracingRuntime (29-387)
flashinfer_bench/data/trace_set.py (1)
  • TraceSet (22-476)
flashinfer_bench/tracing/config.py (1)
  • TracingConfig (67-158)
flashinfer_bench/tracing/builtin/policies.py (1)
  • KeepFirstKPolicy (80-102)
flashinfer_bench/data/definition.py (3)
  • Definition (101-458)
  • AxisVar (35-46)
  • TensorSpec (81-94)
flashinfer_bench/tracing/__init__.py (1)
flashinfer_bench/tracing/config.py (2)
  • FilterPolicy (24-59)
  • TracingConfig (67-158)
flashinfer_bench/data/definition.py (4)
flashinfer_bench/data/utils.py (1)
  • BaseModelWithDocstrings (12-15)
flashinfer_bench/utils.py (1)
  • dtype_str_to_torch_dtype (49-55)
web/apps/web/lib/schemas/trace.ts (1)
  • Definition (139-139)
web/apps/web/scripts/prebuild.mjs (2)
  • result (27-31)
  • i (83-83)
🪛 Ruff (0.14.7)
flashinfer_bench/bench/evaluators/utils.py

40-40: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


81-83: Avoid specifying long messages outside the exception class

(TRY003)


88-88: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer_bench/bench/benchmark.py

136-136: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

flashinfer_bench/apply/apply_api.py

179-179: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer_bench/bench/utils.py

174-174: Avoid specifying long messages outside the exception class

(TRY003)


226-226: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer_bench/tracing/config.py

141-141: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


148-148: Prefer TypeError exception for invalid type

(TRY004)


148-148: Avoid specifying long messages outside the exception class

(TRY003)


151-151: Prefer TypeError exception for invalid type

(TRY004)


151-151: Avoid specifying long messages outside the exception class

(TRY003)


156-156: Avoid specifying long messages outside the exception class

(TRY003)

tests/apply/test_runtime.py

178-178: Unused lambda argument: args

(ARG005)


178-178: Unused lambda argument: kwargs

(ARG005)


305-305: Unpacked variable d is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


323-323: Unpacked variable d is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


340-340: Unpacked variable d is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


357-357: Unpacked variable d is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


458-458: Unpacked variable d is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


479-479: Unpacked variable d is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


498-498: Unpacked variable d is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


507-507: Unused function argument: args

(ARG001)


522-522: Unpacked variable d is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

flashinfer_bench/compile/builders/tvm_ffi_builder.py

292-292: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer_bench/apply/key.py

57-57: Unused method argument: args

(ARG002)

flashinfer_bench/bench/evaluators/lowbit.py

58-58: Do not catch blind exception: Exception

(BLE001)


66-66: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

flashinfer_bench/apply/runtime.py

154-154: Avoid specifying long messages outside the exception class

(TRY003)


167-171: Avoid specifying long messages outside the exception class

(TRY003)


201-201: Avoid specifying long messages outside the exception class

(TRY003)

tests/tracing/test_tracing_config.py

370-370: Unused function argument: inputs

(ARG001)


381-381: Unused function argument: inputs

(ARG001)


399-399: Unused function argument: inputs

(ARG001)

tests/data/test_definition.py

204-204: Unused method argument: definition

(ARG002)

tests/bench/test_evaluator.py

149-149: Unused function argument: args

(ARG001)


149-149: Unused function argument: kwargs

(ARG001)


150-150: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer_bench/tracing/runtime.py

160-160: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


173-173: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


224-224: Prefer TypeError exception for invalid type

(TRY004)


224-224: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer_bench/bench/evaluators/default.py

114-114: Do not catch blind exception: Exception

(BLE001)


122-122: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

flashinfer_bench/compile/runnable.py

127-129: Avoid specifying long messages outside the exception class

(TRY003)


150-152: Avoid specifying long messages outside the exception class

(TRY003)


175-178: Avoid specifying long messages outside the exception class

(TRY003)


191-194: Avoid specifying long messages outside the exception class

(TRY003)


196-196: Loop control variable result overrides iterable it iterates

(B020)


196-196: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


198-201: Prefer TypeError exception for invalid type

(TRY004)


198-201: Avoid specifying long messages outside the exception class

(TRY003)


209-212: Avoid specifying long messages outside the exception class

(TRY003)


224-224: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

tests/compile/test_runnable.py

28-28: Unused function argument: args

(ARG001)


143-143: Unused function argument: A

(ARG001)


304-304: Unused lambda argument: args

(ARG005)


335-335: Unused lambda argument: args

(ARG005)

flashinfer_bench/compile/builder.py

158-162: Avoid specifying long messages outside the exception class

(TRY003)


165-168: Avoid specifying long messages outside the exception class

(TRY003)


178-181: Avoid specifying long messages outside the exception class

(TRY003)


185-188: Avoid specifying long messages outside the exception class

(TRY003)


191-194: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer_bench/bench/evaluators/sampling.py

30-30: Avoid specifying long messages outside the exception class

(TRY003)


140-140: Do not catch blind exception: Exception

(BLE001)


330-330: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

flashinfer_bench/utils.py

51-51: Avoid specifying long messages outside the exception class

(TRY003)

flashinfer_bench/bench/timing.py

79-79: Avoid specifying long messages outside the exception class

(TRY003)


110-110: Avoid specifying long messages outside the exception class

(TRY003)


202-202: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

tests/tracing/test_tracing_runtime.py

100-100: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


142-142: Unused method argument: minimal_traceset

(ARG002)

flashinfer_bench/data/definition.py

168-168: Avoid specifying long messages outside the exception class

(TRY003)


279-279: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


283-286: Avoid specifying long messages outside the exception class

(TRY003)


287-287: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


291-294: Avoid specifying long messages outside the exception class

(TRY003)


299-302: Avoid specifying long messages outside the exception class

(TRY003)


358-358: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Run unit tests on ubuntu-latest and Python 3.13
  • GitHub Check: Run unit tests on ubuntu-latest and Python 3.12
  • GitHub Check: Run unit tests on ubuntu-latest and Python 3.11
🔇 Additional comments (79)
flashinfer_bench/bench/timing.py (4)

42-69: Verify multiprocessing lock semantics match intended use case.

The multiprocessing.Lock objects stored in _device_locks will only provide cross-process synchronization if processes are forked (inheriting the locks) rather than spawned. On spawn-based platforms (Windows, macOS with Python 3.8+), each process gets independent copies, defeating serialization.

If cross-process serialization is required with spawn, consider using multiprocessing.Manager().Lock() instead. If this is only used within a single process (multithreading) or with fork, the current approach is fine.


89-110: LGTM!

The statistical summarization logic is correct and covers the documented return modes appropriately.


206-212: LGTM!

The cloning logic correctly handles tensor arguments while passing through non-tensor values. This ensures each benchmark iteration operates on independent data.


215-250: LGTM!

The time_runnable function correctly:

  • Serializes access per-device via locking
  • Uses the setup mechanism to exclude argument cloning from timing
  • Properly unpacks the cloned arguments when invoking the Runnable

This aligns well with the Runnable contract described in the relevant code snippets.

examples/win_at_p.py (1)

112-113: LGTM - Variable naming improved for clarity.

The rename from wl to workload enhances readability and aligns with the broader codebase standardization introduced in this PR.

flashinfer_bench/data/workload.py (4)

8-17: LGTM - RandomInput model is well-defined.

The model correctly uses Pydantic validation with a literal type discriminator, enabling proper union type discrimination for InputSpec.


19-30: LGTM - ScalarInput model properly validates scalar values.

The Union[int, float, bool] type for the value field appropriately covers the scalar input types mentioned in the docstring.


32-46: LGTM - SafetensorsInput model is well-structured.

The use of NonEmptyString for path and tensor_key ensures validation at the model level, preventing empty values from being accepted.


52-66: LGTM - Workload model provides comprehensive configuration.

The model appropriately uses Dict types for axes and inputs, with proper validation constraints (NonNegativeInt, NonEmptyString). The uuid field enables workload tracking across the benchmark system.

tests/data/test_trace.py (1)

99-116: LGTM - Test updated with improved variable naming.

The rename from wl to workload in the test maintains consistency with the broader refactoring while preserving all test logic and assertions.

flashinfer_bench/data/trace.py (1)

5-10: LGTM - Import reorganization supports modularization.

The changes correctly import Workload from the new workload module, improving code organization by separating workload specifications from trace definitions.

tests/data/test_load_dump_schema.py (2)

50-59: LGTM - Consistent variable naming in test fixtures.

The variable rename enhances test readability and aligns with the project-wide standardization.


74-75: LGTM - Test updated with improved naming.

The workload variable naming is consistent with the rest of the test suite.

flashinfer_bench/tracing/__init__.py (1)

14-14: LGTM - Import source updated to reflect new module organization.

The InputDumpPolicyFunction is now imported from the config module, which appears to centralize policy-related types. The public API via __all__ remains unchanged, maintaining backward compatibility.

tests/data/test_definition.py (4)

121-132: LGTM - Important validation test added.

This test ensures that input and output parameter names remain distinct, preventing ambiguity in parameter resolution. This is a critical constraint for the DPS/VR calling convention support.


134-176: LGTM - Comprehensive test coverage for argument merging.

The test class thoroughly validates the merge_kwargs_to_args behavior across multiple scenarios:

  • Empty kwargs
  • Kwargs-only
  • Mixed args and kwargs
  • Output parameters as kwargs
  • Partial kwargs with missing parameters

This coverage is essential for ensuring correct parameter handling in the new calling conventions.


178-224: LGTM - Axis extraction tests cover key scenarios.

The tests validate axis value extraction from both tensor inputs (with .shape attribute) and scalar inputs (without .shape). The test correctly handles the case where scalar inputs are skipped during axis extraction.

Note: The static analysis warning about the unused definition parameter in line 204 is a false positive - the definition parameter is the pytest fixture, and scalar_def is intentionally created within the test to demonstrate scalar input handling.


226-254: LGTM - Shape-based axis extraction properly tested.

The tests validate both successful axis extraction from shape tuples and proper error handling when axis values are inconsistent across inputs. This is crucial for catching shape mismatches early.

flashinfer_bench/bench/runner/isolated_runner.py (1)

414-421: LGTM! Consistent parameter naming.

The rename from wl to workload improves readability and aligns with the broader API standardization across the codebase.

flashinfer_bench/bench/runner/runner.py (2)

42-42: LGTM! Consistent parameter naming.

The rename from wl to workload improves code clarity.


30-31: Verify all consumers of the API change.

The inputs and outputs fields have changed from dictionary-based to list-based structures. This is a breaking change that affects how downstream code accesses these values. Ensure all consumers of DeviceBaseline (including evaluators, serialization logic, and test utilities) have been updated to use positional indexing instead of named access.

tests/compile/test_runnable.py (2)

10-22: LGTM! Well-structured test helper.

The _make_definition() helper provides a clear, reusable definition for testing core Runnable functionality.


83-239: Excellent test coverage for DPS/VR functionality.

These test classes comprehensively cover destination-passing style (DPS) and value-returning (VR) scenarios, including:

  • Native DPS/VR callables
  • Conversions between styles
  • Edge cases (no outputs, multiple outputs)

The test structure is clear and well-organized.

flashinfer_bench/apply/table.py (2)

113-113: LGTM! Idiomatic Pydantic usage.

Migrating from custom from_encoded() to Pydantic's standard model_validate_json() is the correct approach for deserializing JSON to a Pydantic model.


140-140: LGTM! Idiomatic Pydantic usage.

Using model_dump_json() instead of a custom encode() method is the standard Pydantic approach for JSON serialization.

flashinfer_bench/data/solution.py (2)

79-79: LGTM! Best practice for mutable defaults.

Changing from default=[] to default_factory=list is the correct approach for mutable field defaults in Pydantic. This prevents unintended sharing of the default list instance across model instances.


82-82: LGTM! New field for DPS feature.

The destination_passing_style field addition supports the PR's core feature of enabling destination-passing style for solutions.

tests/bench/test_persistent_runner.py (1)

89-89: LGTM! Consistent variable naming.

The rename from wl to workload throughout the test suite improves readability and aligns with the API changes in the runner modules.

Also applies to: 114-114, 162-162, 225-225, 274-274, 323-325, 364-364, 418-420

tests/bench/test_isolated_runner.py (3)

109-118: LGTM! Clear documentation of list-based API.

The comment explaining that gen_inputs returns a list in definition order is helpful for understanding the new API. The positional access pattern (e.g., out[0] for X, out[1] for Y) is correctly aligned with the list-based input structure introduced in this PR.


131-143: LGTM! Consistent with API changes.

The updates to use safe_tensors (instead of stensors) and list-based indexing are consistent with the broader API refactoring.


193-193: LGTM! Consistent variable naming.

The rename from wl to workload improves readability.

Also applies to: 222-222

flashinfer_bench/tracing/builtin/policies.py (1)

18-20: LGTM! Proper use of TYPE_CHECKING.

Wrapping type-only imports in a TYPE_CHECKING guard is the standard pattern for avoiding runtime import overhead and potential circular dependencies while maintaining type checking capabilities.

flashinfer_bench/bench/runner/persistent_runner.py (1)

508-583: LGTM! Consistent parameter rename from wl to workload.

The renaming is applied consistently across the method signature, docstring, and all internal usages including run_ref calls and error logging. This aligns with the broader codebase terminology updates in this PR.

flashinfer_bench/compile/builders/python_builder.py (1)

141-152: LGTM! Well-structured metadata updates and signature validation.

The RunnableMetadata construction correctly uses the new field names (definition_name, solution_name) and includes destination_passing_style from the solution spec. The signature validation call before returning the Runnable ensures early detection of interface mismatches. The inclusion of the full definition object enables richer runtime context for downstream components.

tests/compile/test_builder.py (2)

38-44: LGTM! Test updated to match new RunnableMetadata API.

The field names are correctly updated from definition/solution to definition_name/solution_name, aligning with the updated RunnableMetadata model.


75-79: LGTM! Mock builder updated for consistency.

The mock RunnableMetadata construction correctly uses the new field names with empty strings, which is appropriate for testing dispatch logic where the metadata content is not the focus.

flashinfer_bench/bench/utils.py (3)

168-203: LGTM! Consistent rename to workload and safe_tensors.

The parameter and variable renames are applied consistently throughout load_safetensors. The logic remains unchanged.


206-239: Return type changed from Dict to List to support positional arguments.

The change to return a List[Any] in definition order aligns with the DPS calling convention where inputs are passed positionally. The docstring clearly documents this behavior.

However, the same potential issue exists at line 232 where shapes[name] is accessed - verify that get_input_shapes returns a structure that supports string key access.


176-192: Potential type mismatch: get_input_shapes return type requires verification.

The code at lines 176-192 accesses expected[name] using dictionary-style key access, where expected is assigned from definition.get_input_shapes(workload.axes). The review claims this method returns List[Optional[Tuple[int, ...]]] based on definition.py lines 363-383, which would cause a TypeError at runtime since list indices must be integers, not string keys.

This claim requires verification of the actual return type annotation and implementation of get_input_shapes.

flashinfer_bench/compile/builder.py (3)

33-38: LGTM! Class attributes documented.

Adding explicit class attribute declarations with docstrings improves code clarity and IDE support.


129-154: LGTM! Graceful handling of unavailable signatures.

The method correctly catches ValueError and TypeError when inspect.signature fails (e.g., for built-in functions or C extensions), allowing builds to proceed without validation in such cases.


156-168: LGTM! DPS and VR parameter count validation.

The parameter count checks are correct:

  • DPS: expects len(inputs) + len(outputs) parameters
  • VR: expects len(inputs) parameters
flashinfer_bench/bench/evaluators/utils.py (1)

44-90: LGTM! Well-structured result normalization.

The function correctly handles various return types and provides clear error messages when validation fails. The descriptive exception messages are appropriate for helping users diagnose issues.

flashinfer_bench/data/__init__.py (1)

12-14: LGTM! Clean module reorganization.

The refactoring moves workload-related types to a dedicated module while maintaining the same public API. The updated grouping in __all__ improves clarity.

Also applies to: 28-34

flashinfer_bench/compile/builders/torch_builder.py (1)

158-171: LGTM! Enhanced metadata and signature validation.

The updated metadata structure with explicit field names (definition_name, solution_name, destination_passing_style) improves clarity. Adding signature validation via _try_validate_signature catches mismatches early, and using the callable directly simplifies the implementation.

tests/tracing/test_tracing_config.py (1)

323-406: LGTM! Comprehensive test coverage for the new API.

The TestGetInputsToDump class thoroughly exercises the updated get_inputs_to_dump(names, values) signature with tests covering static lists, callable policies, validation errors, and edge cases. The test structure is clear and well-organized.

tests/tracing/test_tracing_runtime.py (1)

113-221: LGTM! Thorough testing of args/kwargs handling and calling conventions.

The new test classes TestCollectArgsKwargs and TestCollectCallingConvention provide excellent coverage of the updated API. Tests validate kwargs merging, argument validation, and correct behavior for both value-returning and destination-passing styles.

flashinfer_bench/compile/builders/tvm_ffi_builder.py (1)

6-6: LGTM! Consistent updates aligned with torch_builder.

The metadata structure updates and signature validation addition mirror the changes in torch_builder.py, maintaining consistency across builders. The addition of the inspect import supports the new validation functionality.

Also applies to: 280-297

tests/bench/test_evaluator.py (2)

61-80: Excellent mock design for testing both calling conventions.

The _make_dps_mock and _make_vr_mock helper functions clearly separate the behavior of destination-passing and value-returning styles. The DPS mock's side effect correctly writes to the output argument, while the VR mock simply returns the result.


87-488: LGTM! Comprehensive test coverage for all evaluator variants.

The test suite thoroughly exercises all evaluator types (Default, Sampling, LowBit) with both DPS and VR styles. Each test class validates:

  • Success cases with correct outputs
  • Shape mismatch detection
  • Numerical error detection
  • Runtime error handling
  • Evaluator-specific metrics (e.g., matched_ratio for LowBit)

The organization into separate test classes for each evaluator and style combination makes the test suite clear and maintainable.

tests/apply/test_runtime.py (2)

438-447: Nice mock design for DPS testing.

The FakeTensorWithFill class elegantly simulates tensor mutation for destination-passing style tests. The fill_ method allows verification that DPS functions correctly modify their output arguments in place.


296-533: LGTM! Comprehensive test coverage for dispatch with both calling conventions.

The test suite thoroughly validates:

  • Args/kwargs handling and merging
  • Value-returning style dispatch and results
  • Destination-passing style with output mutation
  • Fallback behavior for cache misses
  • Policy-based routing for both styles

The organization into separate test classes for different aspects (args/kwargs, calling conventions, DPS style) makes the test suite easy to navigate and maintain.

flashinfer_bench/apply/runtime.py (2)

157-171: Argument validation logic looks correct.

The DPS detection based on argument count comparison with num_inputs and num_inputs + num_outputs is well-designed. The early extraction of input_args for key building correctly separates inputs from outputs in DPS mode.


204-209: LGTM!

The dual calling convention handling is correctly implemented. DPS uses full args including pre-allocated outputs and returns None, while value-returning uses only input_args and returns the result.

flashinfer_bench/apply/apply_api.py (2)

134-180: LGTM!

The _dispatch_apply_or_tracing helper cleanly centralizes the dispatch logic for both decorator and function modes. The priority order (apply → tracing → fallback) and the resolver invocation pattern are well-designed.


123-131: LGTM!

The decorator mode correctly uses @wraps(fallback) to preserve function metadata and cleanly delegates to the centralized dispatch helper.

flashinfer_bench/utils.py (2)

29-46: LGTM!

The lazy dtype mapping with @lru_cache(maxsize=1) is a good pattern for deferring torch import until first use while ensuring the mapping is built only once. This improves import-time performance for modules that don't immediately need torch.


67-79: LGTM!

Local torch imports in is_cuda_available and list_cuda_devices are consistent with the lazy import strategy established in this file.

flashinfer_bench/bench/evaluators/evaluator.py (2)

41-63: LGTM!

The signature updates for check_correctness and eval_performance to use List[List[Any]] for inputs and List[List[torch.Tensor]] for ref_outputs are consistent with the PR's refactoring to list-based data structures. Adding definition: Definition to eval_performance enables DPS-aware output allocation in implementations.


77-97: LGTM!

The evaluate method correctly passes the definition parameter to both check_correctness and eval_performance, maintaining consistency with the updated abstract method signatures.

flashinfer_bench/bench/evaluators/lowbit.py (2)

46-62: LGTM!

The DPS and value-returning execution paths are correctly implemented. The DPS path properly allocates outputs first using allocate_outputs, while the VR path normalizes the result afterward. Both paths correctly use torch.no_grad() and synchronize before proceeding.


77-103: LGTM!

The per-tensor validation logic is thorough: it checks for non-finite values first, then computes error statistics. The aggregation of max errors and minimum matched ratio across all tensors provides a comprehensive correctness assessment.

flashinfer_bench/tracing/config.py (1)

24-60: Well-structured FilterPolicy Protocol.

The Protocol clearly defines the contract for filter policies with comprehensive docstrings. This aligns well with the existing implementations in flashinfer_bench/tracing/builtin/policies.py.

flashinfer_bench/tracing/runtime.py (1)

92-124: Clean refactor of collect() to support both calling conventions.

The signature change from runtime_args: Dict[str, Any] to (args, kwargs) with proper validation for VR vs DPS calling conventions is well-implemented. The error handling with early returns prevents runtime crashes during tracing.

flashinfer_bench/apply/key.py (2)

11-22: Clean migration to Pydantic frozen model.

Using ConfigDict(frozen=True) provides immutability for use as dict keys and in sets, replacing the manual __hash__/__eq__ implementations. The Pydantic model also provides built-in JSON serialization via model_dump_json() and model_validate_json().


48-58: Builder methods updated for positional args interface.

The build_from_args and features methods now accept Tuple[Any, ...] instead of dict-based runtime kwargs. The unused args parameter in features() at line 57 is intentional—it maintains the abstract interface contract even when the implementation doesn't need the value.

flashinfer_bench/compile/runnable.py (2)

27-36: Updated metadata fields support DPS/VR distinction.

The rename from definition/solution to definition_name/solution_name improves clarity, and the new destination_passing_style flag enables runtime branching between calling conventions.


98-121: Clean return value normalization logic.

The _revise_return_value helper properly handles edge cases: pass-through for non-tuples, unpacking single-element tuples, and converting empty tuples to None.

flashinfer_bench/bench/evaluators/default.py (3)

46-64: Baseline generation correctly uses value-returning style.

The reference implementation is always value-returning, and the new normalize_result helper properly converts outputs to a consistent List[torch.Tensor] format for comparison.


98-118: Clean DPS vs VR branching in correctness checking.

The conditional logic properly handles both calling conventions: allocating outputs for DPS and normalizing results for VR. The structure is clear and maintainable.


186-197: Performance evaluation correctly handles DPS output allocation.

The DPS path properly allocates outputs and includes them in the timing arguments, ensuring accurate latency measurement that includes any output tensor initialization overhead.

flashinfer_bench/bench/evaluators/sampling.py (4)

26-39: LGTM! Clear helper utilities for name-based input access.

The helper functions provide a clean abstraction for accessing inputs by name from the list-based structure, aligning with the DPS refactoring pattern used across evaluators.


66-81: LGTM! Clean list-based baseline construction.

The baseline correctly stores inputs and outputs as List[List[Any]] and List[List[torch.Tensor]] respectively, with expected_probs stored as the first tensor in the outputs list, aligning with the DPS-oriented structure.


130-147: DPS branch correctly integrated.

The DPS handling matches the pattern in DefaultEvaluator.check_correctness (see default.py lines 90-103), properly allocating outputs via allocate_outputs and invoking the runnable with combined input/output tensors.

The broad Exception catch at line 140 is consistent with the error handling pattern used across all evaluators in this codebase to capture runtime failures.


362-373: LGTM! DPS sampling loop correctly implemented.

The DPS branch properly allocates outputs for the padded inputs and invokes the runnable with combined arguments, while the non-DPS branch normalizes the returned result. Both paths correctly extract the samples tensor from out[0].

flashinfer_bench/data/definition.py (4)

15-16: LGTM! Proper use of TYPE_CHECKING for torch import.

The TYPE_CHECKING guard correctly avoids importing torch at runtime while still enabling type hints for the torch.dtype return types in the cached properties.


158-169: LGTM! Important validation for DPS support.

This validator ensures input and output names don't overlap, which is critical for the DPS calling convention where outputs are passed as additional arguments and need unique identifiers.


450-457: Silent break on missing kwargs may be unexpected.

The loop breaks at line 457 when a parameter name is not found in kwargs. This silently stops merging if kwargs don't form a contiguous sequence after the positional args. This behavior might be intentional for partial application, but could also mask errors where expected kwargs are accidentally omitted.

Consider whether raising an error or documenting this behavior would be more appropriate.


408-428: Unable to independently verify DType enum conversion claims.

Repository access failed, preventing verification of: (1) whether DType inherits from (str, Enum) at line 49, (2) the dtype_str_to_torch_dtype function signature and implementation, and (3) whether the code correctly handles spec.dtype as a DType enum. The review comment's conclusion that the code is correct cannot be confirmed without access to the actual codebase. Web search provided only generic flashinfer API patterns, not the specific implementation details needed.

if fallback is None:
raise RuntimeError(f"Definition '{def_name}' not found and no fallback provided")
return fallback(**runtime_kwargs)
return fallback(*args, **kwargs)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Inconsistent fallback invocation: kwargs dropped on miss path.

When the definition is not found (line 155), fallback is called with *args, **kwargs. However, when no runnable is available after miss policy (line 202), fallback is called with only *args, dropping kwargs. This inconsistency could cause unexpected behavior when fallback functions expect keyword arguments.

Apply this diff to make fallback invocation consistent:

         if runnable is None:
             if fallback is None:
                 raise RuntimeError(f"Apply miss for '{def_name}' and no fallback provided")
-            return fallback(*args)
+            return fallback(*args, **kwargs)

Also applies to: 202-202

🤖 Prompt for AI Agents
In flashinfer_bench/apply/runtime.py around lines 155 and 202, the fallback
invocation is inconsistent: line 155 calls fallback(*args, **kwargs) but line
202 calls fallback(*args) which drops keyword arguments; update the miss-path at
line 202 to call fallback with the same signature (fallback(*args, **kwargs)) so
kwargs are preserved on both miss paths, and run tests to ensure no callers rely
on the older behavior.

Comment on lines +72 to +86
def _quantile(a: List[float], q: Sequence[float]) -> List[float]:
"""Compute quantiles of a list of values."""
n = len(a)
a = sorted(a)

def get_quantile(q: float) -> float:
if not (0 <= q <= 1):
raise ValueError("Quantiles must be in the range [0, 1]")
point = q * (n - 1)
lower = math.floor(point)
upper = math.ceil(point)
t = point - lower
return (1 - t) * a[lower] + t * a[upper]

return [get_quantile(qi) for qi in q]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Handle empty input list edge case.

If a is empty, accessing a[lower] and a[upper] will raise an IndexError. Consider adding an early guard.

 def _quantile(a: List[float], q: Sequence[float]) -> List[float]:
     """Compute quantiles of a list of values."""
     n = len(a)
+    if n == 0:
+        raise ValueError("Cannot compute quantiles of an empty list")
     a = sorted(a)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _quantile(a: List[float], q: Sequence[float]) -> List[float]:
"""Compute quantiles of a list of values."""
n = len(a)
a = sorted(a)
def get_quantile(q: float) -> float:
if not (0 <= q <= 1):
raise ValueError("Quantiles must be in the range [0, 1]")
point = q * (n - 1)
lower = math.floor(point)
upper = math.ceil(point)
t = point - lower
return (1 - t) * a[lower] + t * a[upper]
return [get_quantile(qi) for qi in q]
def _quantile(a: List[float], q: Sequence[float]) -> List[float]:
"""Compute quantiles of a list of values."""
n = len(a)
if n == 0:
raise ValueError("Cannot compute quantiles of an empty list")
a = sorted(a)
def get_quantile(q: float) -> float:
if not (0 <= q <= 1):
raise ValueError("Quantiles must be in the range [0, 1]")
point = q * (n - 1)
lower = math.floor(point)
upper = math.ceil(point)
t = point - lower
return (1 - t) * a[lower] + t * a[upper]
return [get_quantile(qi) for qi in q]
🧰 Tools
🪛 Ruff (0.14.7)

79-79: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In flashinfer_bench/bench/timing.py around lines 72-86, the _quantile function
will IndexError on an empty input list; add an early guard that checks if n == 0
and raise a clear ValueError (e.g. "Cannot compute quantiles of an empty list")
before sorting/processing so callers get a meaningful error instead of an
IndexError; keep the rest of the quantile logic unchanged.

Comment on lines +113 to +120
def _get_empty_cache_for_benchmark() -> torch.Tensor:
"""Create a buffer for clearing L2 cache before benchmark runs.
We maintain a buffer of 256 MB that we clear before each kernel call
to make sure that the L2 cache doesn't contain any input data before the run.
"""
cache_size = 256 * 1024 * 1024
return torch.empty(int(cache_size // 4), dtype=torch.int, device="cuda")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Cache buffer uses hardcoded "cuda" device.

The cache tensor is always allocated on the default CUDA device ("cuda"), but time_runnable accepts a specific device parameter (e.g., "cuda:1"). If the benchmark runs on a non-default GPU, the cache-clearing operation may not effectively clear that GPU's L2 cache, and cross-device synchronization could introduce overhead.

Consider parameterizing the device:

-def _get_empty_cache_for_benchmark() -> torch.Tensor:
+def _get_empty_cache_for_benchmark(device: str = "cuda") -> torch.Tensor:
     """Create a buffer for clearing L2 cache before benchmark runs.
 
     We maintain a buffer of 256 MB that we clear before each kernel call
     to make sure that the L2 cache doesn't contain any input data before the run.
     """
     cache_size = 256 * 1024 * 1024
-    return torch.empty(int(cache_size // 4), dtype=torch.int, device="cuda")
+    return torch.empty(int(cache_size // 4), dtype=torch.int, device=device)

This would require propagating the device through do_bench as well.

🤖 Prompt for AI Agents
In flashinfer_bench/bench/timing.py around lines 113-120 the cache buffer is
always allocated on the default "cuda" device which breaks when benchmarking on
a non-default GPU (e.g., "cuda:1"); change _get_empty_cache_for_benchmark to
accept a device argument and allocate the tensor on that device, then propagate
that device through the benchmarking call path (update do_bench and any callers
such as time_runnable to pass the selected device) so the cache-clear buffer is
created on the same CUDA device as the workload and avoids cross-device
sync/overhead.

Comment on lines +169 to +194
# Check return annotation
num_outputs = len(definition.outputs)
ret_ann = signature.return_annotation
if ret_ann is not inspect.Signature.empty:
origin = getattr(ret_ann, "__origin__", None)
if origin is tuple:
# If returning tuple, length must match num_outputs
args = getattr(ret_ann, "__args__", None)
if args is not None and len(args) != num_outputs:
raise BuildError(
f"Value-returning style callable with {num_outputs} outputs must "
f"return a {num_outputs}-element tuple, got {ret_ann}"
)
elif origin is None:
# If return None, num_outputs must be 0
if num_outputs != 0:
raise BuildError(
f"Value-returning style callable with {num_outputs} outputs must "
f"return None, got {ret_ann}"
)
elif num_outputs != 1:
# If returning non-tuple, num_outputs must be 1
raise BuildError(
f"Value-returning style callable returning non-tuple must "
f"only have one output, got {num_outputs}"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Logic bug: origin is None conflates -> None with non-generic types like -> int.

The condition at line 182 (elif origin is None) is true both when:

  1. The return annotation is None (i.e., -> None)
  2. The return annotation is a non-generic type like int, float, torch.Tensor

For a function with signature def foo(x) -> torch.Tensor and 1 output, origin is None, so the code enters line 182 and raises an error at line 184 because num_outputs (1) != 0.

Consider checking for None type explicitly:

-            elif origin is None:
-                # If return None, num_outputs must be 0
-                if num_outputs != 0:
-                    raise BuildError(
-                        f"Value-returning style callable with {num_outputs} outputs must "
-                        f"return None, got {ret_ann}"
-                    )
-            elif num_outputs != 1:
+            elif ret_ann is type(None):
+                # If return None, num_outputs must be 0
+                if num_outputs != 0:
+                    raise BuildError(
+                        f"Value-returning style callable with {num_outputs} outputs must "
+                        f"not return None when outputs are expected"
+                    )
+            elif origin is None and num_outputs != 1:
                 # If returning non-tuple, num_outputs must be 1
                 raise BuildError(
                     f"Value-returning style callable returning non-tuple must "
                     f"only have one output, got {num_outputs}"
                 )
🧰 Tools
🪛 Ruff (0.14.7)

178-181: Avoid specifying long messages outside the exception class

(TRY003)


185-188: Avoid specifying long messages outside the exception class

(TRY003)


191-194: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In flashinfer_bench/compile/builder.py around lines 169-194, the code wrongly
treats any non-generic return annotation (origin is None) as if it were the
explicit None return; change the branch so you explicitly check for an exact
None return (e.g., ret_ann is None or ret_ann is type(None)) and handle the
"num_outputs must be 0" error only in that case, otherwise fall through to the
non-tuple handling for non-generic types (int, float, torch.Tensor) so a
single-output non-tuple return is allowed when num_outputs == 1.

Comment on lines +196 to +202
for result, output in zip(result, args_output):
if not isinstance(result, torch.Tensor) or not isinstance(output, torch.Tensor):
raise ValueError(
"Destination-passing style callable must return a tuple of tensors, got "
f"{type(result)} and {type(output)}"
)
output.copy_(result)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Loop variable result shadows the iterable.

The loop for result, output in zip(result, args_output) reuses result as both the iterable and the loop variable. This works by accident because Python evaluates the iterable before binding loop variables, but it's confusing and error-prone. Additionally, zip() should use strict=True since mismatched lengths indicate a bug.

-        for result, output in zip(result, args_output):
-            if not isinstance(result, torch.Tensor) or not isinstance(output, torch.Tensor):
+        for res_tensor, out_tensor in zip(result, args_output, strict=True):
+            if not isinstance(res_tensor, torch.Tensor) or not isinstance(out_tensor, torch.Tensor):
                 raise ValueError(
                     "Destination-passing style callable must return a tuple of tensors, got "
-                    f"{type(result)} and {type(output)}"
+                    f"{type(res_tensor)} and {type(out_tensor)}"
                 )
-            output.copy_(result)
+            out_tensor.copy_(res_tensor)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for result, output in zip(result, args_output):
if not isinstance(result, torch.Tensor) or not isinstance(output, torch.Tensor):
raise ValueError(
"Destination-passing style callable must return a tuple of tensors, got "
f"{type(result)} and {type(output)}"
)
output.copy_(result)
for res_tensor, out_tensor in zip(result, args_output, strict=True):
if not isinstance(res_tensor, torch.Tensor) or not isinstance(out_tensor, torch.Tensor):
raise ValueError(
"Destination-passing style callable must return a tuple of tensors, got "
f"{type(res_tensor)} and {type(out_tensor)}"
)
out_tensor.copy_(res_tensor)
🧰 Tools
🪛 Ruff (0.14.7)

196-196: Loop control variable result overrides iterable it iterates

(B020)


196-196: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


198-201: Prefer TypeError exception for invalid type

(TRY004)


198-201: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In flashinfer_bench/compile/runnable.py around lines 196-202, the for-loop
reuses the name `result` for both the iterable and the loop variable which is
confusing and error-prone; rename the iterable (e.g., `results` or
`result_iter`) and iterate with a distinct loop variable (e.g., `result_item`)
and use zip(..., strict=True) to catch length mismatches; keep the existing type
checks and the output.copy_(result_item) call but replace references to the
renamed variables accordingly.

Comment on lines +16 to +17
InputDumpPolicyFunction = Callable[Dict[str, Any], List[str]]
"""Function that selects which inputs to dump from input names and values."""
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Invalid Callable type alias syntax.

The type alias uses incorrect bracket syntax for Callable. This will cause a runtime TypeError when Python evaluates the type hint.

-InputDumpPolicyFunction = Callable[Dict[str, Any], List[str]]
+InputDumpPolicyFunction = Callable[[Dict[str, Any]], List[str]]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
InputDumpPolicyFunction = Callable[Dict[str, Any], List[str]]
"""Function that selects which inputs to dump from input names and values."""
InputDumpPolicyFunction = Callable[[Dict[str, Any]], List[str]]
"""Function that selects which inputs to dump from input names and values."""
🤖 Prompt for AI Agents
In flashinfer_bench/tracing/config.py around lines 16-17, the type alias uses
incorrect Callable syntax; change the alias to use a list for the argument
types, e.g. InputDumpPolicyFunction = Callable[[Dict[str, Any]], List[str]], and
ensure typing.Callable is imported (or adjust imports) so the evaluated type
hint is valid.

@Ubospica Ubospica marked this pull request as draft December 5, 2025 19:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant