diff --git a/flashinfer_bench/apply/runtime.py b/flashinfer_bench/apply/runtime.py index a027dd6d..4d2f5d8d 100644 --- a/flashinfer_bench/apply/runtime.py +++ b/flashinfer_bench/apply/runtime.py @@ -14,7 +14,6 @@ from flashinfer_bench.data import TraceSet from flashinfer_bench.env import get_fib_dataset_path, get_fib_enable_apply from flashinfer_bench.logging import get_logger -from flashinfer_bench.tracing import get_tracing_runtime from .config import ApplyConfig from .key import ApplyKeyBuilder, ApplyKeyFactory @@ -23,46 +22,6 @@ logger = get_logger("ApplyRuntime") -def _init_apply_runtime_from_env() -> Optional["ApplyRuntime"]: - """Initialize the global runtime from environment variables if configured.""" - fib_enable_apply = get_fib_enable_apply() - if not fib_enable_apply: - return - fib_dataset_path = get_fib_dataset_path() - trace_set = TraceSet.from_path(fib_dataset_path) - apply_config = ApplyConfig() - return ApplyRuntime(trace_set, apply_config, None) - - -_global_apply_runtime: Optional["ApplyRuntime"] = _init_apply_runtime_from_env() - - -def get_apply_runtime() -> Optional["ApplyRuntime"]: - """Get the global ApplyRuntime instance. - - Returns the singleton runtime instance, initializing it from environment - variables if it hasn't been created yet. - - Returns - ------- - Optional[ApplyRuntime] - The global runtime instance, or None if not initialized. - """ - return _global_apply_runtime - - -def set_apply_runtime(rt: Optional["ApplyRuntime"]) -> None: - """Set the global ApplyRuntime instance. - - Parameters - ---------- - rt : Optional[ApplyRuntime] - The runtime instance to set, or None to clear the global runtime. - """ - global _global_apply_runtime - _global_apply_runtime = rt - - class ApplyRuntime: """Runtime system for dispatching optimized implementations based on trace data. @@ -146,16 +105,6 @@ def dispatch( If the definition is not found and no fallback is provided, or if no suitable implementation is available and no fallback is provided. """ - # First try to run tracing logic in case tracing is enabled - tracing_runtime = get_tracing_runtime() - if tracing_runtime is not None: - try: - tracing_runtime.collect(def_name, runtime_kwargs) - except Exception as e: - logger.error(f"Error collecting trace for {def_name}: {e}") - pass - - # Then try to run apply logic defn = self._trace_set.definitions.get(def_name) if defn is None: if fallback is None: @@ -198,3 +147,43 @@ def __enter__(self) -> None: def __exit__(self, exc_type, exc, tb) -> bool: set_apply_runtime(self._prev_runtime) return False + + +def _init_apply_runtime_from_env() -> Optional["ApplyRuntime"]: + """Initialize the global runtime from environment variables if configured.""" + fib_enable_apply = get_fib_enable_apply() + if not fib_enable_apply: + return + fib_dataset_path = get_fib_dataset_path() + trace_set = TraceSet.from_path(fib_dataset_path) + apply_config = ApplyConfig() + return ApplyRuntime(trace_set, apply_config, None) + + +_global_apply_runtime: Optional["ApplyRuntime"] = _init_apply_runtime_from_env() + + +def get_apply_runtime() -> Optional["ApplyRuntime"]: + """Get the global ApplyRuntime instance. + + Returns the singleton runtime instance, initializing it from environment + variables if it hasn't been created yet. + + Returns + ------- + Optional[ApplyRuntime] + The global runtime instance, or None if not initialized. + """ + return _global_apply_runtime + + +def set_apply_runtime(rt: Optional["ApplyRuntime"]) -> None: + """Set the global ApplyRuntime instance. + + Parameters + ---------- + rt : Optional[ApplyRuntime] + The runtime instance to set, or None to clear the global runtime. + """ + global _global_apply_runtime + _global_apply_runtime = rt diff --git a/flashinfer_bench/compile/builder.py b/flashinfer_bench/compile/builder.py index 5f4ebc40..fd956a42 100644 --- a/flashinfer_bench/compile/builder.py +++ b/flashinfer_bench/compile/builder.py @@ -1,6 +1,5 @@ from __future__ import annotations -import hashlib import os import re import tempfile @@ -31,18 +30,10 @@ def write_sources_to_temp(base: str, sources: list[SourceFile], pkg: Optional[st def create_pkg_name(sol: Solution, prefix: str = "") -> str: - # Normalize the solution name s = re.sub(r"[^0-9a-zA-Z_]", "_", sol.name) if not s or s[0].isdigit(): s = "_" + s - - # Hash the sources - h = hashlib.sha1() - for src in sol.sources: - h.update(src.path.encode()) - h.update(src.content.encode()) - - return prefix + s + "_" + h.hexdigest()[:6] + return prefix + s class BuildError(RuntimeError): diff --git a/flashinfer_bench/compile/builders/cuda_builder.py b/flashinfer_bench/compile/builders/cuda_builder.py index 84bc0dab..0189dc5a 100644 --- a/flashinfer_bench/compile/builders/cuda_builder.py +++ b/flashinfer_bench/compile/builders/cuda_builder.py @@ -70,7 +70,7 @@ def _get_package_paths(pkg_name: str, lib_names: Optional[List[str]] = None): CUDA_DEPS = { "cublas": ("nvidia.cublas", ["cublas", "cublasLt"]), "cudnn": ("nvidia.cudnn", ["cudnn"]), - "cutlass": ("flashinfer_bench._deps.cutlass", None), # Header-only + "cutlass": ("flashinfer_bench.thirdparty.cutlass", None), # Header-only } diff --git a/flashinfer_bench/compile/registry.py b/flashinfer_bench/compile/registry.py index 0e108e77..d2fb4d7a 100644 --- a/flashinfer_bench/compile/registry.py +++ b/flashinfer_bench/compile/registry.py @@ -15,6 +15,8 @@ def __init__(self, builders: Tuple[Builder, ...]) -> None: if not builders: raise ValueError("BuilderRegistry requires at least one builder") self._builders: Tuple[Builder, ...] = builders + # Cache: solution_name -> builder to avoid repeated can_build checks + self._solution_to_builder: dict[str, Builder] = {} def clear(self) -> None: for b in self._builders: @@ -22,11 +24,16 @@ def clear(self) -> None: b.clear_cache() except Exception: pass + self._solution_to_builder.clear() def build(self, defn: Definition, sol: Solution) -> Runnable: + builder = self._solution_to_builder.get(sol.name) + if builder is not None: + return builder.build(defn, sol) + for builder in self._builders: - # Choose the first if builder.can_build(sol): + self._solution_to_builder[sol.name] = builder return builder.build(defn, sol) raise BuildError(f"No registered builder can build solution '{sol.name}'") diff --git a/flashinfer_bench/data/trace_set.py b/flashinfer_bench/data/trace_set.py index 0da875c5..21759da7 100644 --- a/flashinfer_bench/data/trace_set.py +++ b/flashinfer_bench/data/trace_set.py @@ -44,6 +44,16 @@ class TraceSet: definition.""" traces: Dict[str, List[Trace]] = field(default_factory=dict) """The traces in the database. Map from definition name to all traces for that definition.""" + _solution_by_name: Dict[str, Solution] = field(default_factory=dict, init=False, repr=False) + """Fast lookup index: solution name -> Solution object. Automatically maintained.""" + + def __post_init__(self): + """Initialize the _solution_by_name index from existing solutions.""" + for solutions_list in self.solutions.values(): + for solution in solutions_list: + if solution.name in self._solution_by_name: + raise ValueError(f"Duplicate solution name found: {solution.name}") + self._solution_by_name[solution.name] = solution @property def definitions_path(self) -> Path: @@ -131,6 +141,7 @@ def from_path(cls: type["TraceSet"], path: Optional[str] = None) -> "TraceSet": raise ValueError(f"Duplicate solution name: {s.name}") seen_solutions.add(s.name) trace_set.solutions.setdefault(s.definition, []).append(s) + trace_set._solution_by_name[s.name] = s for p in sorted((trace_set.workloads_path.rglob("*.jsonl"))): for t in load_jsonl_file(Trace, p): @@ -174,9 +185,8 @@ def to_dict(self) -> Dict[str, Any]: def get_solution(self, name: str) -> Optional[Solution]: """Get a solution by name from all loaded solutions. - Searches across all solutions in the TraceSet to find one with the specified name. - Since solution names are unique across the entire dataset, this returns at most - one solution. + Uses an O(1) index lookup for fast retrieval. Since solution names are unique + across the entire dataset, this returns at most one solution. Parameters ---------- @@ -188,11 +198,7 @@ def get_solution(self, name: str) -> Optional[Solution]: Optional[Solution] The solution with the given name, or None if not found. """ - for solution_list in self.solutions.values(): - for solution in solution_list: - if solution.name == name: - return solution - return None + return self._solution_by_name.get(name) def filter_traces(self, def_name: str, atol: float = 1e-2, rtol: float = 1e-2) -> List[Trace]: """Filter traces for a definition based on error bounds. diff --git a/pyproject.toml b/pyproject.toml index b68fe093..88c2257e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ where = ["."] include = ["flashinfer_bench*"] [tool.setuptools.package-data] "flashinfer_bench" = ["py.typed"] -"flashinfer_bench._deps.cutlass" = ["include/**"] +"flashinfer_bench.thirdparty.cutlass" = ["include/**"] [tool.black] line-length = 100