diff --git a/folx/experimental/pallas/attention/custom_gradients.py b/folx/experimental/pallas/attention/custom_gradients.py index 6615723..e20609c 100644 --- a/folx/experimental/pallas/attention/custom_gradients.py +++ b/folx/experimental/pallas/attention/custom_gradients.py @@ -10,6 +10,7 @@ from .mhsea import mhsea_kernel, reference_mhsea_kernel from .utils import ( big_number, + compiler_params, compute_q_and_kv_block_len, create_grid, get_lse_block_spec, @@ -53,9 +54,7 @@ def mhsa_forward( out_shape=jax.ShapeDtypeStruct( shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype ), - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsa_forward', @@ -113,9 +112,7 @@ def mhsa_backward( shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsa_backward', @@ -268,9 +265,7 @@ def mhsea_forward( shape=(batch_len, seq_len, num_heads), dtype=v.dtype ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhea_forward', @@ -372,9 +367,7 @@ def mhsea_backward( shape=(batch_len, seq_len, num_heads, seq_len), dtype=e.dtype ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsea_backward_q_vjp', @@ -433,9 +426,7 @@ def mhsea_backward( shape=(batch_len, seq_len, num_heads, head_len), dtype=v.dtype ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsea_backward_kv_vjp', diff --git a/folx/experimental/pallas/attention/forward_laplacian.py b/folx/experimental/pallas/attention/forward_laplacian.py index 236f17f..f19d38a 100644 --- a/folx/experimental/pallas/attention/forward_laplacian.py +++ b/folx/experimental/pallas/attention/forward_laplacian.py @@ -13,6 +13,7 @@ from .mhsea import reference_mhsea_kernel from .utils import ( big_number, + compiler_params, compute_q_and_kv_block_len, create_grid, get_input_mask_block_spec, @@ -153,9 +154,7 @@ def mhsa_forward_laplacian( dtype=q.dtype, # o.laplacian ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsa_forward_laplacian', @@ -588,9 +587,7 @@ def mhsea_forward_laplacian( dtype=v.dtype, # o.laplacian ), ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsea_forward_laplacian', diff --git a/folx/experimental/pallas/attention/mhsa.py b/folx/experimental/pallas/attention/mhsa.py index 32dc4b9..67ea1bf 100644 --- a/folx/experimental/pallas/attention/mhsa.py +++ b/folx/experimental/pallas/attention/mhsa.py @@ -8,6 +8,7 @@ from .utils import ( big_number, + compiler_params, compute_q_and_kv_block_len, create_grid, get_mask_block_spec, @@ -58,9 +59,7 @@ def mhsa( out_shape=jax.ShapeDtypeStruct( shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype ), - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsa', diff --git a/folx/experimental/pallas/attention/mhsea.py b/folx/experimental/pallas/attention/mhsea.py index 543f75e..3970924 100644 --- a/folx/experimental/pallas/attention/mhsea.py +++ b/folx/experimental/pallas/attention/mhsea.py @@ -8,6 +8,7 @@ from .utils import ( big_number, + compiler_params, compute_q_and_kv_block_len, create_grid, get_lse_block_spec, @@ -58,9 +59,7 @@ def mhsea( shape=(batch_len, seq_len, num_heads), dtype=q.dtype ), # lse ], - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages), debug=False, interpret=interpret, name='mhsea', diff --git a/folx/experimental/pallas/attention/utils.py b/folx/experimental/pallas/attention/utils.py index 70109ce..142fac5 100644 --- a/folx/experimental/pallas/attention/utils.py +++ b/folx/experimental/pallas/attention/utils.py @@ -4,6 +4,14 @@ import jax.numpy as jnp from jax.experimental import pallas as pl +try: + from jax.experimental.pallas import triton as plgpu +except ImportError: + from jax.experimental.pallas import gpu as plgpu + + +from packaging.version import Version + def sum_columns(x: jax.Array) -> jax.Array: return x.sum(axis=1, keepdims=True) @@ -210,3 +218,10 @@ def big_number(dtype) -> float: return 1e40 else: raise ValueError(f'Unexpected dtype {dtype}') + + +def compiler_params(num_warps, num_stages): + if Version(jax.__version__) >= Version('0.4.34'): + return plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages) + else: + return dict(triton=dict(num_warps=num_warps, num_stages=num_stages))