From 2b2e45b41b5420774d25bab39b8bb37a19577b52 Mon Sep 17 00:00:00 2001 From: Adam Foster Date: Tue, 27 May 2025 11:05:29 +0000 Subject: [PATCH 1/3] Use TritonCompilerParams --- .../pallas/attention/custom_gradients.py | 21 ++++++++++--------- .../pallas/attention/forward_laplacian.py | 9 ++++---- folx/experimental/pallas/attention/mhsa.py | 5 +++-- folx/experimental/pallas/attention/mhsea.py | 5 +++-- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/folx/experimental/pallas/attention/custom_gradients.py b/folx/experimental/pallas/attention/custom_gradients.py index 6615723..969406a 100644 --- a/folx/experimental/pallas/attention/custom_gradients.py +++ b/folx/experimental/pallas/attention/custom_gradients.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu from .mhsa import mhsa_kernel, reference_mhsa_kernel from .mhsea import mhsea_kernel, reference_mhsea_kernel @@ -53,8 +54,8 @@ def mhsa_forward( out_shape=jax.ShapeDtypeStruct( shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype ), - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages ), debug=False, interpret=interpret, @@ -113,8 +114,8 @@ def mhsa_backward( shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages ), debug=False, interpret=interpret, @@ -268,8 +269,8 @@ def mhsea_forward( shape=(batch_len, seq_len, num_heads), dtype=v.dtype ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages ), debug=False, interpret=interpret, @@ -372,8 +373,8 @@ def mhsea_backward( shape=(batch_len, seq_len, num_heads, seq_len), dtype=e.dtype ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages ), debug=False, interpret=interpret, @@ -433,8 +434,8 @@ def mhsea_backward( shape=(batch_len, seq_len, num_heads, head_len), dtype=v.dtype ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages ), debug=False, interpret=interpret, diff --git a/folx/experimental/pallas/attention/forward_laplacian.py b/folx/experimental/pallas/attention/forward_laplacian.py index 236f17f..fa6407d 100644 --- a/folx/experimental/pallas/attention/forward_laplacian.py +++ b/folx/experimental/pallas/attention/forward_laplacian.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu from folx import forward_laplacian from folx.api import FwdJacobian, FwdLaplArray @@ -153,8 +154,8 @@ def mhsa_forward_laplacian( dtype=q.dtype, # o.laplacian ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages ), debug=False, interpret=interpret, @@ -588,8 +589,8 @@ def mhsea_forward_laplacian( dtype=v.dtype, # o.laplacian ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages ), debug=False, interpret=interpret, diff --git a/folx/experimental/pallas/attention/mhsa.py b/folx/experimental/pallas/attention/mhsa.py index 32dc4b9..0d549cb 100644 --- a/folx/experimental/pallas/attention/mhsa.py +++ b/folx/experimental/pallas/attention/mhsa.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu from .utils import ( big_number, @@ -58,8 +59,8 @@ def mhsa( out_shape=jax.ShapeDtypeStruct( shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype ), - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages ), debug=False, interpret=interpret, diff --git a/folx/experimental/pallas/attention/mhsea.py b/folx/experimental/pallas/attention/mhsea.py index 543f75e..11e1c06 100644 --- a/folx/experimental/pallas/attention/mhsea.py +++ b/folx/experimental/pallas/attention/mhsea.py @@ -5,6 +5,7 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu from .utils import ( big_number, @@ -58,8 +59,8 @@ def mhsea( shape=(batch_len, seq_len, num_heads), dtype=q.dtype ), # lse ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages ), debug=False, interpret=interpret, From a4a1401065f9e0ed38e5a230c2212791f156fc92 Mon Sep 17 00:00:00 2001 From: Adam Foster Date: Tue, 27 May 2025 11:12:49 +0000 Subject: [PATCH 2/3] Import statement has to differ across versions --- folx/experimental/pallas/attention/custom_gradients.py | 6 +++++- folx/experimental/pallas/attention/forward_laplacian.py | 6 +++++- folx/experimental/pallas/attention/mhsa.py | 6 +++++- folx/experimental/pallas/attention/mhsea.py | 6 +++++- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/folx/experimental/pallas/attention/custom_gradients.py b/folx/experimental/pallas/attention/custom_gradients.py index 969406a..0b5d8bf 100644 --- a/folx/experimental/pallas/attention/custom_gradients.py +++ b/folx/experimental/pallas/attention/custom_gradients.py @@ -5,7 +5,11 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu + +try: + from jax.experimental.pallas import triton as plgpu +except ImportError: + from jax.experimental.pallas import gpu as plgpu from .mhsa import mhsa_kernel, reference_mhsa_kernel from .mhsea import mhsea_kernel, reference_mhsea_kernel diff --git a/folx/experimental/pallas/attention/forward_laplacian.py b/folx/experimental/pallas/attention/forward_laplacian.py index fa6407d..551eb79 100644 --- a/folx/experimental/pallas/attention/forward_laplacian.py +++ b/folx/experimental/pallas/attention/forward_laplacian.py @@ -5,7 +5,11 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu + +try: + from jax.experimental.pallas import triton as plgpu +except ImportError: + from jax.experimental.pallas import gpu as plgpu from folx import forward_laplacian from folx.api import FwdJacobian, FwdLaplArray diff --git a/folx/experimental/pallas/attention/mhsa.py b/folx/experimental/pallas/attention/mhsa.py index 0d549cb..3276adb 100644 --- a/folx/experimental/pallas/attention/mhsa.py +++ b/folx/experimental/pallas/attention/mhsa.py @@ -5,7 +5,11 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu + +try: + from jax.experimental.pallas import triton as plgpu +except ImportError: + from jax.experimental.pallas import gpu as plgpu from .utils import ( big_number, diff --git a/folx/experimental/pallas/attention/mhsea.py b/folx/experimental/pallas/attention/mhsea.py index 11e1c06..bd19586 100644 --- a/folx/experimental/pallas/attention/mhsea.py +++ b/folx/experimental/pallas/attention/mhsea.py @@ -5,7 +5,11 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl -from jax.experimental.pallas import gpu as plgpu + +try: + from jax.experimental.pallas import triton as plgpu +except ImportError: + from jax.experimental.pallas import gpu as plgpu from .utils import ( big_number, From 2634e24ad7cb19af67df5dcef9d7829a631dea04 Mon Sep 17 00:00:00 2001 From: Adam Foster Date: Wed, 28 May 2025 12:25:03 +0000 Subject: [PATCH 3/3] Fix for lower versions --- .../pallas/attention/custom_gradients.py | 26 +++++-------------- .../pallas/attention/forward_laplacian.py | 14 +++------- folx/experimental/pallas/attention/mhsa.py | 10 ++----- folx/experimental/pallas/attention/mhsea.py | 10 ++----- folx/experimental/pallas/attention/utils.py | 15 +++++++++++ 5 files changed, 28 insertions(+), 47 deletions(-) diff --git a/folx/experimental/pallas/attention/custom_gradients.py b/folx/experimental/pallas/attention/custom_gradients.py index 0b5d8bf..e20609c 100644 --- a/folx/experimental/pallas/attention/custom_gradients.py +++ b/folx/experimental/pallas/attention/custom_gradients.py @@ -6,15 +6,11 @@ import jax.numpy as jnp from jax.experimental import pallas as pl -try: - from jax.experimental.pallas import triton as plgpu -except ImportError: - from jax.experimental.pallas import gpu as plgpu - from .mhsa import mhsa_kernel, reference_mhsa_kernel from .mhsea import mhsea_kernel, reference_mhsea_kernel from .utils import ( big_number, + compiler_params, compute_q_and_kv_block_len, create_grid, get_lse_block_spec, @@ -58,9 +54,7 @@ def mhsa_forward( out_shape=jax.ShapeDtypeStruct( shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype ), - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps, num_stages=num_stages - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsa_forward', @@ -118,9 +112,7 @@ def mhsa_backward( shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype ), ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps, num_stages=num_stages - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsa_backward', @@ -273,9 +265,7 @@ def mhsea_forward( shape=(batch_len, seq_len, num_heads), dtype=v.dtype ), ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps, num_stages=num_stages - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhea_forward', @@ -377,9 +367,7 @@ def mhsea_backward( shape=(batch_len, seq_len, num_heads, seq_len), dtype=e.dtype ), ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps, num_stages=num_stages - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsea_backward_q_vjp', @@ -438,9 +426,7 @@ def mhsea_backward( shape=(batch_len, seq_len, num_heads, head_len), dtype=v.dtype ), ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps, num_stages=num_stages - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsea_backward_kv_vjp', diff --git a/folx/experimental/pallas/attention/forward_laplacian.py b/folx/experimental/pallas/attention/forward_laplacian.py index 551eb79..f19d38a 100644 --- a/folx/experimental/pallas/attention/forward_laplacian.py +++ b/folx/experimental/pallas/attention/forward_laplacian.py @@ -6,11 +6,6 @@ import jax.numpy as jnp from jax.experimental import pallas as pl -try: - from jax.experimental.pallas import triton as plgpu -except ImportError: - from jax.experimental.pallas import gpu as plgpu - from folx import forward_laplacian from folx.api import FwdJacobian, FwdLaplArray @@ -18,6 +13,7 @@ from .mhsea import reference_mhsea_kernel from .utils import ( big_number, + compiler_params, compute_q_and_kv_block_len, create_grid, get_input_mask_block_spec, @@ -158,9 +154,7 @@ def mhsa_forward_laplacian( dtype=q.dtype, # o.laplacian ), ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps, num_stages=num_stages - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsa_forward_laplacian', @@ -593,9 +587,7 @@ def mhsea_forward_laplacian( dtype=v.dtype, # o.laplacian ), ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps, num_stages=num_stages - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsea_forward_laplacian', diff --git a/folx/experimental/pallas/attention/mhsa.py b/folx/experimental/pallas/attention/mhsa.py index 3276adb..67ea1bf 100644 --- a/folx/experimental/pallas/attention/mhsa.py +++ b/folx/experimental/pallas/attention/mhsa.py @@ -6,13 +6,9 @@ import jax.numpy as jnp from jax.experimental import pallas as pl -try: - from jax.experimental.pallas import triton as plgpu -except ImportError: - from jax.experimental.pallas import gpu as plgpu - from .utils import ( big_number, + compiler_params, compute_q_and_kv_block_len, create_grid, get_mask_block_spec, @@ -63,9 +59,7 @@ def mhsa( out_shape=jax.ShapeDtypeStruct( shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype ), - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps, num_stages=num_stages - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsa', diff --git a/folx/experimental/pallas/attention/mhsea.py b/folx/experimental/pallas/attention/mhsea.py index bd19586..3970924 100644 --- a/folx/experimental/pallas/attention/mhsea.py +++ b/folx/experimental/pallas/attention/mhsea.py @@ -6,13 +6,9 @@ import jax.numpy as jnp from jax.experimental import pallas as pl -try: - from jax.experimental.pallas import triton as plgpu -except ImportError: - from jax.experimental.pallas import gpu as plgpu - from .utils import ( big_number, + compiler_params, compute_q_and_kv_block_len, create_grid, get_lse_block_spec, @@ -63,9 +59,7 @@ def mhsea( shape=(batch_len, seq_len, num_heads), dtype=q.dtype ), # lse ], - compiler_params=plgpu.TritonCompilerParams( - num_warps=num_warps, num_stages=num_stages - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsea', diff --git a/folx/experimental/pallas/attention/utils.py b/folx/experimental/pallas/attention/utils.py index 70109ce..142fac5 100644 --- a/folx/experimental/pallas/attention/utils.py +++ b/folx/experimental/pallas/attention/utils.py @@ -4,6 +4,14 @@ import jax.numpy as jnp from jax.experimental import pallas as pl +try: + from jax.experimental.pallas import triton as plgpu +except ImportError: + from jax.experimental.pallas import gpu as plgpu + + +from packaging.version import Version + def sum_columns(x: jax.Array) -> jax.Array: return x.sum(axis=1, keepdims=True) @@ -210,3 +218,10 @@ def big_number(dtype) -> float: return 1e40 else: raise ValueError(f'Unexpected dtype {dtype}') + + +def compiler_params(num_warps, num_stages): + if Version(jax.__version__) >= Version('0.4.34'): + return plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages) + else: + return dict(triton=dict(num_warps=num_warps, num_stages=num_stages))