diff --git a/flash-attn2/build.toml b/flash-attn2/build.toml index 7d61631f..d9c11d97 100644 --- a/flash-attn2/build.toml +++ b/flash-attn2/build.toml @@ -169,7 +169,7 @@ src = [ backend = "xpu" depends = [ "torch", - "cutlass_sycl", + "sycl_tla", ] src = [ "flash_attn_xpu/flash_api.cpp", diff --git a/flash-attn2/flake.lock b/flash-attn2/flake.lock index 801c84bc..f0e88e05 100644 --- a/flash-attn2/flake.lock +++ b/flash-attn2/flake.lock @@ -41,16 +41,15 @@ "rust-overlay": "rust-overlay" }, "locked": { - "lastModified": 1769443799, - "narHash": "sha256-iSJyElXgv2SWMXwJuTAHFcZQI8ViWNJpMKFZ1JxyfPE=", + "lastModified": 1772639128, + "narHash": "sha256-2tmxkshScPr/NK3JFJ5KVr/U1qhmdsqlut+Kfe7zLNY=", "owner": "huggingface", "repo": "kernels", - "rev": "30685a79203ed855b328cee874698d13e82fb3ae", + "rev": "20d9f69982d233bcd3fdb3d0daf35fe03a233413", "type": "github" }, "original": { "owner": "huggingface", - "ref": "v0.12.1", "repo": "kernels", "type": "github" } diff --git a/flash-attn2/flake.nix b/flash-attn2/flake.nix index ff53a773..22ad0199 100644 --- a/flash-attn2/flake.nix +++ b/flash-attn2/flake.nix @@ -2,7 +2,7 @@ description = "Flake for flash-attn kernel"; inputs = { - kernel-builder.url = "github:huggingface/kernels/v0.12.1"; + kernel-builder.url = "github:huggingface/kernels"; }; outputs =