diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml new file mode 100644 index 00000000..72e593c5 --- /dev/null +++ b/.github/workflows/build-release.yml @@ -0,0 +1,141 @@ +name: Build & Release Wheels + +on: + push: + tags: + - "v*" + workflow_dispatch: + +concurrency: + group: build-release-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-wheel: + name: "wheel / ${{ matrix.cuda }} / cp312 / ${{ matrix.arch }}" + runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }} + defaults: + run: + shell: bash + strategy: + fail-fast: false + matrix: + cuda: + - cu129 + - cu130 + arch: + - x86_64 + - aarch64 + container: + # UBI 8 provides the glibc 2.28 baseline required by manylinux_2_28. + image: "nvidia/cuda:${{ matrix.cuda == 'cu129' && '12.9.0' || '13.0.0' }}-devel-ubi8" + + steps: + - name: Free disk space + run: | + rm -rf /opt/hostedtoolcache /usr/local/lib/android /usr/share/dotnet \ + /usr/local/share/boost /opt/ghc 2>/dev/null || true + dnf clean all 2>/dev/null || true + df -h / || true + + - name: Install system dependencies + run: | + dnf install -y \ + git \ + gcc-toolset-13-gcc \ + gcc-toolset-13-gcc-c++ \ + python3.12 \ + python3.12-devel \ + python3.12-pip + dnf clean all + + - name: Checkout + uses: actions/checkout@v5 + with: + fetch-depth: 0 + submodules: recursive + + - name: Configure git safe directory + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" + + - name: Install Python dependencies + run: | + python3.12 -m pip install --no-cache-dir --upgrade pip + python3.12 -m pip install --no-cache-dir torch --index-url ${{ matrix.cuda == 'cu129' && 'https://download.pytorch.org/whl/cu129' || 'https://download.pytorch.org/whl/cu130' }} + python3.12 -m pip install --no-cache-dir setuptools wheel "setuptools_scm>=6.0" build ninja auditwheel patchelf + + - name: Compute version + id: version + run: | + if [[ "$GITHUB_REF" == refs/tags/v* ]]; then + BASE="${GITHUB_REF#refs/tags/v}" + else + # Strip any local segment (+gXXX) so we get a clean base + BASE=$(python3.12 -c "from setuptools_scm import get_version; print(get_version().split('+')[0])") + fi + echo "version=${BASE}+${{ matrix.cuda }}" >> "$GITHUB_OUTPUT" + + - name: Build fat-binary wheel + env: + CC: /opt/rh/gcc-toolset-13/root/usr/bin/gcc + CXX: /opt/rh/gcc-toolset-13/root/usr/bin/g++ + CUDAHOSTCXX: /opt/rh/gcc-toolset-13/root/usr/bin/g++ + CULA_BUILD_ALL_ARCHS: "1" + SETUPTOOLS_SCM_PRETEND_VERSION: "${{ steps.version.outputs.version }}" + NVCC_THREADS: "4" + MAX_JOBS: "4" + run: | + "$CC" --version + "$CXX" --version + python3.12 -m build --wheel --no-isolation --outdir dist-raw + + - name: Repair wheel for manylinux_2_28 + run: | + # These libraries are supplied by the NVIDIA driver, PyTorch, or + # PyTorch's CUDA runtime dependency and must remain external. + python3.12 -m auditwheel repair \ + --plat manylinux_2_28_${{ matrix.arch }} \ + --exclude libcuda.so.1 \ + --exclude 'libcudart.so.*' \ + --exclude 'libc10*.so' \ + --exclude 'libtorch*.so' \ + --wheel-dir dist \ + dist-raw/*.whl + + - name: Verify wheel + run: | + echo "Built wheel:" + ls -lh dist/*.whl + ls dist/*.whl | grep -q "+${{ matrix.cuda }}" \ + || { echo "ERROR: wheel name missing +${{ matrix.cuda }} suffix"; exit 1; } + ls dist/*.whl | grep -q "manylinux_2_28_${{ matrix.arch }}" \ + || { echo "ERROR: wheel is not tagged manylinux_2_28_${{ matrix.arch }}"; exit 1; } + python3.12 -m auditwheel show dist/*.whl + + - name: Upload wheel artifact + uses: actions/upload-artifact@v6 + with: + name: wheel-${{ matrix.cuda }}-${{ matrix.arch }} + path: dist/*.whl + + release: + name: Create GitHub Release + needs: [build-wheel] + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/v') + permissions: + contents: write + steps: + - name: Download all artifacts + uses: actions/download-artifact@v6 + with: + path: artifacts/ + + - name: Create release + uses: softprops/action-gh-release@v3 + with: + files: | + artifacts/wheel-*/*.whl + generate_release_notes: true + draft: true + prerelease: ${{ contains(github.ref, 'rc') || contains(github.ref, 'beta') || contains(github.ref, 'alpha') }} diff --git a/README.md b/README.md index 7bed61e2..09418811 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,16 @@ cuLA supports both **Hopper (SM90)** and **Blackwell (SM10X)** GPUs. > **Note:** The PyTorch CUDA version must match your system CUDA Toolkit version. Check with `nvcc --version` and `python -c "import torch; print(torch.version.cuda)"`. +### Pre-built Wheels + +Pre-built fat-binary wheels (SM90 + SM100 + SM103) are available on [GitHub Releases](https://github.com/inclusionAI/cuLA/releases). Linux wheels target `manylinux_2_28` and require glibc 2.28 or newer: + + pip install "cuda-linear-attention==+" -f https://github.com/inclusionAI/cuLA/releases/expanded_assets/ + +Replace `` with the release tag (e.g., `v0.2.0`), `` with the base version (e.g., `0.2.0`), and `` with your PyTorch CUDA build tag (e.g., `cu129` or `cu130`). Or download the `.whl` file directly from the [Releases page](https://github.com/inclusionAI/cuLA/releases) and install it with `pip install .whl`. + +### Build from Source + **Clone cuLA & dependencies:** ```bash @@ -47,6 +57,12 @@ pip install -e third_party/flash-linear-attention pip install -e . --no-build-isolation ``` +**Build fat wheel (SM90 + SM100 + SM103):** + +```bash +CULA_BUILD_ALL_ARCHS=1 python -m build --wheel --no-isolation +``` + ## Quick Start ### KDA (Kimi Delta Attention) — Blackwell (SM10X) @@ -239,4 +255,4 @@ No CUDA experience is required as long as you're a quick learner. For Q&A and discussion, you can join us through: - **Slack:** [cuLA Slack Community](https://join.slack.com/t/cula-hq/shared_invite/zt-3uaacvm9y-xJwZyGueeKxZRYQlj7~hxw) -- **WeChat:** The WeChat group has exceeded 200 members and can no longer be joined via QR code. To join, please send your WeChat ID to any of the following emails and we'll invite you: **chaofanyu@gmail.com** / **kevinzz08@foxmail.com** / **yzpag@gmail.com** / **haoc80996@gmail.com**. You can also ask someone already in the group to invite you directly. \ No newline at end of file +- **WeChat:** The WeChat group has exceeded 200 members and can no longer be joined via QR code. To join, please send your WeChat ID to any of the following emails and we'll invite you: **chaofanyu@gmail.com** / **kevinzz08@foxmail.com** / **yzpag@gmail.com** / **haoc80996@gmail.com**. You can also ask someone already in the group to invite you directly. diff --git a/csrc/api/kda_sm100.cu b/csrc/api/kda_sm100.cu index 7edca370..020d90ca 100644 --- a/csrc/api/kda_sm100.cu +++ b/csrc/api/kda_sm100.cu @@ -188,4 +188,10 @@ ChunkKDAFwdRecompWU( StaticPersistentTileScheduler::Params{tile_num, params.h_v, params.heads_per_group, params.num_sm, nullptr}; kda::sm100::run_kda_fwd_recomp_w_u_sm100(params, at::cuda::getCurrentCUDAStream()); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "cuLA SM100/SM103 kernels"; + m.def("chunk_kda_fwd_intra_cuda", &ChunkKDAFwdIntra); + m.def("recompute_w_u_cuda", &ChunkKDAFwdRecompWU); } \ No newline at end of file diff --git a/csrc/api/kda_sm90.cu b/csrc/api/kda_sm90.cu index 9e016eb1..d80df7cc 100644 --- a/csrc/api/kda_sm90.cu +++ b/csrc/api/kda_sm90.cu @@ -191,3 +191,8 @@ kda_fwd_prefill( return {output, output_state}; } + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "cuLA SM90 kernels"; + m.def("kda_fwd_prefill", &kda_fwd_prefill); +} diff --git a/csrc/api/pybind.cu b/csrc/api/pybind.cu deleted file mode 100644 index d14a41c5..00000000 --- a/csrc/api/pybind.cu +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2025-2026 Ant Group Co., Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#if defined(CULA_SM100_ENABLED) || defined(CULA_SM103_ENABLED) -void -ChunkKDAFwdIntra( - at::Tensor q, - at::Tensor k, - at::Tensor g, - at::Tensor beta, - at::Tensor cu_seqlens, - at::Tensor chunk_indices, - at::Tensor Aqk_out, - at::Tensor Akk_out, - at::Tensor tile_counter, - float scale, - int chunk_size, - bool use_tf32_inverse, - bool unified_gref); -void -ChunkKDAFwdRecompWU( - at::Tensor k, - at::Tensor v, - at::Tensor beta, - at::Tensor A, - at::Tensor g, - at::Tensor cu_seqlens, - at::Tensor chunk_indices, - at::Tensor w_out, - at::Tensor u_out, - at::Tensor kg_out, - int chunk_size, - std::optional q, - std::optional qg_out); -#endif - -#if defined(CULA_SM90A_ENABLED) -std::tuple> -kda_fwd_prefill( - std::optional output_, - std::optional output_state_, - torch::Tensor const& q, - torch::Tensor const& k, - torch::Tensor const& v, - std::optional input_state_, - std::optional alpha_, - std::optional beta_, - torch::Tensor const& cu_seqlens, - torch::Tensor workspace_buffer, - float scale, - bool output_final_state, - bool safe_gate); -#endif - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "cuLA"; -#if defined(CULA_SM100_ENABLED) || defined(CULA_SM103_ENABLED) - m.def("chunk_kda_fwd_intra_cuda", &ChunkKDAFwdIntra); - m.def("recompute_w_u_cuda", &ChunkKDAFwdRecompWU); -#endif -#if defined(CULA_SM90A_ENABLED) - m.def("kda_fwd_prefill", &kda_fwd_prefill); -#endif -} diff --git a/cula/__init__.py b/cula/__init__.py index 7272e289..6e13aa13 100644 --- a/cula/__init__.py +++ b/cula/__init__.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.0" +try: + from cula._version import version as __version__ +except ImportError: + __version__ = "0.1.0" from cula.ops.lightning_attn_sm100 import LinearAttentionChunkwiseDecay diff --git a/cula/cudac.py b/cula/cudac.py new file mode 100644 index 00000000..28fb5f38 --- /dev/null +++ b/cula/cudac.py @@ -0,0 +1,102 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified interface to per-architecture CUDA extensions. + +Downstream code can continue to use ``import cula.cudac as cula_cuda`` +and call ``cula_cuda.kda_fwd_prefill(...)`` or +``cula_cuda.chunk_kda_fwd_intra_cuda(...)`` without knowing which +extension provides the function. + +Loading is **once per process**: the first attribute access checks the +currently active CUDA device, imports the matching ``cula._cudac_sm*`` +extension, and caches the discovered callables on the module instance. +Changing the active CUDA device to a different architecture after a +process has already loaded ``cula.cudac`` will therefore not be picked +up -- callers that need a different extension must restart Python. +""" + +import importlib +import sys +import threading +from types import ModuleType + + +def _current_device_extension() -> tuple[str, str]: + try: + import torch + except ImportError as exc: + raise ImportError("cuLA CUDA extensions require PyTorch to detect the current GPU.") from exc + + if not torch.cuda.is_available(): + raise RuntimeError("cuLA CUDA extensions require a visible CUDA GPU, but torch.cuda.is_available() is False.") + + device = torch.cuda.current_device() + prop = torch.cuda.get_device_properties(device) + sm_label = f"sm_{prop.major}{prop.minor}" + if prop.major == 10 and prop.minor in (0, 3): + return "cula._cudac_sm100", sm_label + if prop.major == 9 and prop.minor == 0: + return "cula._cudac_sm90", sm_label + raise RuntimeError(f"Unsupported CUDA compute capability {sm_label}. Supported architectures: sm_100, sm_103, sm_90.") + + +class _CudacProxy(ModuleType): + """Lazy proxy that exposes functions from the current GPU arch extension.""" + + def __init__(self): + super().__init__(__name__) + self.__path__ = [] + self._modules_loaded = False + self._funcs: dict[str, object] = {} + self._lock = threading.Lock() + + def _load(self): + if self._modules_loaded: + return + with self._lock: + if self._modules_loaded: + return + ext_name, sm_label = _current_device_extension() + try: + mod = importlib.import_module(ext_name) + for attr in dir(mod): + if not attr.startswith("_"): + self._funcs[attr] = getattr(mod, attr) + except (ImportError, AttributeError, OSError) as exc: + raise ImportError( + f"The cuLA CUDA extension for the current GPU ({sm_label}) could not be imported. " + f"Extension {ext_name} failed with: {exc}. " + "Please make sure cuLA is compiled correctly." + ) from exc + self.__dict__.update(self._funcs) + self._modules_loaded = True + + def __getattr__(self, name: str): + if name.startswith("_"): + raise AttributeError(name) + self._load() + try: + return self._funcs[name] + except KeyError: + raise AttributeError(f"module 'cula.cudac' has no attribute '{name}'") from None + + def __dir__(self): + self._load() + return list(self._funcs.keys()) + + +_proxy = _CudacProxy() +_proxy.__dict__.update({k: globals().get(k) for k in ("__spec__", "__file__", "__package__", "__loader__")}) +sys.modules[__name__] = _proxy diff --git a/pyproject.toml b/pyproject.toml index ef1a531b..fe93e562 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,9 +84,6 @@ force-sort-within-sections = false "cula/kda/blackwell_fused_fwd.py" = ["F821"] [tool.setuptools_scm] -# write generated version into package for runtime access write_to = "cula/_version.py" -# add a date-based local suffix when needed -local_scheme = "node-and-date" -# fallback for non-git sources +local_scheme = "no-local-version" fallback_version = "0.1.0" diff --git a/scripts/build_wheel.sh b/scripts/build_wheel.sh index 42b35665..79ac3305 100755 --- a/scripts/build_wheel.sh +++ b/scripts/build_wheel.sh @@ -18,10 +18,19 @@ cd "$REPO_ROOT" # Parse args ISOLATION_FLAG="--no-isolation" -if [[ "${1:-}" == "--isolated" ]]; then - ISOLATION_FLAG="" - echo "[build_wheel] Using isolated build environment" -else +for arg in "$@"; do + case "$arg" in + --isolated) + ISOLATION_FLAG="" + echo "[build_wheel] Using isolated build environment" + ;; + --fat) + export CULA_BUILD_ALL_ARCHS=1 + echo "[build_wheel] Fat binary: building for all SM architectures" + ;; + esac +done +if [[ "$ISOLATION_FLAG" == "--no-isolation" ]]; then echo "[build_wheel] Using current environment (--no-isolation)" fi @@ -33,6 +42,7 @@ rm -rf dist build *.egg-info echo "[build_wheel] Python: $(python -V 2>&1)" echo "[build_wheel] torch: $(python -c 'import torch; print(torch.__version__)' 2>/dev/null || echo 'not installed')" echo "[build_wheel] CUDA: $(nvcc --version 2>/dev/null | grep 'release' | sed 's/.*release //' | sed 's/,.*//' || echo 'not found')" +echo "[build_wheel] Fat binary: ${CULA_BUILD_ALL_ARCHS:-0}" # Build wheel echo "[build_wheel] Building wheel..." diff --git a/setup.py b/setup.py index f7b11b95..78c61e5c 100644 --- a/setup.py +++ b/setup.py @@ -46,13 +46,15 @@ def detect_gpu_archs() -> tuple[bool, bool, bool]: def resolve_disable_flag(env_name: str, detected: bool) -> bool: """ Resolve whether to disable a given SM target. + - If CULA_BUILD_ALL_ARCHS is set, all targets are enabled unconditionally. - If the environment variable is explicitly set, honour it. - Otherwise, disable the target when no matching GPU is detected. """ + if os.getenv("CULA_BUILD_ALL_ARCHS", "0") == "1": + return False env_val = os.getenv(env_name) if env_val is not None: return env_val.lower() in ["true", "1", "y", "yes"] - # Auto-detect: disable if no matching device found disable = not detected if disable: print(f" No matching GPU detected; auto-setting {env_name}=1 (disable). Set {env_name}=0 to override.") @@ -66,7 +68,11 @@ def get_features_args(): USE_FAST_MATH = os.getenv("CULA_USE_FAST_MATH", "1") == "1" -print("Detecting GPU architectures...") +if os.getenv("CULA_BUILD_ALL_ARCHS", "0") == "1": + print("CULA_BUILD_ALL_ARCHS=1: enabling all SM targets (sm90a, sm100a, sm103a)") +else: + print("Detecting GPU architectures...") + _has_sm100, _has_sm103, _has_sm90 = detect_gpu_archs() DISABLE_SM100 = resolve_disable_flag("CULA_DISABLE_SM100", _has_sm100) DISABLE_SM103 = resolve_disable_flag("CULA_DISABLE_SM103", _has_sm103) @@ -111,26 +117,6 @@ def assert_blackwell_build_env() -> None: ) -def get_arch_flags(): - major, minor = get_nvcc_version() - print(f"Compiling using NVCC {major}.{minor}") - - # Validate Blackwell build environment - assert_blackwell_build_env() - - arch_flags = [] - if not DISABLE_SM100: - arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) - arch_flags.extend(["-DCULA_SM100_ENABLED"]) - if not DISABLE_SM103: - arch_flags.extend(["-gencode", "arch=compute_103a,code=sm_103a"]) - arch_flags.extend(["-DCULA_SM103_ENABLED"]) - if not DISABLE_SM90: - arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"]) - arch_flags.extend(["-DCULA_SM90A_ENABLED"]) - return arch_flags - - def get_nvcc_thread_args(): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return ["--threads", nvcc_threads] @@ -145,61 +131,84 @@ def get_nvcc_thread_args(): else: cxx_args = ["-O3", "-std=c++20", "-DNDEBUG", "-Wno-deprecated-declarations"] -cuda_sources = [ - "csrc/api/pybind.cu", +nvcc_common_args = [ + "-O3", + "-std=c++20", + "-DNDEBUG", + # "-D_USE_MATH_DEFINES", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "-lineinfo", + "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", + "-diag-suppress=3189", ] + +include_dirs = [ + Path(this_dir) / "csrc", + Path(this_dir) / "csrc" / "kerutils" / "include", + Path(this_dir) / "csrc" / "cutlass" / "include", + Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", +] + +major, minor = get_nvcc_version() +print(f"Compiling using NVCC {major}.{minor}") +assert_blackwell_build_env() + +ext_modules = [] + if not DISABLE_SM100 or not DISABLE_SM103: - cuda_sources.extend( - [ - "csrc/api/kda_sm100.cu", - "csrc/kda/sm100/kda_fwd_sm100.cu", - ] - ) -if not DISABLE_SM90: - cuda_sources.extend( - [ - "csrc/api/kda_sm90.cu", - "csrc/kda/sm90/kda_fwd_sm90.cu", - "csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu", - ] + sm100_arch_flags = [] + if not DISABLE_SM100: + sm100_arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) + if not DISABLE_SM103: + sm100_arch_flags.extend(["-gencode", "arch=compute_103a,code=sm_103a"]) + + ext_modules.append( + CUDAExtension( + name="cula._cudac_sm100", + sources=[ + "csrc/api/kda_sm100.cu", + "csrc/kda/sm100/kda_fwd_sm100.cu", + ], + extra_compile_args={ + "cxx": cxx_args + get_features_args(), + "nvcc": nvcc_common_args + + get_features_args() + + sm100_arch_flags + + get_nvcc_thread_args() + + (["--use_fast_math"] if USE_FAST_MATH else []), + }, + include_dirs=include_dirs, + ) ) -ext_modules = [] -ext_modules.append( - CUDAExtension( - name="cula.cudac", - sources=cuda_sources, - extra_compile_args={ - "cxx": cxx_args + get_features_args(), - "nvcc": [ - "-O3", - "-std=c++20", - "-DNDEBUG", - # "-D_USE_MATH_DEFINES", - "-Wno-deprecated-declarations", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "-lineinfo", - "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", - "-diag-suppress=3189", # suppress the warning of torch in C++ 20 - ] - + get_features_args() - + get_arch_flags() - + get_nvcc_thread_args() - + (["--use_fast_math"] if USE_FAST_MATH else []), - }, - include_dirs=[ - Path(this_dir) / "csrc", - Path(this_dir) / "csrc" / "kerutils" / "include", - Path(this_dir) / "csrc" / "cutlass" / "include", - Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", - ], +if not DISABLE_SM90: + sm90_arch_flags = ["-gencode", "arch=compute_90a,code=sm_90a", "-DCULA_SM90A_ENABLED"] + + ext_modules.append( + CUDAExtension( + name="cula._cudac_sm90", + sources=[ + "csrc/api/kda_sm90.cu", + "csrc/kda/sm90/kda_fwd_sm90.cu", + "csrc/kda/sm90/kda_fwd_sm90_safe_gate.cu", + ], + extra_compile_args={ + "cxx": cxx_args + get_features_args(), + "nvcc": nvcc_common_args + + get_features_args() + + sm90_arch_flags + + get_nvcc_thread_args() + + (["--use_fast_math"] if USE_FAST_MATH else []), + }, + include_dirs=include_dirs, + ) ) -) setup( name="cuda-linear-attention", diff --git a/tests/conftest.py b/tests/conftest.py index f144c10b..a9338aca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import re + import pytest import torch @@ -56,9 +57,5 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_slow) continue callspec = getattr(item, "callspec", None) - if ( - callspec is not None - and callspec.params.get("disable_recompute") - and "kda_fast_norecomp" not in item.keywords - ): + if callspec is not None and callspec.params.get("disable_recompute") and "kda_fast_norecomp" not in item.keywords: item.add_marker(skip_fast_norecomp)