Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/win_at_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def collect_runs(paths: Iterable[str], author_map: Dict[str, str]) -> List[Run]:
try:
definition = obj["definition"]
solution = obj["solution"]
wl = obj["workload"]
uuid = wl["uuid"]
workload = obj["workload"]
uuid = workload["uuid"]
evalo = obj["evaluation"]
status = evalo["status"]
perf = evalo.get("performance")
Expand Down
185 changes: 105 additions & 80 deletions flashinfer_bench/apply/apply_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

import inspect
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, Union, overload

from flashinfer_bench.data import TraceSet
Expand All @@ -17,19 +17,21 @@ def apply(
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ...


# Function mode
# Function mode with positional args
@overload
def apply(
def_name_or_resolver: Union[str, Callable[..., str]],
*,
runtime_kwargs: Dict[str, Any],
fallback: Optional[Callable[..., Any]],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
fallback: Optional[Callable[..., Any]] = None,
) -> Any: ...


def apply(
def_name_or_resolver: Union[str, Callable[..., str]],
runtime_kwargs: Optional[Dict[str, Any]] = None,
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
fallback: Optional[Callable[..., Any]] = None,
):
"""
Expand All @@ -41,17 +43,26 @@ def apply(
1) **Decorator mode** (only ``def_name_or_resolver`` provided): returns a decorator
that wraps a kernel function with a router. The router selects the best-performing
candidate according to the function's runtime arguments.
2) **Function mode** (``runtime_kwargs`` provided, optionally ``fallback``):
2) **Function mode** (``args`` or ``kwargs`` provided, optionally ``fallback``):
immediately resolves and calls the best-performing kernel and returns its result.

The calling convention (value-returning vs destination-passing) is determined by the
number of arguments:
- If len(args) == len(inputs): value-returning style, solution returns outputs
- If len(args) == len(inputs) + len(outputs): destination-passing style, outputs are
pre-allocated and passed as arguments

Parameters
----------
def_name_or_resolver : Union[str, Callable[..., str]]
The kernel name, or a resolver ``fn(*args, **kwargs) -> str`` that maps runtime
The kernel name, or a resolver ``fn(*args) -> str`` that maps runtime
arguments to a kernel name (definition name).
Comment on lines +58 to 59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for def_name_or_resolver states its signature is fn(*args) -> str, but the implementation in _dispatch_apply_or_tracing calls it with *args, **kwargs. The docstring should be updated to match the more flexible implementation.

Suggested change
The kernel name, or a resolver ``fn(*args) -> str`` that maps runtime
arguments to a kernel name (definition name).
The kernel name, or a resolver ``fn(*args, **kwargs) -> str`` that maps runtime
arguments to a kernel name (definition name).

runtime_kwargs : Dict[str, Any], optional
Only used in **function mode**. The runtime arguments to feed into the selected
kernel. Use this to call the kernel immediately instead of returning a decorator.
args : Tuple[Any, ...], optional
Only used in **function mode**. The positional runtime arguments to feed into
the selected kernel. The number of arguments determines the calling convention.
kwargs : Dict[str, Any], optional
Only used in **function mode**. The keyword runtime arguments to feed into
the selected kernel. The number of arguments determines the calling convention.
fallback : Optional[Callable[..., Any]], optional
Only used in **function mode**. A fallback function to invoke when no matching
kernel is found in the Trace database.
Expand All @@ -62,83 +73,113 @@ def apply(
- **Decorator mode**: a decorator that transforms the target kernel function into
a routed version.
- **Function mode**: the return value produced by the selected (or fallback) kernel.
For destination-passing style, returns None.

Examples
--------
Decorator mode with a fixed name
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
>>> @apply("gemm_bf16")
... def gemm_bf16(A, B):
... return torch.nn.functional.linear(A, B)
... return A @ B.T

Decorator mode with a resolver
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
>>> @apply(lambda A, B: f"gemm_n{B.shape[0]}_k{B.shape[1]}")
... def gemm_bf16(A, B):
... return torch.nn.functional.linear(A, B)
... return A @ B.T

Function mode
~~~~~~~~~~~~~
Function mode (value-returning)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
>>> out = apply(
... "gemm_bf16",
... runtime_kwargs={"A": A, "B": B, "bias": None},
... fallback=lambda **kw: torch.nn.functional.linear(**kw),
... args=(A, B),
... fallback=lambda A, B: A @ B.T,
... )
"""
# Imperative
if runtime_kwargs is not None:
kwargs = dict(runtime_kwargs)
def_name = (
def_name_or_resolver
if isinstance(def_name_or_resolver, str)
else def_name_or_resolver(**kwargs)
)

tracing_rt = get_tracing_runtime()
if tracing_rt is not None:
tracing_rt.collect(def_name, kwargs)
tracing_rt.flush()

apply_rt = get_apply_runtime()
if apply_rt is None:
if fallback is None:
raise RuntimeError("Apply is not enabled and no fallback provided")
return fallback(**kwargs)

return apply_rt.dispatch(def_name, kwargs, fallback)

# Decorator
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
# Inspect once
sig = inspect.signature(fn)
param_names = tuple(sig.parameters.keys())

Function mode (destination-passing)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
>>> C = torch.empty(M, N, device=A.device, dtype=A.dtype)
>>> apply(
... "gemm_bf16",
... args=(A, B, C), # C is pre-allocated output
... fallback=lambda *args: my_gemm_dps(*args),
... )

Function mode with kwargs
~~~~~~~~~~~~~~~~~~~~~~~~~
>>> out = apply(
... "gemm_bf16",
... kwargs={"A": A, "B": B},
... fallback=lambda A, B: A @ B.T,
... )
"""
# Imperative / Function mode
if args is not None or kwargs is not None:
args = args if args is not None else ()
kwargs = kwargs if kwargs is not None else {}
return _dispatch_apply_or_tracing(def_name_or_resolver, args, kwargs, fallback)

# Decorator mode
def decorator(fallback: Callable[..., Any]) -> Callable[..., Any]:
@wraps(fallback)
def wrapped(*args: Any, **kwargs: Any):
tracing_rt = get_tracing_runtime()
apply_rt = get_apply_runtime()
if tracing_rt is None and apply_rt is None:
return fn(*args, **kwargs)

bound = _merge_args_to_kwargs(param_names, args, kwargs)
def_name = (
def_name_or_resolver
if isinstance(def_name_or_resolver, str)
else def_name_or_resolver(**bound)
)
if tracing_rt is not None:
tracing_rt.collect(def_name, bound)
if apply_rt is None:
return fn(*args, **kwargs)
return apply_rt.dispatch(def_name, bound, fn)

wrapped.__name__ = fn.__name__
wrapped.__doc__ = fn.__doc__
wrapped.__wrapped__ = fn
return _dispatch_apply_or_tracing(def_name_or_resolver, args, kwargs, fallback)

return wrapped

return decorator


def _dispatch_apply_or_tracing(
def_name_or_resolver: Union[str, Callable[..., str]],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
fallback: Optional[Callable[..., Any]],
) -> Any:
"""Internal dispatch function that handles tracing and apply.

Parameters
----------
def_name_or_resolver : Union[str, Callable[..., str]]
Definition name or a resolver function.
args : Tuple[Any, ...]
Positional arguments (inputs only, or inputs + outputs for DPS).
kwargs : Dict[str, Any]
Keyword arguments.
fallback : Optional[Callable[..., Any]]
Fallback function.

Returns
-------
Any
Result of the call (None for DPS).
"""
# Resolve def_name
def_name = (
def_name_or_resolver
if isinstance(def_name_or_resolver, str)
else def_name_or_resolver(*args, **kwargs)
)

apply_rt = get_apply_runtime()

# Apply
if apply_rt is not None:
return apply_rt.dispatch(def_name, args, kwargs, fallback)

tracing_rt = get_tracing_runtime()

# Tracing
if tracing_rt is not None:
tracing_rt.collect(def_name, args, kwargs)

# No runtime enabled
if fallback is None:
raise RuntimeError("Apply or tracing is not enabled and no fallback provided")
return fallback(*args, **kwargs)


def enable_apply(
dataset_path: Optional[str] = None, apply_config: Optional[ApplyConfig] = None
) -> ApplyRuntime:
Expand Down Expand Up @@ -190,19 +231,3 @@ def disable_apply() -> None:
Check out the `enable_apply` function for examples.
"""
set_apply_runtime(None)


def _merge_args_to_kwargs(
param_names: Tuple[str], args: Tuple[Any], kwargs: Dict[str, Any]
) -> Dict[str, Any]:
if len(args) > len(param_names):
raise TypeError("Too many positional arguments")
merged: Dict[str, Any] = {}
for i, val in enumerate(args):
merged[param_names[i]] = val
# Merge kwargs with conflict detection
for k, v in kwargs.items():
if k in merged:
raise TypeError(f"Multiple values for argument '{k}'")
merged[k] = v
return merged
Loading
Loading