Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions folx/experimental/pallas/attention/custom_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .mhsea import mhsea_kernel, reference_mhsea_kernel
from .utils import (
big_number,
compiler_params,
compute_q_and_kv_block_len,
create_grid,
get_lse_block_spec,
Expand Down Expand Up @@ -53,9 +54,7 @@ def mhsa_forward(
out_shape=jax.ShapeDtypeStruct(
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
),
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
debug=False,
interpret=interpret,
name='mhsa_forward',
Expand Down Expand Up @@ -113,9 +112,7 @@ def mhsa_backward(
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
),
],
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
debug=False,
interpret=interpret,
name='mhsa_backward',
Expand Down Expand Up @@ -268,9 +265,7 @@ def mhsea_forward(
shape=(batch_len, seq_len, num_heads), dtype=v.dtype
),
],
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
debug=False,
interpret=interpret,
name='mhea_forward',
Expand Down Expand Up @@ -372,9 +367,7 @@ def mhsea_backward(
shape=(batch_len, seq_len, num_heads, seq_len), dtype=e.dtype
),
],
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
debug=False,
interpret=interpret,
name='mhsea_backward_q_vjp',
Expand Down Expand Up @@ -433,9 +426,7 @@ def mhsea_backward(
shape=(batch_len, seq_len, num_heads, head_len), dtype=v.dtype
),
],
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
debug=False,
interpret=interpret,
name='mhsea_backward_kv_vjp',
Expand Down
9 changes: 3 additions & 6 deletions folx/experimental/pallas/attention/forward_laplacian.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .mhsea import reference_mhsea_kernel
from .utils import (
big_number,
compiler_params,
compute_q_and_kv_block_len,
create_grid,
get_input_mask_block_spec,
Expand Down Expand Up @@ -153,9 +154,7 @@ def mhsa_forward_laplacian(
dtype=q.dtype, # o.laplacian
),
],
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
debug=False,
interpret=interpret,
name='mhsa_forward_laplacian',
Expand Down Expand Up @@ -588,9 +587,7 @@ def mhsea_forward_laplacian(
dtype=v.dtype, # o.laplacian
),
],
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
debug=False,
interpret=interpret,
name='mhsea_forward_laplacian',
Expand Down
5 changes: 2 additions & 3 deletions folx/experimental/pallas/attention/mhsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .utils import (
big_number,
compiler_params,
compute_q_and_kv_block_len,
create_grid,
get_mask_block_spec,
Expand Down Expand Up @@ -58,9 +59,7 @@ def mhsa(
out_shape=jax.ShapeDtypeStruct(
shape=(batch_len, seq_len, num_heads, head_len), dtype=q.dtype
),
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
debug=False,
interpret=interpret,
name='mhsa',
Expand Down
5 changes: 2 additions & 3 deletions folx/experimental/pallas/attention/mhsea.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .utils import (
big_number,
compiler_params,
compute_q_and_kv_block_len,
create_grid,
get_lse_block_spec,
Expand Down Expand Up @@ -58,9 +59,7 @@ def mhsea(
shape=(batch_len, seq_len, num_heads), dtype=q.dtype
), # lse
],
compiler_params=dict(
triton=dict(num_warps=num_warps, num_stages=num_stages)
),
compiler_params=compiler_params(num_warps=num_warps, num_stages=num_stages),
debug=False,
interpret=interpret,
name='mhsea',
Expand Down
15 changes: 15 additions & 0 deletions folx/experimental/pallas/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import jax.numpy as jnp
from jax.experimental import pallas as pl

try:
from jax.experimental.pallas import triton as plgpu
except ImportError:
from jax.experimental.pallas import gpu as plgpu


from packaging.version import Version


def sum_columns(x: jax.Array) -> jax.Array:
return x.sum(axis=1, keepdims=True)
Expand Down Expand Up @@ -210,3 +218,10 @@ def big_number(dtype) -> float:
return 1e40
else:
raise ValueError(f'Unexpected dtype {dtype}')


def compiler_params(num_warps, num_stages):
if Version(jax.__version__) >= Version('0.4.34'):
return plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages)
else:
return dict(triton=dict(num_warps=num_warps, num_stages=num_stages))