Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions .github/workflows/build-release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
name: Build & Release Wheels

on:
push:
tags:
- "v*"
workflow_dispatch:

concurrency:
group: build-release-${{ github.ref }}
cancel-in-progress: true

jobs:
build-wheel:
name: "wheel / ${{ matrix.cuda }} / cp312 / ${{ matrix.arch }}"
runs-on: ${{ matrix.arch == 'aarch64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }}
defaults:
run:
shell: bash
strategy:
fail-fast: false
matrix:
cuda:
- cu129
- cu130
arch:
- x86_64
- aarch64
container:
# UBI 8 provides the glibc 2.28 baseline required by manylinux_2_28.
image: "nvidia/cuda:${{ matrix.cuda == 'cu129' && '12.9.0' || '13.0.0' }}-devel-ubi8"

steps:
- name: Free disk space
run: |
rm -rf /opt/hostedtoolcache /usr/local/lib/android /usr/share/dotnet \
/usr/local/share/boost /opt/ghc 2>/dev/null || true
dnf clean all 2>/dev/null || true
df -h / || true

- name: Install system dependencies
run: |
dnf install -y \
git \
gcc-toolset-13-gcc \
gcc-toolset-13-gcc-c++ \
python3.12 \
python3.12-devel \
python3.12-pip
dnf clean all

- name: Checkout
uses: actions/checkout@v5
with:
fetch-depth: 0
submodules: recursive

- name: Configure git safe directory
run: git config --global --add safe.directory "$GITHUB_WORKSPACE"

- name: Install Python dependencies
run: |
python3.12 -m pip install --no-cache-dir --upgrade pip
python3.12 -m pip install --no-cache-dir torch --index-url ${{ matrix.cuda == 'cu129' && 'https://download.pytorch.org/whl/cu129' || 'https://download.pytorch.org/whl/cu130' }}
python3.12 -m pip install --no-cache-dir setuptools wheel "setuptools_scm>=6.0" build ninja auditwheel patchelf

- name: Compute version
id: version
run: |
if [[ "$GITHUB_REF" == refs/tags/v* ]]; then
BASE="${GITHUB_REF#refs/tags/v}"
else
# Strip any local segment (+gXXX) so we get a clean base
BASE=$(python3.12 -c "from setuptools_scm import get_version; print(get_version().split('+')[0])")
fi
echo "version=${BASE}+${{ matrix.cuda }}" >> "$GITHUB_OUTPUT"

- name: Build fat-binary wheel
env:
CC: /opt/rh/gcc-toolset-13/root/usr/bin/gcc
CXX: /opt/rh/gcc-toolset-13/root/usr/bin/g++
CUDAHOSTCXX: /opt/rh/gcc-toolset-13/root/usr/bin/g++
CULA_BUILD_ALL_ARCHS: "1"
SETUPTOOLS_SCM_PRETEND_VERSION: "${{ steps.version.outputs.version }}"
NVCC_THREADS: "4"
MAX_JOBS: "4"
run: |
"$CC" --version
"$CXX" --version
python3.12 -m build --wheel --no-isolation --outdir dist-raw

- name: Repair wheel for manylinux_2_28
run: |
# These libraries are supplied by the NVIDIA driver, PyTorch, or
# PyTorch's CUDA runtime dependency and must remain external.
python3.12 -m auditwheel repair \
--plat manylinux_2_28_${{ matrix.arch }} \
--exclude libcuda.so.1 \
--exclude 'libcudart.so.*' \
--exclude 'libc10*.so' \
--exclude 'libtorch*.so' \
--wheel-dir dist \
dist-raw/*.whl

- name: Verify wheel
run: |
echo "Built wheel:"
ls -lh dist/*.whl
ls dist/*.whl | grep -q "+${{ matrix.cuda }}" \
|| { echo "ERROR: wheel name missing +${{ matrix.cuda }} suffix"; exit 1; }
ls dist/*.whl | grep -q "manylinux_2_28_${{ matrix.arch }}" \
|| { echo "ERROR: wheel is not tagged manylinux_2_28_${{ matrix.arch }}"; exit 1; }
python3.12 -m auditwheel show dist/*.whl

- name: Upload wheel artifact
uses: actions/upload-artifact@v6
with:
name: wheel-${{ matrix.cuda }}-${{ matrix.arch }}
path: dist/*.whl

release:
name: Create GitHub Release
needs: [build-wheel]
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/v')
permissions:
contents: write
steps:
- name: Download all artifacts
uses: actions/download-artifact@v6
with:
path: artifacts/

- name: Create release
uses: softprops/action-gh-release@v3
with:
files: |
artifacts/wheel-*/*.whl
generate_release_notes: true
draft: true
prerelease: ${{ contains(github.ref, 'rc') || contains(github.ref, 'beta') || contains(github.ref, 'alpha') }}
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ cuLA supports both **Hopper (SM90)** and **Blackwell (SM10X)** GPUs.

> **Note:** The PyTorch CUDA version must match your system CUDA Toolkit version. Check with `nvcc --version` and `python -c "import torch; print(torch.version.cuda)"`.

### Pre-built Wheels

Pre-built fat-binary wheels (SM90 + SM100 + SM103) are available on [GitHub Releases](https://github.com/inclusionAI/cuLA/releases). Linux wheels target `manylinux_2_28` and require glibc 2.28 or newer:

pip install "cuda-linear-attention==<VERSION>+<CUDA_TAG>" -f https://github.com/inclusionAI/cuLA/releases/expanded_assets/<TAG>

Replace `<TAG>` with the release tag (e.g., `v0.2.0`), `<VERSION>` with the base version (e.g., `0.2.0`), and `<CUDA_TAG>` with your PyTorch CUDA build tag (e.g., `cu129` or `cu130`). Or download the `.whl` file directly from the [Releases page](https://github.com/inclusionAI/cuLA/releases) and install it with `pip install <filename>.whl`.

### Build from Source

**Clone cuLA & dependencies:**

```bash
Expand All @@ -47,6 +57,12 @@ pip install -e third_party/flash-linear-attention
pip install -e . --no-build-isolation
```

**Build fat wheel (SM90 + SM100 + SM103):**

```bash
CULA_BUILD_ALL_ARCHS=1 python -m build --wheel --no-isolation
```

## Quick Start

### KDA (Kimi Delta Attention) β€” Blackwell (SM10X)
Expand Down Expand Up @@ -239,4 +255,4 @@ No CUDA experience is required as long as you're a quick learner.
For Q&A and discussion, you can join us through:

- **Slack:** [cuLA Slack Community](https://join.slack.com/t/cula-hq/shared_invite/zt-3uaacvm9y-xJwZyGueeKxZRYQlj7~hxw)
- **WeChat:** The WeChat group has exceeded 200 members and can no longer be joined via QR code. To join, please send your WeChat ID to any of the following emails and we'll invite you: **chaofanyu@gmail.com** / **kevinzz08@foxmail.com** / **yzpag@gmail.com** / **haoc80996@gmail.com**. You can also ask someone already in the group to invite you directly.
- **WeChat:** The WeChat group has exceeded 200 members and can no longer be joined via QR code. To join, please send your WeChat ID to any of the following emails and we'll invite you: **chaofanyu@gmail.com** / **kevinzz08@foxmail.com** / **yzpag@gmail.com** / **haoc80996@gmail.com**. You can also ask someone already in the group to invite you directly.
6 changes: 6 additions & 0 deletions csrc/api/kda_sm100.cu
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,10 @@ ChunkKDAFwdRecompWU(
StaticPersistentTileScheduler::Params{tile_num, params.h_v, params.heads_per_group, params.num_sm, nullptr};

kda::sm100::run_kda_fwd_recomp_w_u_sm100(params, at::cuda::getCurrentCUDAStream());
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "cuLA SM100/SM103 kernels";
m.def("chunk_kda_fwd_intra_cuda", &ChunkKDAFwdIntra);
m.def("recompute_w_u_cuda", &ChunkKDAFwdRecompWU);
}
5 changes: 5 additions & 0 deletions csrc/api/kda_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,8 @@ kda_fwd_prefill(

return {output, output_state};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "cuLA SM90 kernels";
m.def("kda_fwd_prefill", &kda_fwd_prefill);
}
80 changes: 0 additions & 80 deletions csrc/api/pybind.cu

This file was deleted.

5 changes: 4 additions & 1 deletion cula/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.1.0"
try:
from cula._version import version as __version__
except ImportError:
__version__ = "0.1.0"

from cula.ops.lightning_attn_sm100 import LinearAttentionChunkwiseDecay

Expand Down
102 changes: 102 additions & 0 deletions cula/cudac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2025-2026 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unified interface to per-architecture CUDA extensions.

Downstream code can continue to use ``import cula.cudac as cula_cuda``
and call ``cula_cuda.kda_fwd_prefill(...)`` or
``cula_cuda.chunk_kda_fwd_intra_cuda(...)`` without knowing which
extension provides the function.

Loading is **once per process**: the first attribute access checks the
currently active CUDA device, imports the matching ``cula._cudac_sm*``
extension, and caches the discovered callables on the module instance.
Changing the active CUDA device to a different architecture after a
process has already loaded ``cula.cudac`` will therefore not be picked
up -- callers that need a different extension must restart Python.
"""

import importlib
import sys
import threading
from types import ModuleType


def _current_device_extension() -> tuple[str, str]:
try:
import torch
except ImportError as exc:
raise ImportError("cuLA CUDA extensions require PyTorch to detect the current GPU.") from exc

if not torch.cuda.is_available():
raise RuntimeError("cuLA CUDA extensions require a visible CUDA GPU, but torch.cuda.is_available() is False.")

device = torch.cuda.current_device()
prop = torch.cuda.get_device_properties(device)
sm_label = f"sm_{prop.major}{prop.minor}"
if prop.major == 10 and prop.minor in (0, 3):
return "cula._cudac_sm100", sm_label
if prop.major == 9 and prop.minor == 0:
return "cula._cudac_sm90", sm_label
raise RuntimeError(f"Unsupported CUDA compute capability {sm_label}. Supported architectures: sm_100, sm_103, sm_90.")


class _CudacProxy(ModuleType):
"""Lazy proxy that exposes functions from the current GPU arch extension."""

def __init__(self):
super().__init__(__name__)
self.__path__ = []
self._modules_loaded = False
self._funcs: dict[str, object] = {}
self._lock = threading.Lock()

def _load(self):
if self._modules_loaded:
return
with self._lock:
if self._modules_loaded:
return
ext_name, sm_label = _current_device_extension()
try:
mod = importlib.import_module(ext_name)
for attr in dir(mod):
if not attr.startswith("_"):
self._funcs[attr] = getattr(mod, attr)
except (ImportError, AttributeError, OSError) as exc:
raise ImportError(
f"The cuLA CUDA extension for the current GPU ({sm_label}) could not be imported. "
f"Extension {ext_name} failed with: {exc}. "
"Please make sure cuLA is compiled correctly."
) from exc
self.__dict__.update(self._funcs)
self._modules_loaded = True

def __getattr__(self, name: str):
if name.startswith("_"):
raise AttributeError(name)
self._load()
try:
return self._funcs[name]
except KeyError:
raise AttributeError(f"module 'cula.cudac' has no attribute '{name}'") from None

def __dir__(self):
self._load()
return list(self._funcs.keys())


_proxy = _CudacProxy()
_proxy.__dict__.update({k: globals().get(k) for k in ("__spec__", "__file__", "__package__", "__loader__")})
sys.modules[__name__] = _proxy
Comment thread
tongke6 marked this conversation as resolved.
Loading