diff --git a/.github/workflows/test-whl.yaml b/.github/workflows/test-whl.yaml index f62531e9..1b6a4850 100644 --- a/.github/workflows/test-whl.yaml +++ b/.github/workflows/test-whl.yaml @@ -59,6 +59,12 @@ jobs: docker exec flydsl_test bash -c "mkdir -p /flydsl/build-fly/bin && cp /dist/fly-opt /flydsl/build-fly/bin/fly-opt && chmod +x /flydsl/build-fly/bin/fly-opt" docker exec flydsl_test bash -c "git config --global --add safe.directory /flydsl" + - name: Install JAX ROCm (optional) + continue-on-error: true + run: | + docker exec flydsl_test bash -c "python3 -m pip install jax jaxlib jax-rocm7-pjrt jax-rocm7-plugin 2>&1 | tail -5" + docker exec flydsl_test bash -c "python3 -c 'import jax; print(\"JAX\", jax.__version__, \"backend:\", jax.default_backend())' 2>/dev/null || echo 'JAX installation failed (non-blocking)'" + - name: Run tests id: tests run: | diff --git a/examples/04-vectorAdd-jax.py b/examples/04-vectorAdd-jax.py new file mode 100644 index 00000000..16757499 --- /dev/null +++ b/examples/04-vectorAdd-jax.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Vector addition example using FlyDSL with JAX arrays. + +This is the JAX equivalent of ``01-vectorAdd.py``. It demonstrates both: + +- **Level 1** (eager): wrapping JAX arrays via ``from_jax`` and calling + a ``@flyc.jit`` function directly. +- **Level 2** (``jax.jit``): wrapping a ``@flyc.jit`` function with + ``jax_kernel`` so it can be called inside ``jax.jit``. + +Requirements: + pip install jax[rocm] # ROCm backend for AMD GPUs +""" + +import sys + +try: + import jax + import jax.numpy as jnp +except ImportError: + print("SKIP: JAX not installed") + sys.exit(0) + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.jax import from_jax, jax_kernel + + +# ---------- Kernel definition (identical to 01-vectorAdd.py) ---------- + + +@flyc.kernel +def vectorAddKernel( + A: fx.Tensor, + B: fx.Tensor, + C: fx.Tensor, + block_dim: fx.Constexpr[int], +): + bid = fx.block_idx.x + tid = fx.thread_idx.x + fx.printf("[kernel] bid={}, tid={}", bid, tid) + + A = fx.rocdl.make_buffer_tensor(A) + + tA = fx.logical_divide(A, fx.make_layout(block_dim, 1)) + tB = fx.logical_divide(B, fx.make_layout(block_dim, 1)) + tC = fx.logical_divide(C, fx.make_layout(block_dim, 1)) + + tA = fx.slice(tA, (None, bid)) + tB = fx.slice(tB, (None, bid)) + tC = fx.slice(tC, (None, bid)) + tA = fx.logical_divide(tA, fx.make_layout(1, 1)) + tB = fx.logical_divide(tB, fx.make_layout(1, 1)) + tC = fx.logical_divide(tC, fx.make_layout(1, 1)) + + RABMemRefTy = fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + + copyAtom = fx.make_copy_atom(fx.UniversalCopy32b(), fx.Float32) + copyAtomBuffer = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), fx.Float32) + + rA = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + rB = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + rC = fx.memref_alloca(RABMemRefTy, fx.make_layout(1, 1)) + + fx.copy_atom_call(copyAtomBuffer, fx.slice(tA, (None, tid)), rA) + fx.copy_atom_call(copyAtom, fx.slice(tB, (None, tid)), rB) + + vC = fx.arith.addf(fx.memref_load_vec(rA), fx.memref_load_vec(rB)) + fx.memref_store_vec(vC, rC) + + fx.copy_atom_call(copyAtom, rC, fx.slice(tC, (None, tid))) + + +# ---------- JIT launcher (identical to 01-vectorAdd.py) ---------- + + +@flyc.jit +def vectorAdd( + A: fx.Tensor, + B: fx.Tensor, + C, + n: fx.Int32, + const_n: fx.Constexpr[int], + stream: fx.Stream = fx.Stream(None), +): + block_dim = 64 + grid_x = (n + block_dim - 1) // block_dim + fx.printf("> vectorAdd: n={}, grid_x={}", n, grid_x) + + vectorAddKernel(A, B, C, block_dim).launch( + grid=(grid_x, 1, 1), block=[block_dim, 1, 1], stream=stream + ) + + +# ---------- JAX eager execution ---------- + + +def run_eager_jax(): + """Eager-mode execution with JAX arrays.""" + n = 128 + + # Create JAX arrays on the GPU. + key = jax.random.PRNGKey(42) + A = jax.random.randint(key, (n,), 0, 10).astype(jnp.float32) + B = jax.random.randint(jax.random.PRNGKey(7), (n,), 0, 10).astype(jnp.float32) + C = jnp.zeros(n, dtype=jnp.float32) + + # Wrap JAX arrays for FlyDSL. + tA = from_jax(A).mark_layout_dynamic(leading_dim=0, divisibility=4) + tB = from_jax(B) + tC = from_jax(C) + + # Ensure JAX computations are complete before launching the kernel. + jax.block_until_ready(A) + jax.block_until_ready(B) + + # Launch kernel (uses default HIP stream). + vectorAdd(tA, tB, tC, n, n + 1) + + # Synchronize and verify. + # Note: C was written to in-place on the GPU. We need to read it back. + # Since FlyDSL wrote to C's device buffer directly, the JAX array C + # still points to the same buffer. + expected = A + B + is_close = jnp.allclose(C, expected) + print(f"[JAX Eager] Result correct: {is_close}") + if not is_close: + print(" A:", A[:16]) + print(" B:", B[:16]) + print(" C:", C[:16]) + print(" expected:", expected[:16]) + return bool(is_close) + + +# ---------- Level 2: jax.jit integration via jax_kernel ---------- + + +# Wrap the @flyc.jit function so it can be used inside jax.jit. +# - out_shapes: tells JAX the shape and dtype of each output. +# - constexpr_kwargs: compile-time constants (Constexpr parameters). +# - runtime_scalars: non-tensor runtime args baked into the compiled kernel. +# The scalar 'n' is traced with value 128 during FlyDSL compilation. +vectorAdd_jax = jax_kernel( + vectorAdd, + out_shapes=lambda a, b: [ + (a.shape, a.dtype), # output C has same shape/dtype as A + ], + constexpr_kwargs={"const_n": 129}, + runtime_scalars={"n": 128}, +) + + +def run_jit_jax(): + """jax.jit-compiled execution with JAX arrays. + + The FlyDSL kernel is compiled once and registered as an XLA custom call. + Subsequent calls reuse the compiled kernel with zero Python overhead. + """ + n = 128 + + key = jax.random.PRNGKey(42) + A = jax.random.randint(key, (n,), 0, 10).astype(jnp.float32) + B = jax.random.randint(jax.random.PRNGKey(7), (n,), 0, 10).astype(jnp.float32) + + @jax.jit + def add_vectors(a, b): + # vectorAdd_jax receives only JAX arrays; scalar args (n) are baked + # into the compiled kernel via runtime_scalars. + (c,) = vectorAdd_jax(a, b) + return c + + C = add_vectors(A, B) + expected = A + B + is_close = jnp.allclose(C, expected) + print(f"[JAX jit] Result correct: {is_close}") + if not is_close: + print(" A:", A[:16]) + print(" B:", B[:16]) + print(" C:", C[:16]) + print(" expected:", expected[:16]) + return bool(is_close) + + +if __name__ == "__main__": + print("=" * 50) + print("Test 1: FlyDSL + JAX (Eager)") + print("=" * 50) + ok1 = run_eager_jax() + + print() + print("=" * 50) + print("Test 2: FlyDSL + JAX (jax.jit)") + print("=" * 50) + try: + ok2 = run_jit_jax() + except Exception as e: + print(f"[JAX jit] FAILED with exception: {e}") + ok2 = False + + print(f"\nAll passed: {ok1 and ok2}") diff --git a/examples/05-tiledCopy-jax.py b/examples/05-tiledCopy-jax.py new file mode 100644 index 00000000..838d8b02 --- /dev/null +++ b/examples/05-tiledCopy-jax.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Tiled copy example using FlyDSL with JAX arrays. + +JAX equivalent of ``02-tiledCopy.py``. Demonstrates tiled copy with +partitioned tensors using the layout algebra DSL, running on JAX arrays. + +Requirements: + pip install jax[rocm] +""" + +import sys + +try: + import jax + import jax.numpy as jnp +except ImportError: + print("SKIP: JAX not installed") + sys.exit(0) + +import numpy as np + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.jax import from_jax, jax_kernel + + +# ---------- Kernel (identical to 02-tiledCopy.py) ---------- + + +@flyc.kernel +def copy_kernel( + A: fx.Tensor, + B: fx.Tensor, +): + tid = fx.thread_idx.x + bid = fx.block_idx.x + + block_m = 8 + block_n = 24 + tile = fx.make_tile([fx.make_layout(block_m, 1), fx.make_layout(block_n, 1)]) + + A = fx.rocdl.make_buffer_tensor(A) + B = fx.rocdl.make_buffer_tensor(B) + + bA = fx.zipped_divide(A, tile) + bB = fx.zipped_divide(B, tile) + bA = fx.slice(bA, (None, bid)) + bB = fx.slice(bB, (None, bid)) + + thr_layout = fx.make_layout((4, 1), (1, 1)) + val_layout = fx.make_layout((1, 8), (1, 1)) + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), fx.Float32) + layout_thr_val = fx.logical_product(thr_layout, val_layout) + layout_thr_val = fx.raked_product(thr_layout, val_layout) + + tile_mn = fx.make_tile(4, 8) + + tiled_copy = fx.make_tiled_copy(copy_atom, layout_thr_val, tile_mn) + thr_copy = tiled_copy.get_slice(tid) + + partition_src = thr_copy.partition_S(bA) + partition_dst = thr_copy.partition_D(bB) + + frag = fx.make_fragment_like(partition_src) + + fx.copy(copy_atom, partition_src, frag) + fx.copy(copy_atom, frag, partition_dst) + + +# ---------- JIT launcher ---------- + + +@flyc.jit +def tiledCopy( + A: fx.Tensor, + B: fx.Tensor, + stream: fx.Stream = fx.Stream(None), +): + copy_kernel(A, B).launch(grid=(15, 1, 1), block=(4, 1, 1), stream=stream) + + +# ---------- Eager ---------- + + +def run_eager(): + M, N = 8 * 3, 24 * 5 + A = jnp.arange(M * N, dtype=jnp.float32).reshape(M, N) + B = jnp.zeros((M, N), dtype=jnp.float32) + + tA = from_jax(A) + tB = from_jax(B) + + jax.block_until_ready(A) + tiledCopy(tA, tB) + + is_correct = np.allclose(np.asarray(A), np.asarray(B)) + print(f"[Eager] Result correct: {is_correct}") + if not is_correct: + print(" A[:2,:8]:", np.asarray(A)[:2, :8]) + print(" B[:2,:8]:", np.asarray(B)[:2, :8]) + return is_correct + + +# ---------- jax.jit ---------- + + +tiledCopy_jax = jax_kernel( + tiledCopy, + out_shapes=lambda a: [(a.shape, a.dtype)], +) + + +def run_jit(): + M, N = 8 * 3, 24 * 5 + A = jnp.arange(M * N, dtype=jnp.float32).reshape(M, N) + + @jax.jit + def f(a): + (b,) = tiledCopy_jax(a) + return b + + B = f(A) + + is_correct = np.allclose(np.asarray(A), np.asarray(B)) + print(f"[jax.jit] Result correct: {is_correct}") + if not is_correct: + print(" A[:2,:8]:", np.asarray(A)[:2, :8]) + print(" B[:2,:8]:", np.asarray(B)[:2, :8]) + return is_correct + + +if __name__ == "__main__": + print("=" * 50) + print("Test 1: Tiled Copy (Eager)") + print("=" * 50) + ok1 = run_eager() + + print() + print("=" * 50) + print("Test 2: Tiled Copy (jax.jit)") + print("=" * 50) + try: + ok2 = run_jit() + except Exception as e: + print(f"[jax.jit] FAILED: {e}") + ok2 = False + + print(f"\nAll passed: {ok1 and ok2}") diff --git a/examples/06-tiledMma-jax.py b/examples/06-tiledMma-jax.py new file mode 100644 index 00000000..57aca584 --- /dev/null +++ b/examples/06-tiledMma-jax.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Tiled MMA (matrix multiply accumulate) example using FlyDSL with JAX. + +JAX equivalent of ``03-tiledMma.py``. Demonstrates a single-tile GEMM +using MFMA instructions on AMD GPUs, running on JAX arrays. + +Requirements: + pip install jax[rocm] +""" + +import sys + +try: + import jax + import jax.numpy as jnp +except ImportError: + print("SKIP: JAX not installed") + sys.exit(0) + +import numpy as np + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.jax import from_jax, jax_kernel + +block_m = 64 +block_n = 64 +block_k = 8 + + +# ---------- Kernel (identical to 03-tiledMma.py) ---------- + + +@flyc.kernel +def gemm_kernel( + A: fx.Tensor, + B: fx.Tensor, + C: fx.Tensor, +): + tid = fx.thread_idx.x + bid = fx.block_idx.x + + tileA = fx.make_tile(block_m, block_k) + tileB = fx.make_tile(block_n, block_k) + tileC = fx.make_tile(block_m, block_n) + + A = fx.rocdl.make_buffer_tensor(A) + B = fx.rocdl.make_buffer_tensor(B) + C = fx.rocdl.make_buffer_tensor(C) + + bA = fx.zipped_divide(A, tileA) + bB = fx.zipped_divide(B, tileB) + bC = fx.zipped_divide(C, tileC) + + bA = fx.slice(bA, (None, bid)) + bB = fx.slice(bB, (None, bid)) + bC = fx.slice(bC, (None, bid)) + + mma_atom = fx.make_mma_atom(fx.rocdl.MFMA(16, 16, 4, fx.Float32)) + tiled_mma = fx.make_tiled_mma(mma_atom, fx.make_layout((2, 2, 1), (1, 2, 0))) + thr_mma = tiled_mma.thr_slice(tid) + + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), fx.Float32) + tiled_copy_A = fx.make_tiled_copy_A(copy_atom, tiled_mma) + tiled_copy_B = fx.make_tiled_copy_B(copy_atom, tiled_mma) + tiled_copy_C = fx.make_tiled_copy_C(copy_atom, tiled_mma) + + thr_copy_A = tiled_copy_A.get_slice(tid) + thr_copy_B = tiled_copy_B.get_slice(tid) + thr_copy_C = tiled_copy_C.get_slice(tid) + + copy_src_A = thr_copy_A.partition_S(bA) + copy_src_B = thr_copy_B.partition_S(bB) + copy_dst_C = thr_copy_C.partition_S(bC) + + partition_A = thr_mma.partition_A(bA) + partition_B = thr_mma.partition_B(bB) + partition_C = thr_mma.partition_C(bC) + + frag_A = thr_mma.make_fragment_A(partition_A) + frag_B = thr_mma.make_fragment_B(partition_B) + frag_C = thr_mma.make_fragment_C(partition_C) + + copy_frag_A = thr_copy_A.retile(frag_A) + copy_frag_B = thr_copy_B.retile(frag_B) + copy_frag_C = thr_copy_C.retile(frag_C) + + fx.copy(copy_atom, copy_src_A, copy_frag_A, pred=None) + fx.copy(copy_atom, copy_src_B, copy_frag_B, pred=None) + + fx.gemm(mma_atom, frag_C, frag_A, frag_B, frag_C) + + fx.copy(copy_atom, copy_frag_C, copy_dst_C, pred=None) + + +# ---------- JIT launcher ---------- + + +@flyc.jit +def tiledMma( + A: fx.Tensor, + B: fx.Tensor, + C: fx.Tensor, + stream: fx.Stream = fx.Stream(None), +): + gemm_kernel(A, B, C).launch(grid=(1, 1, 1), block=(256, 1, 1), stream=stream) + + +# ---------- Eager ---------- + + +def run_eager(): + M, N, K = block_m, block_n, block_k + A = jax.random.normal(jax.random.PRNGKey(0), (M, K), dtype=jnp.float32) + B = jax.random.normal(jax.random.PRNGKey(1), (N, K), dtype=jnp.float32) + C = jnp.zeros((M, N), dtype=jnp.float32) + + tA = from_jax(A) + tB = from_jax(B) + tC = from_jax(C) + + jax.block_until_ready(A) + jax.block_until_ready(B) + tiledMma(tA, tB, tC) + + expected = np.asarray(A) @ np.asarray(B).T + result = np.asarray(C) + max_diff = np.max(np.abs(result - expected)) + is_correct = max_diff < 1e-5 + print(f"[Eager] Result correct: {is_correct} (max diff: {max_diff:.2e})") + if not is_correct: + print(" Expected[:2,:4]:", expected[:2, :4]) + print(" Got[:2,:4]: ", result[:2, :4]) + return is_correct + + +# ---------- jax.jit ---------- + + +tiledMma_jax = jax_kernel( + tiledMma, + out_shapes=lambda a, b: [((a.shape[0], b.shape[0]), jnp.float32)], +) + + +def run_jit(): + M, N, K = block_m, block_n, block_k + A = jax.random.normal(jax.random.PRNGKey(0), (M, K), dtype=jnp.float32) + B = jax.random.normal(jax.random.PRNGKey(1), (N, K), dtype=jnp.float32) + + @jax.jit + def f(a, b): + (c,) = tiledMma_jax(a, b) + return c + + C = f(A, B) + + expected = np.asarray(A) @ np.asarray(B).T + result = np.asarray(C) + max_diff = np.max(np.abs(result - expected)) + is_correct = max_diff < 1e-5 + print(f"[jax.jit] Result correct: {is_correct} (max diff: {max_diff:.2e})") + if not is_correct: + print(" Expected[:2,:4]:", expected[:2, :4]) + print(" Got[:2,:4]: ", result[:2, :4]) + return is_correct + + +if __name__ == "__main__": + print("=" * 50) + print("Test 1: Tiled MMA (Eager)") + print("=" * 50) + ok1 = run_eager() + + print() + print("=" * 50) + print("Test 2: Tiled MMA (jax.jit)") + print("=" * 50) + try: + ok2 = run_jit() + except Exception as e: + print(f"[jax.jit] FAILED: {e}") + ok2 = False + + print(f"\nAll passed: {ok1 and ok2}") diff --git a/python/flydsl/compiler/__init__.py b/python/flydsl/compiler/__init__.py index eef90c97..87d1622e 100644 --- a/python/flydsl/compiler/__init__.py +++ b/python/flydsl/compiler/__init__.py @@ -17,3 +17,13 @@ "kernel", "register_backend", ] + + +def from_jax(array, *, assumed_align=None, use_32bit_stride=False): + """Convenience re-export of :func:`flydsl.jax.from_jax`. + + Available only when JAX is installed. + """ + from ..jax.adapter import from_jax as _from_jax + + return _from_jax(array, assumed_align=assumed_align, use_32bit_stride=use_32bit_stride) diff --git a/python/flydsl/compiler/jit_argument.py b/python/flydsl/compiler/jit_argument.py index fc861388..d96ba750 100644 --- a/python/flydsl/compiler/jit_argument.py +++ b/python/flydsl/compiler/jit_argument.py @@ -3,28 +3,13 @@ import ctypes import inspect -from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple, Type, get_origin -import torch - from .._mlir._mlir_libs._fly import DLTensorAdaptor from ..expr.typing import Constexpr, Int32, Stream, Tensor from .protocol import DslType, JitArgument -_FLOAT8_DTYPES = tuple( - dt - for dt in ( - getattr(torch, "float8_e4m3fn", None), - getattr(torch, "float8_e5m2", None), - getattr(torch, "float8_e4m3fnuz", None), - getattr(torch, "float8_e5m2fnuz", None), - ) - if dt is not None -) - - class JitArgumentRegistry: registry: Dict[type, Tuple[Callable, Type[DslType]]] = {} jit_arg2dsl_type: Dict[type, Type[DslType]] = {} @@ -135,84 +120,117 @@ def convert_to_jit_arguments( return param_names, jit_args, dsl_types, constexpr_values -# ================================ Common useful JitArguments ================================ +# ================================ Framework-agnostic registrations ================================ +JitArgumentRegistry.register(int)(Int32) -@JitArgumentRegistry.register(torch.Tensor, dsl_type=Tensor) -class TensorAdaptor: - def __init__( - self, - tensor: torch.Tensor, - assumed_align: Optional[int] = None, - use_32bit_stride: bool = False, - ): - self._tensor_keepalive = tensor - dlpack_tensor = tensor - if _FLOAT8_DTYPES and tensor.dtype in _FLOAT8_DTYPES: - dlpack_tensor = tensor.view(torch.uint8) - self._tensor_keepalive = dlpack_tensor - self.tensor_adaptor = DLTensorAdaptor(dlpack_tensor.__dlpack__(stream=-1), assumed_align, use_32bit_stride) - self.assumed_align = assumed_align - self.use_32bit_stride = use_32bit_stride - self._orig_dtype = tensor.dtype - self._orig_shape = tensor.shape - self._orig_strides = tensor.stride() +# ================================ PyTorch support (optional) ================================ - @staticmethod - def _extract_data_ptr(arg): - return arg.data_ptr() +try: + import torch - @classmethod - def _reusable_slot_spec(cls, arg): - """Reusable slot for tensor arguments. - - For bare-pointer calling convention, only the data pointer changes - between calls with the same shape/dtype/strides. - """ - if not hasattr(arg, 'data_ptr'): - return None - return ctypes.c_void_p, cls._extract_data_ptr - - def requires_memref_desc(func): - def wrapper(self, *args, **kwargs): - self.tensor_adaptor.build_memref_desc() - return func(self, *args, **kwargs) - - return wrapper - - @requires_memref_desc - def __fly_types__(self): - return [self.tensor_adaptor.get_memref_type()] - - @requires_memref_desc - def __fly_ptrs__(self): - return self.tensor_adaptor.get_c_pointers() - - @staticmethod - def raw_cache_signature(tensor: torch.Tensor): - """Lightweight cache sig from a raw tensor, no DLPack overhead.""" - return (tensor.dtype,) - - def __cache_signature__(self): - return ( - self._orig_dtype, - self.assumed_align, - self.use_32bit_stride, + _FLOAT8_DTYPES = tuple( + dt + for dt in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e5m2", None), + getattr(torch, "float8_e4m3fnuz", None), + getattr(torch, "float8_e5m2fnuz", None), ) + if dt is not None + ) - def mark_layout_dynamic(self, leading_dim: Optional[int] = None, divisibility: int = 1): - if leading_dim is None: - leading_dim = -1 - self.tensor_adaptor.mark_layout_dynamic(leading_dim, divisibility) - return self - - -def from_dlpack( - tensor: torch.Tensor, *, assumed_align: Optional[int] = None, use_32bit_stride: bool = False -) -> TensorAdaptor: - return TensorAdaptor(tensor, assumed_align, use_32bit_stride) - - -JitArgumentRegistry.register(int)(Int32) -JitArgumentRegistry.register(torch.cuda.Stream)(Stream) + @JitArgumentRegistry.register(torch.Tensor, dsl_type=Tensor) + class TensorAdaptor: + def __init__( + self, + tensor: torch.Tensor, + assumed_align: Optional[int] = None, + use_32bit_stride: bool = False, + ): + self._tensor_keepalive = tensor + dlpack_tensor = tensor + if _FLOAT8_DTYPES and tensor.dtype in _FLOAT8_DTYPES: + dlpack_tensor = tensor.view(torch.uint8) + self._tensor_keepalive = dlpack_tensor + + self.tensor_adaptor = DLTensorAdaptor(dlpack_tensor.__dlpack__(stream=-1), assumed_align, use_32bit_stride) + self.assumed_align = assumed_align + self.use_32bit_stride = use_32bit_stride + self._orig_dtype = tensor.dtype + self._orig_shape = tensor.shape + self._orig_strides = tensor.stride() + self._dynamic_leading_dim = None + self._dynamic_divisibility = None + + @staticmethod + def _extract_data_ptr(arg): + return arg.data_ptr() + + @classmethod + def _reusable_slot_spec(cls, arg): + """Reusable slot for tensor arguments. + + For bare-pointer calling convention, only the data pointer changes + between calls with the same shape/dtype/strides. + """ + if not hasattr(arg, 'data_ptr'): + return None + return ctypes.c_void_p, cls._extract_data_ptr + + def requires_memref_desc(func): + def wrapper(self, *args, **kwargs): + self.tensor_adaptor.build_memref_desc() + return func(self, *args, **kwargs) + + return wrapper + + @requires_memref_desc + def __fly_types__(self): + return [self.tensor_adaptor.get_memref_type()] + + @requires_memref_desc + def __fly_ptrs__(self): + return self.tensor_adaptor.get_c_pointers() + + @staticmethod + def raw_cache_signature(tensor: torch.Tensor): + """Lightweight cache sig from a raw tensor, no DLPack overhead.""" + return (tensor.dtype,) + + def __cache_signature__(self): + return ( + self._orig_dtype, + self.assumed_align, + self.use_32bit_stride, + self._dynamic_leading_dim, + self._dynamic_divisibility, + ) + + def mark_layout_dynamic(self, leading_dim: Optional[int] = None, divisibility: int = 1): + if leading_dim is None: + leading_dim = -1 + self._dynamic_leading_dim = leading_dim + self._dynamic_divisibility = divisibility + self.tensor_adaptor.mark_layout_dynamic(leading_dim, divisibility) + return self + + def from_dlpack( + tensor: torch.Tensor, *, assumed_align: Optional[int] = None, use_32bit_stride: bool = False + ) -> TensorAdaptor: + return TensorAdaptor(tensor, assumed_align, use_32bit_stride) + + JitArgumentRegistry.register(torch.cuda.Stream)(Stream) + +except ImportError: + # torch is not installed — PyTorch tensor support is unavailable. + # JAX arrays can still be used via flydsl.jax. + TensorAdaptor = None + + def from_dlpack(tensor, *, assumed_align=None, use_32bit_stride=False): + raise ImportError( + "from_dlpack requires PyTorch. Install it with:\n" + " pip install torch\n" + "For JAX arrays, use flydsl.jax.from_jax instead." + ) diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index 1be37d85..ce3c6bb2 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -601,6 +601,19 @@ def __init__(self, func: Callable): self._sig = None # lazy: set on first call self._mem_cache = {} + def get_last_artifact(self) -> Optional["CompiledArtifact"]: + """Return the most recently compiled artifact, or None. + + Used by external integrations (e.g. ``flydsl.jax``) to retrieve + compiled kernels for registration with framework-specific runtimes. + """ + artifact = getattr(self, "_last_compiled", None) + if artifact is None and self._mem_cache: + candidate = next(reversed(self._mem_cache.values())) + if isinstance(candidate, CompiledArtifact): + artifact = candidate + return artifact + def _ensure_sig(self): """Initialize signature + param metadata on first call (not at decoration time).""" if self._sig is not None: @@ -715,6 +728,11 @@ def __call__(self, *args, **kwargs): self._mem_cache[cache_key] = cached_func if cached_func is not None: + # Keep _last_compiled up to date so external integrations + # (e.g. flydsl.jax) can retrieve the artifact for this + # compilation even when the fast path is taken. + self._last_compiled = cached_func + # Build CallState via JitArgument registry (same dispatch as compile path) try: state = _build_call_state( @@ -788,6 +806,11 @@ def __call__(self, *args, **kwargs): original_ir, ) + # Always keep a reference to the last compiled artifact so that + # external tools (e.g. flydsl.jax) can retrieve it even when + # caching is disabled via FLYDSL_RUNTIME_ENABLE_CACHE=0. + self._last_compiled = compiled_func + if use_cache: self._mem_cache[cache_key] = compiled_func if self.cache_manager and not env.debug.dump_ir: diff --git a/python/flydsl/jax/__init__.py b/python/flydsl/jax/__init__.py new file mode 100644 index 00000000..5f1fdfd4 --- /dev/null +++ b/python/flydsl/jax/__init__.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""FlyDSL JAX integration. + +Provides two levels of integration: + +Level 1 — Eager mode (``from_jax``): + Wrap JAX arrays as FlyDSL JitArguments so they can be passed directly to + ``@flyc.jit`` functions. Requires ``jax.block_until_ready()`` for + synchronization. + +Level 2 — ``jax.jit`` integration (``jax_kernel``): + Register compiled FlyDSL kernels as JAX primitives via the XLA FFI so they + compose with ``jax.jit``. + +Usage (eager):: + + import jax.numpy as jnp + import flydsl.compiler as flyc + from flydsl.jax import from_jax + + a = jnp.ones(1024, dtype=jnp.float32) + ta = from_jax(a) + my_jit_func(ta, ...) + +Usage (jax.jit):: + + from flydsl.jax import jax_kernel + + wrapped = jax_kernel( + my_flyc_jit_func, + out_shapes=lambda a, b: [(a.shape, a.dtype)], + ) + + @jax.jit + def f(a, b): + (c,) = wrapped(a, b) + return c +""" + +from flydsl.jax.adapter import JaxTensorAdaptor, from_jax + +# Lazy imports: these modules depend on torch (via flydsl.compiler) and +# JAX internal MLIR dialects, so we defer them to avoid hard import-time +# failures in environments where torch is not installed. + + +def compile_and_register(*args, **kwargs): + """Lazy wrapper for :func:`flydsl.jax.ffi_bridge.compile_and_register`.""" + from flydsl.jax.ffi_bridge import compile_and_register as _car + + return _car(*args, **kwargs) + + +def jax_kernel(*args, **kwargs): + """Lazy wrapper for :func:`flydsl.jax.primitive.jax_kernel`.""" + from flydsl.jax.primitive import jax_kernel as _jk + + return _jk(*args, **kwargs) + + +__all__ = [ + "compile_and_register", + "from_jax", + "JaxTensorAdaptor", + "jax_kernel", +] diff --git a/python/flydsl/jax/_xla_bridge.c b/python/flydsl/jax/_xla_bridge.c new file mode 100644 index 00000000..07131fd3 --- /dev/null +++ b/python/flydsl/jax/_xla_bridge.c @@ -0,0 +1,110 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright (c) 2025 FlyDSL Project Contributors + * + * Thin C trampoline that bridges XLA's GPU custom-call convention to + * FlyDSL's bare-pointer convention. + * + * XLA GPU custom call (API_VERSION_STATUS_RETURNING_UNIFIED): + * void fn(hipStream_t stream, void** buffers, const char* opaque, size_t opaque_len) + * + * FlyDSL bare-pointer convention: + * void fn(void** ptrs) + * where ptrs[i] = &storage[i], storage[i] = device_ptr or stream value + * + * Compiled at import time via: cc -shared -fPIC -o _xla_bridge.so _xla_bridge.c -lpthread + */ + +#include +#include +#include +#include + +#define MAX_BUFFERS 64 +#define MAX_SCALARS 16 +#define MAX_TARGETS 256 + +typedef void (*flydsl_func_t)(void **ptrs); + +typedef struct { + flydsl_func_t func; + int n_buffers; + int n_scalars; + int64_t scalar_vals[MAX_SCALARS]; +} target_slot_t; + +static target_slot_t g_targets[MAX_TARGETS]; +static int g_n_targets = 0; +static pthread_mutex_t g_lock = PTHREAD_MUTEX_INITIALIZER; + +int flydsl_xla_register(void *func_ptr, int n_buffers, int n_scalars, + int64_t *scalar_vals) { + if (n_buffers > MAX_BUFFERS || n_scalars > MAX_SCALARS) + return -1; + + pthread_mutex_lock(&g_lock); + if (g_n_targets >= MAX_TARGETS) { + pthread_mutex_unlock(&g_lock); + return -1; + } + int idx = g_n_targets++; + /* Populate slot while holding lock so concurrent dispatchers never + observe an allocated index with uninitialized data. */ + g_targets[idx].func = (flydsl_func_t)func_ptr; + g_targets[idx].n_buffers = n_buffers; + g_targets[idx].n_scalars = n_scalars; + for (int i = 0; i < n_scalars; i++) + g_targets[idx].scalar_vals[i] = scalar_vals ? scalar_vals[i] : 0; + pthread_mutex_unlock(&g_lock); + + return idx; +} + +static void xla_bridge_dispatch(void *stream, void **buffers, + const char *opaque, size_t opaque_len) { + /* Validate opaque: must be exactly sizeof(int) carrying the slot index. */ + if (opaque == NULL || opaque_len != sizeof(int)) + return; + + int idx; + memcpy(&idx, opaque, sizeof(int)); + + if (idx < 0 || idx >= g_n_targets) + return; + + target_slot_t *t = &g_targets[idx]; + int nb = t->n_buffers; + int ns = t->n_scalars; + + /* Build FlyDSL's ptrs array on the stack. + * Layout: [buf0, buf1, ..., scalar0, scalar1, ..., stream] + * + * For scalars, FlyDSL reads through ptrs[i] -> &storage[i] as a + * pointer-sized value. We store the int64_t in a dedicated array + * and point into it, avoiding integer-to-pointer casts. + */ + void *storage[MAX_BUFFERS + 1]; + int64_t scalar_storage[MAX_SCALARS]; + void *packed[MAX_BUFFERS + MAX_SCALARS + 1]; + + /* Tensor buffers. */ + for (int i = 0; i < nb; i++) { + storage[i] = buffers[i]; + packed[i] = &storage[i]; + } + /* Scalar values (from the registered slot, not from opaque). */ + for (int i = 0; i < ns; i++) { + scalar_storage[i] = t->scalar_vals[i]; + packed[nb + i] = &scalar_storage[i]; + } + /* Stream in the last slot. */ + storage[nb] = stream; + packed[nb + ns] = &storage[nb]; + + t->func(packed); +} + +void *flydsl_xla_get_bridge(int idx) { + (void)idx; + return (void *)&xla_bridge_dispatch; +} diff --git a/python/flydsl/jax/adapter.py b/python/flydsl/jax/adapter.py new file mode 100644 index 00000000..d687d6f3 --- /dev/null +++ b/python/flydsl/jax/adapter.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""JAX array adapter for FlyDSL's JitArgument protocol. + +Wraps ``jax.Array`` objects via DLPack so they can be passed to +``@flyc.jit`` functions in eager mode. +""" + +import ctypes +import threading +from typing import Optional + +try: + import jax + import jax.numpy as jnp +except ImportError as exc: + raise ImportError( + "JAX is required for flydsl.jax. Install it with:\n" + " pip install jax[rocm] # or jax[cuda12] for NVIDIA GPUs" + ) from exc + +from .._mlir._mlir_libs._fly import DLTensorAdaptor + + +# JAX float8 dtypes that need uint8 view treatment (analogous to PyTorch float8). +_JAX_FLOAT8_DTYPES = tuple( + dt + for dt in ( + getattr(jnp, "float8_e3m4", None), + getattr(jnp, "float8_e4m3", None), + getattr(jnp, "float8_e4m3fn", None), + getattr(jnp, "float8_e4m3fnuz", None), + getattr(jnp, "float8_e4m3b11fnuz", None), + getattr(jnp, "float8_e5m2", None), + getattr(jnp, "float8_e5m2fnuz", None), + getattr(jnp, "float8_e8m0fnu", None), + ) + if dt is not None +) + +# Lazy registration flag — we register JaxTensorAdaptor with +# JitArgumentRegistry only when torch (and therefore jit_argument.py) +# is actually importable. This avoids a hard torch dependency. +_registered = False +_register_lock = threading.Lock() + + +def _ensure_registered(): + """Register JaxTensorAdaptor with FlyDSL's JitArgumentRegistry. + + Called lazily so that ``import flydsl.jax`` works even when + torch is not installed. Registration is only needed for the + eager path (``@flyc.jit`` calls with JAX arrays). + """ + global _registered + if _registered: + return + with _register_lock: + if _registered: + return + try: + from ..compiler.jit_argument import JitArgumentRegistry + from ..expr.typing import Tensor + + JitArgumentRegistry.register(jax.Array, dsl_type=Tensor)(JaxTensorAdaptor) + except ImportError: + pass # torch not available — eager path won't work, but jax.jit path can + _registered = True + + +class JaxTensorAdaptor: + """Adapt a ``jax.Array`` to FlyDSL's JitArgument protocol via DLPack. + + Parameters + ---------- + array : jax.Array + A JAX array on a GPU device. + assumed_align : int, optional + Override pointer alignment assumption (bytes). + use_32bit_stride : bool + Use 32-bit strides in the MLIR memref descriptor. + + Notes + ----- + - The array must reside on a single GPU device (no sharded arrays). + - JAX does not expose HIP streams directly. The caller is responsible + for synchronization (``jax.block_until_ready`` before launch, and + device synchronization after). + - Float8 arrays are handled by extracting DLPack from a uint8 view. + """ + + def __init__( + self, + array: "jax.Array", + assumed_align: Optional[int] = None, + use_32bit_stride: bool = False, + ): + if not isinstance(array, jax.Array): + raise TypeError(f"Expected jax.Array, got {type(array).__name__}") + + # Ensure the array is materialized and on a single device. + array = jax.device_put(array) + if callable(getattr(array, "is_deleted", None)) and array.is_deleted(): + raise ValueError("Cannot adapt a deleted JAX array") + + self._array_keepalive = array + self._orig_dtype = array.dtype + self._orig_shape = array.shape + self._orig_strides = _jax_strides(array) + self.assumed_align = assumed_align + self.use_32bit_stride = use_32bit_stride + self._dynamic_leading_dim = None + self._dynamic_divisibility = None + + # Float8 arrays: extract DLPack from a uint8 view. + dlpack_array = array + if _JAX_FLOAT8_DTYPES and array.dtype in _JAX_FLOAT8_DTYPES: + dlpack_array = array.view(jnp.uint8) + self._array_keepalive = dlpack_array + + # Extract DLPack capsule. + # JAX __dlpack__(stream=) semantics: stream=0 means "may be used on any + # stream" (JAX default stream). We pass 0 rather than a specific stream. + dl_capsule = dlpack_array.__dlpack__(stream=0) + self.tensor_adaptor = DLTensorAdaptor(dl_capsule, assumed_align, use_32bit_stride) + + # ------------------------------------------------------------------ + # JitArgument protocol + # ------------------------------------------------------------------ + + def _build_desc(func): + """Decorator: ensure memref descriptor is built before access.""" + + def wrapper(self, *args, **kwargs): + self.tensor_adaptor.build_memref_desc() + return func(self, *args, **kwargs) + + return wrapper + + @_build_desc + def __fly_types__(self): + return [self.tensor_adaptor.get_memref_type()] + + @_build_desc + def __fly_ptrs__(self): + return self.tensor_adaptor.get_c_pointers() + + # ------------------------------------------------------------------ + # Cache signature (same structure as TensorAdaptor) + # ------------------------------------------------------------------ + + def __cache_signature__(self): + return ( + self._orig_dtype, + self.assumed_align, + self.use_32bit_stride, + self._dynamic_leading_dim, + self._dynamic_divisibility, + ) + + # ------------------------------------------------------------------ + # Fast-path reuse + # ------------------------------------------------------------------ + + @staticmethod + def _extract_data_ptr(arg): + """Extract the raw device pointer from a jax.Array.""" + # For single-device arrays, unsafe_buffer_pointer gives the device ptr. + buf = arg.addressable_data(0) + return buf.unsafe_buffer_pointer() + + @classmethod + def _reusable_slot_spec(cls, arg): + if not isinstance(arg, jax.Array): + return None + return ctypes.c_void_p, cls._extract_data_ptr + + # ------------------------------------------------------------------ + # Layout dynamism (mirrors TensorAdaptor) + # ------------------------------------------------------------------ + + def mark_layout_dynamic(self, leading_dim: Optional[int] = None, divisibility: int = 1): + """Mark dimensions as dynamic for shape-polymorphic compilation.""" + if leading_dim is None: + leading_dim = -1 + self._dynamic_leading_dim = leading_dim + self._dynamic_divisibility = divisibility + self.tensor_adaptor.mark_layout_dynamic(leading_dim, divisibility) + return self + + +def _jax_strides(array: "jax.Array"): + """Return element strides for a JAX array (like torch.Tensor.stride()). + + JAX arrays are always contiguous (C-order by default), so we compute + strides from the shape. + """ + shape = array.shape + if not shape: + return () + strides = [1] * len(shape) + for i in range(len(shape) - 2, -1, -1): + strides[i] = strides[i + 1] * shape[i + 1] + return tuple(strides) + + +def from_jax( + array: "jax.Array", + *, + assumed_align: Optional[int] = None, + use_32bit_stride: bool = False, +) -> "JaxTensorAdaptor": + """Wrap a JAX array for use with ``@flyc.jit`` functions. + + Parameters + ---------- + array : jax.Array + A JAX array residing on a GPU device. + assumed_align : int, optional + Override pointer alignment hint (bytes). + use_32bit_stride : bool + Use 32-bit strides in the memref descriptor. + + Returns + ------- + JaxTensorAdaptor + An adapter implementing the FlyDSL ``JitArgument`` protocol. + + Examples + -------- + >>> import jax.numpy as jnp + >>> import flydsl.compiler as flyc + >>> from flydsl.jax import from_jax + >>> a = jnp.ones(1024, dtype=jnp.float32) + >>> ta = from_jax(a) + >>> my_jit_func(ta, ...) # pass to @flyc.jit function + """ + _ensure_registered() + return JaxTensorAdaptor(array, assumed_align, use_32bit_stride) diff --git a/python/flydsl/jax/ffi_bridge.py b/python/flydsl/jax/ffi_bridge.py new file mode 100644 index 00000000..760d4bba --- /dev/null +++ b/python/flydsl/jax/ffi_bridge.py @@ -0,0 +1,294 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Bridge between FlyDSL compiled kernels and JAX's XLA custom-call interface. + +This module compiles ``@flyc.jit`` functions via the normal FlyDSL MLIR +pipeline, then registers the resulting native function pointer as an XLA +custom-call target so that ``jax.jit`` can invoke it. + +Architecture +------------ +A compiled C trampoline (``_xla_bridge.so``) translates between XLA's +GPU custom-call convention and FlyDSL's bare-pointer convention: + +XLA GPU custom call (``API_VERSION_STATUS_RETURNING_UNIFIED``):: + + void fn(hipStream_t stream, void** buffers, const char* opaque, size_t opaque_len) + +Note: the api_version numbering differs between the XLA registration API +(``xla_client.register_custom_call_target``, where 0 = untyped custom call) +and StableHLO (``CustomCallOp``, where 2 = STATUS_RETURNING_UNIFIED for GPU). +Both refer to the same calling convention. + +FlyDSL bare-pointer convention:: + + void fn(void** ptrs) // ptrs[i] = &storage[i], storage[i] = device_ptr or stream + +The trampoline uses the ``opaque`` bytes to look up the registered FlyDSL +function and buffer count, then repacks the XLA buffers + stream into the +FlyDSL layout on the C stack. This avoids Python callbacks entirely — +critical because XLA dispatches custom calls from C++ threads. +""" + +from __future__ import annotations + +import ctypes +import hashlib +import struct +import subprocess +import threading +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple + +try: + import jax + import jax.numpy as jnp +except ImportError as exc: + raise ImportError( + "JAX is required for flydsl.jax. Install with:\n" + " pip install jax[rocm]" + ) from exc + + +# --------------------------------------------------------------------------- +# Load (or build) the C trampoline shared library +# --------------------------------------------------------------------------- + +_THIS_DIR = Path(__file__).resolve().parent +_BRIDGE_SO = _THIS_DIR / "_xla_bridge.so" +_BRIDGE_C = _THIS_DIR / "_xla_bridge.c" + + +def _ensure_bridge_lib() -> ctypes.CDLL: + """Load ``_xla_bridge.so``, compiling from source if necessary. + + Recompiles if the .c source is newer than the .so to avoid stale binaries. + """ + needs_compile = not _BRIDGE_SO.exists() + if not needs_compile and _BRIDGE_C.exists(): + needs_compile = _BRIDGE_C.stat().st_mtime > _BRIDGE_SO.stat().st_mtime + if needs_compile: + if not _BRIDGE_C.exists(): + raise FileNotFoundError( + f"Cannot find XLA bridge source: {_BRIDGE_C}\n" + f"Please rebuild or reinstall flydsl." + ) + subprocess.check_call( + ["cc", "-shared", "-fPIC", "-O2", "-lpthread", + "-o", str(_BRIDGE_SO), str(_BRIDGE_C)], + cwd=str(_THIS_DIR), + ) + lib = ctypes.CDLL(str(_BRIDGE_SO)) + + # int flydsl_xla_register(void *func_ptr, int n_buffers, int n_scalars, int64_t *scalar_vals) + lib.flydsl_xla_register.restype = ctypes.c_int + lib.flydsl_xla_register.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_int, ctypes.c_void_p] + + # void *flydsl_xla_get_bridge(int idx) + lib.flydsl_xla_get_bridge.restype = ctypes.c_void_p + lib.flydsl_xla_get_bridge.argtypes = [ctypes.c_int] + + return lib + + +_bridge_lib: Optional[ctypes.CDLL] = None + + +def _get_bridge_lib() -> ctypes.CDLL: + global _bridge_lib + if _bridge_lib is None: + _bridge_lib = _ensure_bridge_lib() + return _bridge_lib + + +# --------------------------------------------------------------------------- +# Thread-safe registry of compiled kernels +# --------------------------------------------------------------------------- + +_lock = threading.Lock() +_registered_targets: Dict[str, "_RegisteredTarget"] = {} + + +class _RegisteredTarget: + """Bookkeeping for a registered XLA custom-call target.""" + + __slots__ = ("name", "slot_idx", "n_buffers", "opaque_bytes", + "artifact", "bridge_func_ptr") + + def __init__(self, name: str, artifact, n_buffers: int, slot_idx: int, bridge_ptr: int): + self.name = name + self.artifact = artifact + self.n_buffers = n_buffers + self.slot_idx = slot_idx + self.opaque_bytes = struct.pack("i", slot_idx) # 4-byte opaque for XLA + self.bridge_func_ptr = bridge_ptr + + +# --------------------------------------------------------------------------- +# Compilation + registration +# --------------------------------------------------------------------------- + + +def compile_and_register( + flyc_func: Callable, + *, + input_shapes: List[Tuple[Tuple[int, ...], Any]], + output_shapes: List[Tuple[Tuple[int, ...], Any]], + constexpr_kwargs: Optional[dict] = None, + runtime_scalars: Optional[dict] = None, +) -> str: + """Compile a ``@flyc.jit`` function and register it as an XLA custom-call target. + + Parameters + ---------- + flyc_func : callable + A FlyDSL ``@flyc.jit``-decorated function (``JitFunction``). + input_shapes : list of (shape, dtype) + Shape and dtype of each input tensor. + output_shapes : list of (shape, dtype) + Shape and dtype of each output tensor. + constexpr_kwargs : dict, optional + Compile-time constant keyword arguments (``Constexpr`` parameters). + runtime_scalars : dict, optional + Runtime scalar arguments that are not tensors (e.g. ``n: Int32``). + Keys are parameter names, values are representative values used + during compilation tracing. + + Returns + ------- + str + The registered custom-call target name. + """ + if constexpr_kwargs is None: + constexpr_kwargs = {} + if runtime_scalars is None: + runtime_scalars = {} + + # Build a unique name based on function + shapes + constexprs. + func_name = flyc_func.func.__name__ if hasattr(flyc_func, "func") else str(flyc_func) + sig_parts = [func_name] + for shape, dtype in input_shapes: + sig_parts.append(f"i{shape}:{dtype}") + for shape, dtype in output_shapes: + sig_parts.append(f"o{shape}:{dtype}") + for k, v in sorted(constexpr_kwargs.items()): + sig_parts.append(f"c{k}={v}") + for k, v in sorted(runtime_scalars.items()): + sig_parts.append(f"r{k}={v}") + + name_hash = hashlib.sha256("|".join(sig_parts).encode()).hexdigest()[:16] + target_name = f"flydsl_{name_hash}" + + # Hold the lock through the full compile+register path to prevent + # concurrent threads from compiling and registering the same target. + with _lock: + if target_name in _registered_targets: + return target_name + + # Create concrete JAX arrays for each tensor argument. + from .adapter import from_jax + + all_arrays = [] + for shape, dtype in list(input_shapes) + list(output_shapes): + all_arrays.append(jnp.zeros(shape, dtype=dtype)) + + jit_args = [from_jax(a) for a in all_arrays] + + # Build the full argument list for the @flyc.jit function. + call_args = list(jit_args) + for _name, val in sorted(runtime_scalars.items()): + call_args.append(val) + + # Trigger compilation. The cache key now includes dynamic layout + # state, so a prior eager call with mark_layout_dynamic won't + # collide with this plain-array compilation. + flyc_func(*call_args, **constexpr_kwargs) + + artifact = flyc_func.get_last_artifact() + if artifact is None: + raise RuntimeError( + "FlyDSL compilation did not produce a cached artifact. " + "Ensure the function is a @flyc.jit-decorated function." + ) + + # Get the native function pointer from the compiled artifact. + func_exe = artifact._get_func_exe() + fly_func_ptr = ctypes.cast(func_exe, ctypes.c_void_p).value + + # Register with the C trampoline. + # Scalar values are baked into the trampoline slot so they're inserted + # between the tensor buffers and the stream at dispatch time. + n_buffers = len(input_shapes) + len(output_shapes) + scalar_values = [v for _name, v in sorted(runtime_scalars.items())] + n_scalars = len(scalar_values) + + # Pack scalar values as int64 array for the C bridge. + if n_scalars > 0: + ScalarArray = ctypes.c_int64 * n_scalars + scalar_arr = ScalarArray(*scalar_values) + scalar_ptr = ctypes.cast(scalar_arr, ctypes.c_void_p) + else: + scalar_ptr = None + + lib = _get_bridge_lib() + slot_idx = lib.flydsl_xla_register(fly_func_ptr, n_buffers, n_scalars, scalar_ptr) + if slot_idx < 0: + raise RuntimeError("Failed to register FlyDSL kernel in XLA bridge (too many targets?)") + + bridge_ptr = lib.flydsl_xla_get_bridge(slot_idx) + + target = _RegisteredTarget(target_name, artifact, n_buffers, slot_idx, bridge_ptr) + + # Register with JAX's XLA custom-call mechanism. + _register_with_xla(target) + + _registered_targets[target_name] = target + + return target_name + + +def _register_with_xla(target: _RegisteredTarget) -> None: + """Register the C bridge function as an XLA custom-call target.""" + # Ensure the JAX backend is initialized so the custom-call handler is + # registered. Without this, registrations are queued but never flushed. + import jax as _jax + _jax.default_backend() + + # Use JAX's own pycapsule to ensure the capsule name matches what XLA expects. + from jax import ffi as _jax_ffi + capsule = _jax_ffi.pycapsule(ctypes.c_void_p(target.bridge_func_ptr)) + + # Use the internal xla_client directly to bypass JAX's platform name + # mapping which maps "gpu" -> "CUDA" (wrong for ROCm). + from jax._src.lib import xla_client as _xla_client + + # Detect the XLA internal platform name. + if "ROCM" in _xla_client._custom_callback_handler: + xla_platform_name = "ROCM" + elif "CUDA" in _xla_client._custom_callback_handler: + xla_platform_name = "CUDA" + else: + # Fallback: try both + xla_platform_name = "ROCM" + + # Registration api_version=0 selects the untyped custom-call convention: + # void fn(stream, void** buffers, const char* opaque, size_t opaque_len) + # This corresponds to StableHLO api_version=2 (STATUS_RETURNING_UNIFIED) + # used in the CustomCallOp emitted by primitive.py. + _xla_client.register_custom_call_target( + target.name, capsule, xla_platform_name, api_version=0, + ) + + +def get_opaque_for(target_name: str) -> bytes: + """Return the opaque bytes that XLA should pass to the custom call. + + The opaque encodes the slot index so the C trampoline can look up + the correct FlyDSL function. + """ + with _lock: + target = _registered_targets.get(target_name) + if target is None: + raise KeyError(f"Target {target_name!r} not registered") + return target.opaque_bytes diff --git a/python/flydsl/jax/primitive.py b/python/flydsl/jax/primitive.py new file mode 100644 index 00000000..ba70aead --- /dev/null +++ b/python/flydsl/jax/primitive.py @@ -0,0 +1,319 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""JAX primitive wrapping compiled FlyDSL kernels for ``jax.jit`` integration. + +This module registers FlyDSL-compiled GPU kernels as JAX custom-call +primitives so they can participate in JAX's tracing and compilation +pipeline. + +Architecture +------------ +1. At decoration time (``jax_kernel``), a FlyDSL ``@flyc.jit`` function is + wrapped. No compilation happens yet. +2. When the wrapper is called inside a ``jax.jit``-traced function, a JAX + primitive (``flydsl_call_p``) is bound. Its abstract-eval rule + propagates shapes/dtypes to JAX. +3. At XLA lowering time, the FlyDSL kernel is JIT-compiled for the concrete + shapes, and the resulting GPU binary is registered as a custom-call + target. The XLA ``CustomCall`` HLO is emitted. +4. At execution time, XLA invokes the custom call on its own HIP stream — + no explicit stream management is needed. + +Limitations +----------- +- **In-place semantics**: FlyDSL kernels write to pre-allocated output + buffers. The wrapper pre-allocates outputs and passes them as additional + XLA buffers to the custom call. +- **No autograd**: ``jax.grad`` is not supported (would need explicit VJP + rules with a backward kernel). +- **No vmap**: Batching rules are not yet implemented. +- **Shape-specialization**: Each unique set of input shapes triggers a new + compilation, cached by FlyDSL's existing cache. +- **Scalar args baked at compile time**: Non-tensor runtime arguments (e.g. + ``n: Int32``) are traced with their concrete values during compilation. + Changing them requires recompilation (use ``Constexpr`` or pass via + ``runtime_scalars``). +""" + +from __future__ import annotations + +import functools +from typing import Callable, Dict, Optional, Tuple + +try: + import jax + import jax.numpy as jnp + from jax._src import core + from jax.interpreters import mlir as jax_mlir + + # StableHLO/HLO dialect ops and MLIR IR used in the lowering rule. + from jax._src.lib.mlir import ir as jax_ir + from jax._src.lib.mlir.dialects import hlo as stablehlo +except ImportError as exc: + raise ImportError( + "JAX is required for flydsl.jax. Install with:\n" + " pip install jax[rocm]" + ) from exc + +from .adapter import JaxTensorAdaptor, from_jax + +# --------------------------------------------------------------------------- +# JAX Primitive +# --------------------------------------------------------------------------- + +flydsl_call_p = core.Primitive("flydsl_call") +flydsl_call_p.multiple_results = True + + +def _flydsl_abstract_eval( + *args: core.ShapedArray, + out_avals: Tuple[core.ShapedArray, ...], + **_kwargs, +) -> Tuple[core.ShapedArray, ...]: + """Abstract evaluation: propagate output shapes/dtypes.""" + return out_avals + + +flydsl_call_p.def_abstract_eval(_flydsl_abstract_eval) + + +# --------------------------------------------------------------------------- +# Impl rule (eager fallback for un-jitted calls) +# --------------------------------------------------------------------------- + + +def _flydsl_impl( + *args, + flyc_func: Callable, + out_avals: Tuple[core.ShapedArray, ...], + constexpr_kwargs: tuple, + runtime_scalars: tuple, + **_kwargs, +): + """Eager implementation: compile and run via FlyDSL's normal JIT path. + + Note: the kernel executes asynchronously on the default HIP stream. + Callers should use ``jax.block_until_ready()`` on the returned arrays. + """ + # Unpack frozen tuples back to dicts. + constexpr_dict = dict(constexpr_kwargs) + scalars_dict = dict(runtime_scalars) + + # Allocate output arrays. + outputs = [] + for aval in out_avals: + outputs.append(jnp.zeros(aval.shape, dtype=aval.dtype)) + + # Convert all JAX arrays to JaxTensorAdaptors. + jit_args = [] + for a in list(args) + outputs: + jit_args.append(from_jax(a)) + + # Append runtime scalar values. + call_args = list(jit_args) + for _name, val in sorted(scalars_dict.items()): + call_args.append(val) + + # Call via FlyDSL JIT (uses default stream). + flyc_func(*call_args, **constexpr_dict) + + return tuple(outputs) + + +flydsl_call_p.def_impl(_flydsl_impl) + + +# --------------------------------------------------------------------------- +# XLA lowering rule +# --------------------------------------------------------------------------- + +# Cache: hashable key -> registered target name +_lowering_cache: Dict[tuple, str] = {} + + +def _shapes_key(avals): + """Create a hashable key from a sequence of abstract values.""" + return tuple((a.shape, a.dtype) for a in avals) + + +def _flydsl_lowering( + ctx: jax_mlir.LoweringRuleContext, + *args, + flyc_func: Callable, + out_avals: Tuple[core.ShapedArray, ...], + constexpr_kwargs: tuple, + runtime_scalars: tuple, +): + """MLIR lowering rule: emit a ``stablehlo.custom_call`` backed by the + compiled FlyDSL kernel. + + This function is called during ``jax.jit`` lowering. It: + 1. Computes input/output shape signatures from ``ctx.avals_in``/``ctx.avals_out``. + 2. Triggers FlyDSL compilation (via ``ffi_bridge.compile_and_register``) if + the kernel hasn't been compiled for these shapes yet. + 3. Emits a ``stablehlo.CustomCallOp`` that XLA will dispatch to the + registered bridge function at runtime. + + Non-tensor arguments (``runtime_scalars``) are baked into the compiled + kernel during tracing — the XLA custom call only receives tensor buffers. + """ + from .ffi_bridge import compile_and_register, get_opaque_for + + avals_in = ctx.avals_in + avals_out = ctx.avals_out + + # constexpr_kwargs and runtime_scalars arrive as frozen tuples of + # (key, value) pairs (required for JAX hashability). + # Use function name (stable across calls) rather than id() which can + # be reused after GC for a different object. + func_id = flyc_func.func.__name__ if hasattr(flyc_func, "func") else str(flyc_func) + cache_key = ( + func_id, + _shapes_key(avals_in), + _shapes_key(avals_out), + constexpr_kwargs, + runtime_scalars, + ) + + target_name = _lowering_cache.get(cache_key) + if target_name is None: + # Compile the FlyDSL function and register it as an XLA custom-call target. + input_shapes = [(tuple(a.shape), a.dtype) for a in avals_in] + output_shapes = [(tuple(a.shape), a.dtype) for a in avals_out] + + target_name = compile_and_register( + flyc_func, + input_shapes=input_shapes, + output_shapes=output_shapes, + constexpr_kwargs=dict(constexpr_kwargs), + runtime_scalars=dict(runtime_scalars), + ) + _lowering_cache[cache_key] = target_name + + # The opaque bytes encode the slot index so the C trampoline can look up + # the correct FlyDSL compiled function. + opaque = get_opaque_for(target_name) + + # Build MLIR result types for each output. + result_types = [jax_mlir.aval_to_ir_type(aval) for aval in avals_out] + + # Emit the custom call. + # api_version in StableHLO CustomCallOp: + # 0 = API_VERSION_ORIGINAL (CPU: fn(out, ins)) + # 1 = API_VERSION_STATUS_RETURNING (CPU: fn(out, ins, status)) + # 2 = API_VERSION_STATUS_RETURNING_UNIFIED (GPU: fn(stream, buffers, opaque, opaque_len)) + # 4 = API_VERSION_TYPED_FFI + # We use 2 for GPU custom calls with the old untyped convention. + # backend_config carries the opaque bytes (slot index for the C trampoline). + i32_type = jax_ir.IntegerType.get_signless(32) + call = stablehlo.CustomCallOp( + result_types, + list(args), + call_target_name=target_name, + api_version=jax_ir.IntegerAttr.get(i32_type, 2), + backend_config=jax_ir.StringAttr.get(opaque.decode("latin-1")), + has_side_effect=jax_ir.BoolAttr.get(True), + ) + + return call.results + + +# Register the lowering for the ROCm platform. +# JAX uses "rocm" as the platform name for AMD GPUs. +jax_mlir.register_lowering(flydsl_call_p, _flydsl_lowering, platform="rocm") + +# Also register for "gpu" platform (some JAX versions use this generically). +try: + jax_mlir.register_lowering(flydsl_call_p, _flydsl_lowering, platform="gpu") +except Exception: + pass + + +# --------------------------------------------------------------------------- +# jax_kernel wrapper +# --------------------------------------------------------------------------- + + +def jax_kernel( + flyc_func: Callable, + *, + out_shapes: Callable, + constexpr_kwargs: Optional[dict] = None, + runtime_scalars: Optional[dict] = None, +) -> Callable: + """Wrap a ``@flyc.jit`` function for use inside ``jax.jit``. + + Parameters + ---------- + flyc_func : callable + A FlyDSL ``@flyc.jit``-decorated function. + out_shapes : callable + A function ``(*wrapper_args) -> list[(shape, dtype)]`` that returns + the shape and dtype of each output tensor the kernel will produce. + FlyDSL kernels write to pre-allocated output buffers, so the caller + must specify the output layout. + constexpr_kwargs : dict, optional + Compile-time constant keyword arguments forwarded to the FlyDSL + function (``Constexpr`` parameters). + runtime_scalars : dict, optional + Non-tensor runtime arguments (e.g. ``{"n": 128}``). These are + passed to the FlyDSL function during compilation tracing but are + NOT passed through the XLA custom call at runtime — they are baked + into the compiled kernel. Changing them requires recompilation. + + Returns + ------- + callable + A function that accepts only JAX array arguments and returns a + tuple of JAX output arrays. Compatible with ``jax.jit``. + + Examples + -------- + :: + + @flyc.jit + def my_add(A, B, C, n, const_n, stream): + ... + + wrapped = jax_kernel( + my_add, + out_shapes=lambda a, b: [(a.shape, a.dtype)], + constexpr_kwargs={"const_n": 129}, + runtime_scalars={"n": 128}, + ) + + @jax.jit + def f(a, b): + (c,) = wrapped(a, b) + return c + """ + if constexpr_kwargs is None: + constexpr_kwargs = {} + if runtime_scalars is None: + runtime_scalars = {} + + # JAX requires all primitive parameters to be hashable. + # Convert dicts to frozen tuples of sorted items for the bind() call. + frozen_constexpr = tuple(sorted(constexpr_kwargs.items())) + frozen_scalars = tuple(sorted(runtime_scalars.items())) + + @functools.wraps(flyc_func) + def wrapper(*args): + # Compute output abstract values from the user-provided function. + out_specs = out_shapes(*args) + out_avals = tuple( + core.ShapedArray(shape, dtype) for shape, dtype in out_specs + ) + + # Only JAX arrays are passed as primitive operands — scalars are + # carried as primitive parameters and baked into the compiled kernel. + return flydsl_call_p.bind( + *args, + flyc_func=flyc_func, + out_avals=out_avals, + constexpr_kwargs=frozen_constexpr, + runtime_scalars=frozen_scalars, + ) + + return wrapper diff --git a/scripts/run_tests.sh b/scripts/run_tests.sh index ea70265c..c827412d 100644 --- a/scripts/run_tests.sh +++ b/scripts/run_tests.sh @@ -53,6 +53,10 @@ for example in "${REPO_ROOT}"/examples/*.py; do output=$(python3 "${example}" 2>&1) || { echo " FAIL ${name}"; echo "$output" | tail -10 | sed 's/^/ /'; exit 1 } + if echo "$output" | grep -q "^SKIP:"; then + echo " SKIP ${name} ($(echo "$output" | grep "^SKIP:" | head -1))" + continue + fi if echo "$output" | grep -qE "Result correct: False|All passed: False"; then echo " FAIL ${name}"; echo "$output" | tail -10 | sed 's/^/ /'; exit 1 fi diff --git a/tests/test_jax_vecadd.py b/tests/test_jax_vecadd.py new file mode 100644 index 00000000..111501cf --- /dev/null +++ b/tests/test_jax_vecadd.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Minimal vector-add test using JAX arrays with FlyDSL. + +Demonstrates the eager (Level 1) integration: JAX arrays are wrapped +via ``from_jax`` and passed directly to a ``@flyc.jit`` function. +""" + +import numpy as np +import pytest + +try: + import jax + import jax.numpy as jnp +except ImportError: + pytest.skip("JAX not installed", allow_module_level=True) + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.jax import from_jax + + +# ── Kernel (same as tests/kernels/test_vec_add.py) ───────────────────── + +@flyc.kernel +def vecAddKernel( + A: fx.Tensor, + B: fx.Tensor, + C: fx.Tensor, + block_dim: fx.Constexpr[int], + vec_width: fx.Constexpr[int], +): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + tile_elems = block_dim * vec_width + + tA = fx.logical_divide(A, fx.make_layout(tile_elems, 1)) + tB = fx.logical_divide(B, fx.make_layout(tile_elems, 1)) + tC = fx.logical_divide(C, fx.make_layout(tile_elems, 1)) + + tA = fx.slice(tA, (None, bid)) + tB = fx.slice(tB, (None, bid)) + tC = fx.slice(tC, (None, bid)) + + tA = fx.logical_divide(tA, fx.make_layout(vec_width, 1)) + tB = fx.logical_divide(tB, fx.make_layout(vec_width, 1)) + tC = fx.logical_divide(tC, fx.make_layout(vec_width, 1)) + + copy_bits = vec_width * 32 + RABMemRefTy = fx.MemRefType.get( + fx.T.f32(), fx.LayoutType.get(vec_width, 1), fx.AddressSpace.Register + ) + copyAtom = fx.make_copy_atom(fx.UniversalCopy(copy_bits), fx.Float32) + + rA = fx.memref_alloca(RABMemRefTy, fx.make_layout(vec_width, 1)) + rB = fx.memref_alloca(RABMemRefTy, fx.make_layout(vec_width, 1)) + rC = fx.memref_alloca(RABMemRefTy, fx.make_layout(vec_width, 1)) + + fx.copy_atom_call(copyAtom, fx.slice(tA, (None, tid)), rA) + fx.copy_atom_call(copyAtom, fx.slice(tB, (None, tid)), rB) + + vC = fx.arith.addf(fx.memref_load_vec(rA), fx.memref_load_vec(rB)) + fx.memref_store_vec(vC, rC) + + fx.copy_atom_call(copyAtom, rC, fx.slice(tC, (None, tid))) + + +# ── JIT launcher ──────────────────────────────────────────────────────── + +@flyc.jit +def vecAdd( + A: fx.Tensor, + B: fx.Tensor, + C, + n: fx.Int32, + const_n: fx.Constexpr[int], + block_dim: fx.Constexpr[int], + vec_width: fx.Constexpr[int], + stream: fx.Stream = fx.Stream(None), +): + tile_elems = block_dim * vec_width + grid_x = (n + tile_elems - 1) // tile_elems + vecAddKernel(A, B, C, block_dim, vec_width).launch( + grid=(grid_x, 1, 1), block=(block_dim, 1, 1), stream=stream + ) + + +# ── Test ──────────────────────────────────────────────────────────────── + +def main(): + BLOCK_DIM = 256 + VEC_WIDTH = 4 + TILE = BLOCK_DIM * VEC_WIDTH + N = TILE * 100 # 102400 elements, aligned to tile + + print(f"JAX devices: {jax.devices()}") + print(f"Vector add: N={N}, block={BLOCK_DIM}, vec_width={VEC_WIDTH}") + + # Create JAX arrays on GPU. + key = jax.random.PRNGKey(42) + A = jax.random.normal(key, (N,), dtype=jnp.float32) + B = jax.random.normal(jax.random.PRNGKey(7), (N,), dtype=jnp.float32) + C = jnp.zeros(N, dtype=jnp.float32) + + # Wrap for FlyDSL. + tA = from_jax(A).mark_layout_dynamic(leading_dim=0, divisibility=VEC_WIDTH) + tB = from_jax(B) + tC = from_jax(C) + + # Ensure JAX computations are done before kernel launch. + jax.block_until_ready(A) + jax.block_until_ready(B) + + # Launch FlyDSL kernel. + print("Compiling and launching kernel...") + vecAdd(tA, tB, tC, N, N, BLOCK_DIM, VEC_WIDTH) + + # Synchronize (FlyDSL kernel runs on default HIP stream). + # Use a HIP device sync via JAX. + jax.block_until_ready(C) + + # Verify. + expected = np.asarray(A) + np.asarray(B) + result = np.asarray(C) + max_err = np.max(np.abs(result - expected)) + print(f"Max error: {max_err:.2e}") + + if max_err < 1e-5: + print("PASSED") + else: + print("FAILED") + print(f" A[:8] = {np.asarray(A)[:8]}") + print(f" B[:8] = {np.asarray(B)[:8]}") + print(f" C[:8] = {result[:8]}") + print(f" expected[:8] = {expected[:8]}") + return False + return True + + +if __name__ == "__main__": + ok = main() + exit(0 if ok else 1) diff --git a/tests/unit/test_jax_integration.py b/tests/unit/test_jax_integration.py new file mode 100644 index 00000000..07915db5 --- /dev/null +++ b/tests/unit/test_jax_integration.py @@ -0,0 +1,454 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Unit tests for the FlyDSL JAX integration layer. + +Tests the adapter (from_jax), primitive (jax_kernel), and FFI bridge +independently, without requiring a full FlyDSL kernel compilation +(except where noted). +""" + +import ctypes +import struct + +import numpy as np +import pytest + +try: + import jax + import jax.numpy as jnp +except ImportError: + pytest.skip("JAX not installed", allow_module_level=True) + +from flydsl._mlir import ir + + +# ====================================================================== +# Adapter tests (from_jax / JaxTensorAdaptor) +# ====================================================================== + + +class TestJaxTensorAdaptor: + """Tests for flydsl.jax.adapter.""" + + def test_import_without_torch(self): + """Importing flydsl.jax should not require torch.""" + from flydsl.jax import from_jax, JaxTensorAdaptor + + assert from_jax is not None + assert JaxTensorAdaptor is not None + + def test_basic_f32(self): + from flydsl.jax import from_jax + + a = jnp.ones(128, dtype=jnp.float32) + ta = from_jax(a) + assert ta._orig_dtype == jnp.float32 + assert ta._orig_shape == (128,) + assert ta._orig_strides == (1,) + + def test_2d_array(self): + from flydsl.jax import from_jax + + a = jnp.ones((32, 64), dtype=jnp.float32) + ta = from_jax(a) + assert ta._orig_shape == (32, 64) + assert ta._orig_strides == (64, 1) + + @pytest.mark.parametrize( + "dtype", + [jnp.float32, jnp.float16, jnp.bfloat16, jnp.int32, jnp.int8, jnp.uint8], + ) + def test_multiple_dtypes(self, dtype): + from flydsl.jax import from_jax + + a = jnp.ones(64, dtype=dtype) + ta = from_jax(a) + assert ta._orig_dtype == dtype + + def test_fly_types_returns_memref(self): + from flydsl.jax import from_jax + + a = jnp.ones(128, dtype=jnp.float32) + ta = from_jax(a) + with ir.Context() as ctx: + ctx.load_all_available_dialects() + types = ta.__fly_types__() + assert len(types) == 1 + assert "f32" in str(types[0]) + assert "128" in str(types[0]) + + def test_fly_ptrs_returns_pointers(self): + from flydsl.jax import from_jax + + a = jnp.ones(128, dtype=jnp.float32) + ta = from_jax(a) + with ir.Context() as ctx: + ctx.load_all_available_dialects() + ta.__fly_types__() # build memref desc + ptrs = ta.__fly_ptrs__() + assert len(ptrs) >= 1 + assert isinstance(ptrs[0], int) or isinstance(ptrs[0], ctypes.c_void_p) + + def test_cache_signature_stable(self): + from flydsl.jax import from_jax + + a = jnp.ones(128, dtype=jnp.float32) + ta1 = from_jax(a) + ta2 = from_jax(a) + assert ta1.__cache_signature__() == ta2.__cache_signature__() + + def test_cache_signature_same_dtype_same_sig(self): + """Shape is not part of cache signature (handled by JitFunction).""" + from flydsl.jax import from_jax + + a = jnp.ones(128, dtype=jnp.float32) + b = jnp.ones(256, dtype=jnp.float32) + assert from_jax(a).__cache_signature__() == from_jax(b).__cache_signature__() + + def test_cache_signature_differs_by_dtype(self): + from flydsl.jax import from_jax + + a = jnp.ones(128, dtype=jnp.float32) + b = jnp.ones(128, dtype=jnp.float16) + assert from_jax(a).__cache_signature__() != from_jax(b).__cache_signature__() + + def test_mark_layout_dynamic(self): + from flydsl.jax import from_jax + + a = jnp.ones(128, dtype=jnp.float32) + ta = from_jax(a).mark_layout_dynamic(leading_dim=0, divisibility=4) + # Should return self for chaining + assert ta is not None + assert ta._orig_shape == (128,) + + def test_assumed_align(self): + from flydsl.jax import from_jax + + a = jnp.ones(128, dtype=jnp.float32) + ta = from_jax(a, assumed_align=16) + assert ta.assumed_align == 16 + sig = ta.__cache_signature__() + assert 16 in sig + + def test_rejects_non_jax_array(self): + from flydsl.jax.adapter import JaxTensorAdaptor + + with pytest.raises(TypeError, match="Expected jax.Array"): + JaxTensorAdaptor([1, 2, 3]) + + def test_float8_dtype_if_available(self): + """Float8 arrays should be handled via uint8 view.""" + from flydsl.jax import from_jax + + f8 = getattr(jnp, "float8_e4m3fn", None) + if f8 is None: + pytest.skip("float8_e4m3fn not available in this JAX version") + a = jnp.ones(64, dtype=f8) + ta = from_jax(a) + assert ta._orig_dtype == f8 + + +# ====================================================================== +# Primitive tests (jax_kernel / flydsl_call_p) +# ====================================================================== + + +class TestJaxKernelPrimitive: + """Tests for flydsl.jax.primitive.""" + + def test_eager_single_output(self): + from flydsl.jax import jax_kernel + + call_log = [] + + def mock_fn(*args, **kwargs): + call_log.append({"n_args": len(args), "kwargs": dict(kwargs)}) + + wrapped = jax_kernel( + mock_fn, + out_shapes=lambda a, b: [(a.shape, a.dtype)], + ) + + a = jnp.ones(64, dtype=jnp.float32) + b = jnp.ones(64, dtype=jnp.float32) + result = wrapped(a, b) + + assert len(result) == 1 + assert result[0].shape == (64,) + assert result[0].dtype == jnp.float32 + assert len(call_log) == 1 + # 2 inputs + 1 output = 3 JaxTensorAdaptors + assert call_log[0]["n_args"] == 3 + + def test_eager_multiple_outputs(self): + from flydsl.jax import jax_kernel + + def mock_fn(*args, **kwargs): + pass + + wrapped = jax_kernel( + mock_fn, + out_shapes=lambda a: [ + (a.shape, jnp.float32), + (a.shape, jnp.int32), + ], + ) + + a = jnp.ones(32, dtype=jnp.float32) + result = wrapped(a) + + assert len(result) == 2 + assert result[0].dtype == jnp.float32 + assert result[1].dtype == jnp.int32 + + def test_constexpr_kwargs_forwarded(self): + from flydsl.jax import jax_kernel + + received = {} + + def mock_fn(*args, **kwargs): + received.update(kwargs) + + wrapped = jax_kernel( + mock_fn, + out_shapes=lambda a: [(a.shape, a.dtype)], + constexpr_kwargs={"block_dim": 256, "vec_width": 4}, + ) + + a = jnp.ones(64, dtype=jnp.float32) + wrapped(a) + assert received == {"block_dim": 256, "vec_width": 4} + + def test_runtime_scalars_forwarded(self): + from flydsl.jax import jax_kernel + + call_log = [] + + def mock_fn(*args, **kwargs): + call_log.append({"n_args": len(args), "args": list(args)}) + + wrapped = jax_kernel( + mock_fn, + out_shapes=lambda a: [(a.shape, a.dtype)], + runtime_scalars={"n": 128}, + ) + + a = jnp.ones(64, dtype=jnp.float32) + wrapped(a) + + # 1 input + 1 output + 1 scalar = 3 args + assert call_log[0]["n_args"] == 3 + # Last arg should be the scalar value 128 + assert call_log[0]["args"][-1] == 128 + + def test_abstract_eval(self): + """Primitive abstract eval should propagate shapes/dtypes.""" + from jax._src import core + from flydsl.jax.primitive import flydsl_call_p + + aval_in = core.ShapedArray((128,), jnp.float32) + aval_out = core.ShapedArray((64,), jnp.int32) + + result = flydsl_call_p.abstract_eval( + aval_in, + flyc_func=None, + out_avals=(aval_out,), + constexpr_kwargs=(), + runtime_scalars=(), + ) + + # result is (out_avals, effects) + out_avals = result[0] + assert len(out_avals) == 1 + assert out_avals[0].shape == (64,) + assert out_avals[0].dtype == jnp.int32 + + +# ====================================================================== +# C trampoline tests +# ====================================================================== + + +class TestXlaBridge: + """Tests for the C trampoline (_xla_bridge.so).""" + + @pytest.fixture + def bridge_lib(self): + from flydsl.jax.ffi_bridge import _get_bridge_lib + + return _get_bridge_lib() + + def test_register_and_dispatch(self, bridge_lib): + """Register a function and verify the bridge dispatches correctly.""" + results = [] + + FUNC_T = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + + def test_fn(ptrs_raw): + ptrs = ctypes.cast(ptrs_raw, ctypes.POINTER(ctypes.c_void_p)) + # Read buffer slot + buf_storage = ctypes.cast(ptrs[0], ctypes.POINTER(ctypes.c_void_p)) + results.append(("buf0", buf_storage[0])) + # Read stream slot + stream_storage = ctypes.cast(ptrs[1], ctypes.POINTER(ctypes.c_void_p)) + results.append(("stream", stream_storage[0])) + + cfunc = FUNC_T(test_fn) + func_ptr = ctypes.cast(cfunc, ctypes.c_void_p).value + + slot = bridge_lib.flydsl_xla_register(func_ptr, 1, 0, None) + assert slot >= 0 + + bridge_ptr = bridge_lib.flydsl_xla_get_bridge(slot) + assert bridge_ptr != 0 + + # Simulate XLA calling the bridge + BRIDGE_T = ctypes.CFUNCTYPE( + None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_size_t + ) + bridge_fn = BRIDGE_T(bridge_ptr) + + fake_buf = ctypes.c_void_p(0xBEEF) + buffers = (ctypes.c_void_p * 1)(fake_buf) + opaque = struct.pack("i", slot) + + bridge_fn(0xDEAD, ctypes.cast(buffers, ctypes.c_void_p), opaque, len(opaque)) + + assert results[0] == ("buf0", 0xBEEF) + assert results[1] == ("stream", 0xDEAD) + + def test_scalar_insertion(self, bridge_lib): + """Scalars should be inserted between buffers and stream.""" + results = [] + + FUNC_T = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + + def test_fn(ptrs_raw): + ptrs = ctypes.cast(ptrs_raw, ctypes.POINTER(ctypes.c_void_p)) + for i in range(4): # 2 bufs + 1 scalar + 1 stream + storage = ctypes.cast(ptrs[i], ctypes.POINTER(ctypes.c_void_p)) + results.append(storage[0]) + + cfunc = FUNC_T(test_fn) + func_ptr = ctypes.cast(cfunc, ctypes.c_void_p).value + + scalar_val = (ctypes.c_int64 * 1)(42) + slot = bridge_lib.flydsl_xla_register( + func_ptr, 2, 1, ctypes.cast(scalar_val, ctypes.c_void_p) + ) + + bridge_ptr = bridge_lib.flydsl_xla_get_bridge(slot) + BRIDGE_T = ctypes.CFUNCTYPE( + None, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p, ctypes.c_size_t + ) + bridge_fn = BRIDGE_T(bridge_ptr) + + bufs = (ctypes.c_void_p * 2)(ctypes.c_void_p(0xAA), ctypes.c_void_p(0xBB)) + opaque = struct.pack("i", slot) + + bridge_fn(0xCC, ctypes.cast(bufs, ctypes.c_void_p), opaque, len(opaque)) + + assert results[0] == 0xAA # buf0 + assert results[1] == 0xBB # buf1 + assert results[2] == 42 # scalar + assert results[3] == 0xCC # stream + + def test_registration_returns_incremental_slots(self, bridge_lib): + FUNC_T = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + noop = FUNC_T(lambda p: None) + ptr = ctypes.cast(noop, ctypes.c_void_p).value + + s1 = bridge_lib.flydsl_xla_register(ptr, 1, 0, None) + s2 = bridge_lib.flydsl_xla_register(ptr, 1, 0, None) + assert s2 == s1 + 1 + + def test_rejects_too_many_buffers(self, bridge_lib): + FUNC_T = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + noop = FUNC_T(lambda p: None) + ptr = ctypes.cast(noop, ctypes.c_void_p).value + + slot = bridge_lib.flydsl_xla_register(ptr, 999, 0, None) + assert slot == -1 + + def test_rejects_too_many_scalars(self, bridge_lib): + FUNC_T = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + noop = FUNC_T(lambda p: None) + ptr = ctypes.cast(noop, ctypes.c_void_p).value + + slot = bridge_lib.flydsl_xla_register(ptr, 1, 999, None) + assert slot == -1 + + +# ====================================================================== +# Registration deduplication +# ====================================================================== + + +class TestRegistrationDedup: + """Test that compile_and_register deduplicates by shape signature.""" + + def test_same_shapes_produce_same_target_name(self): + """Same function name + shapes + kwargs should hash to the same target.""" + import hashlib + + def _build_target_name(func_name, input_shapes, output_shapes, constexpr, scalars): + sig_parts = [func_name] + for shape, dtype in input_shapes: + sig_parts.append(f"i{shape}:{dtype}") + for shape, dtype in output_shapes: + sig_parts.append(f"o{shape}:{dtype}") + for k, v in sorted(constexpr.items()): + sig_parts.append(f"c{k}={v}") + for k, v in sorted(scalars.items()): + sig_parts.append(f"r{k}={v}") + name_hash = hashlib.sha256("|".join(sig_parts).encode()).hexdigest()[:16] + return f"flydsl_{name_hash}" + + name1 = _build_target_name( + "vecAdd", + [((128,), jnp.float32), ((128,), jnp.float32)], + [((128,), jnp.float32)], + {"const_n": 129}, + {"n": 128}, + ) + name2 = _build_target_name( + "vecAdd", + [((128,), jnp.float32), ((128,), jnp.float32)], + [((128,), jnp.float32)], + {"const_n": 129}, + {"n": 128}, + ) + assert name1 == name2 + + def test_different_shapes_produce_different_target_name(self): + import hashlib + + def _build_target_name(func_name, input_shapes, output_shapes, constexpr, scalars): + sig_parts = [func_name] + for shape, dtype in input_shapes: + sig_parts.append(f"i{shape}:{dtype}") + for shape, dtype in output_shapes: + sig_parts.append(f"o{shape}:{dtype}") + for k, v in sorted(constexpr.items()): + sig_parts.append(f"c{k}={v}") + for k, v in sorted(scalars.items()): + sig_parts.append(f"r{k}={v}") + name_hash = hashlib.sha256("|".join(sig_parts).encode()).hexdigest()[:16] + return f"flydsl_{name_hash}" + + name1 = _build_target_name( + "vecAdd", + [((128,), jnp.float32)], + [((128,), jnp.float32)], + {}, {}, + ) + name2 = _build_target_name( + "vecAdd", + [((256,), jnp.float32)], + [((256,), jnp.float32)], + {}, {}, + ) + assert name1 != name2