Skip to content

Commit 1e5bc44

Browse files
Copilotavik-palwsmoses
authored
feat: add export_to_enzymejax for automated JAX/EnzymeAD integration (#1934)
* Initial plan * Add export_to_enzymeax function for JAX integration Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com> * Add comprehensive tests for export_to_enzymeax Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com> * Fix code review issues: binary mode, ComplexF16 support, Python indentation Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com> * Add documentation for export_to_enzymeax function Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com> * chore: run formatter * chore: cleanup * fix: proper export * Use existing seen_args infrastructure instead of manual inverse map Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com> * Rename export_to_enzymeax to export_to_enzymejax Co-authored-by: wsmoses <30564094+wsmoses@users.noreply.github.com> * feat: add path to docstring * fix: use NPZ for proper export * feat: add size checks * feat: automatically run jit * feat: preserve sharding * test: exported functions * test: cleanup * Apply suggestions from code review * Apply suggestion from @avik-pal * Apply suggestion from @avik-pal --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: avik-pal <30564094+avik-pal@users.noreply.github.com> Co-authored-by: Avik Pal <avikpal@mit.edu> Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com> Co-authored-by: wsmoses <30564094+wsmoses@users.noreply.github.com>
1 parent e9fbc25 commit 1e5bc44

File tree

13 files changed

+691
-47
lines changed

13 files changed

+691
-47
lines changed

CondaPkg.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[deps]
2-
python = "<=3.13,>=3.9,<4"
2+
python = "<=3.12,>=3.9,<4"
33

44
[pip.deps]
5-
jax = ">= 0.6"
5+
jax = ">= 0.5"
66
tensorflow = ">= 2.17"
7-
numpy = ">= 2"
7+
numpy = ">= 1, >= 2"

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
3939
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
4040
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
4141
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
42+
NPZ = "15e1cf62-19b3-5cfa-8e77-841668bca605"
4243
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
4344
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
4445
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
@@ -62,6 +63,7 @@ ReactantFloat8sExt = "Float8s"
6263
ReactantKernelAbstractionsExt = "KernelAbstractions"
6364
ReactantMPIExt = "MPI"
6465
ReactantNNlibExt = ["NNlib", "Statistics"]
66+
ReactantNPZExt = "NPZ"
6567
ReactantOffsetArraysExt = "OffsetArrays"
6668
ReactantOneHotArraysExt = "OneHotArrays"
6769
ReactantPythonCallExt = "PythonCall"
@@ -96,6 +98,7 @@ Libdl = "1.10"
9698
LinearAlgebra = "1.10"
9799
MPI = "0.20"
98100
NNlib = "0.9.26"
101+
NPZ = "0.4"
99102
OffsetArrays = "1"
100103
OneHotArrays = "0.2.10"
101104
OrderedCollections = "1"

docs/src/api/serialization.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,22 @@ or [TensorFlow Hub](https://tensorflow.org/hub). Refer to the
2727
```@docs
2828
Reactant.Serialization.export_as_tf_saved_model
2929
```
30+
31+
## Exporting to JAX via EnzymeAD
32+
33+
!!! note "Load NPZ"
34+
35+
This export functionality requires the `NPZ` package to be loaded.
36+
37+
This export functionality generates:
38+
39+
1. A `.mlir` file containing the StableHLO representation of your Julia function
40+
2. Input `.npz` files containing the input arrays for the function
41+
3. A Python script that wraps the function for use with `enzyme_ad.jax.hlo_call`
42+
43+
The generated Python script can be immediately used with JAX and EnzymeAD without any
44+
additional Julia dependencies.
45+
46+
```@docs
47+
Reactant.Serialization.EnzymeJAX.export_to_enzymejax
48+
```

ext/ReactantNPZExt.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
module ReactantNPZExt
2+
3+
using NPZ: npzwrite
4+
using Reactant.Serialization: Serialization, EnzymeJAX
5+
6+
Serialization.serialization_supported(::Val{:NPZ}) = true
7+
8+
# Helper function to save all input data to a single NPZ file
9+
function EnzymeJAX.save_inputs_npz_impl(
10+
output_path::String, inputs::Dict{String,<:Union{AbstractArray,Number}}
11+
)
12+
# Transpose arrays for Python/NumPy (row-major vs column-major)
13+
transposed_inputs = Dict{String,Union{AbstractArray,Number}}()
14+
for (name, arr) in inputs
15+
transposed_inputs[name] =
16+
arr isa Number ? arr : permutedims(arr, reverse(1:ndims(arr)))
17+
end
18+
19+
# Save all inputs to a single NPZ file with compression
20+
npzwrite(output_path, transposed_inputs)
21+
return output_path
22+
end
23+
24+
end # module

ext/ReactantPythonCallExt/ReactantPythonCallExt.jl

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module ReactantPythonCallExt
33
using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist
44
using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay
55
using Reactant.Ops: @opcall
6+
using Reactant.Serialization: NUMPY_SIMPLE_TYPES
67

78
const jaxptr = Ref{Py}()
89
const jnpptr = Ref{Py}()
@@ -15,24 +16,6 @@ const npptr = Ref{Py}()
1516

1617
const SAVED_MODEL_EXPORT_SUPPORTED = Ref{Bool}(false)
1718

18-
const NUMPY_SIMPLE_TYPES = Dict(
19-
Bool => :bool,
20-
Int8 => :int8,
21-
Int16 => :int16,
22-
Int32 => :int32,
23-
Int64 => :int64,
24-
UInt8 => :uint8,
25-
UInt16 => :uint16,
26-
UInt32 => :uint32,
27-
UInt64 => :uint64,
28-
Float16 => :float16,
29-
Float32 => :float32,
30-
Float64 => :float64,
31-
ComplexF16 => :complex16,
32-
ComplexF32 => :complex32,
33-
ComplexF64 => :complex64,
34-
)
35-
3619
function __init__()
3720
try
3821
jaxptr[] = pyimport("jax")

src/Compiler.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,7 +1411,7 @@ function __get_compile_options_and_kwargs(;
14111411
)
14121412
end
14131413

1414-
function compile_mlir(f, args; client=nothing, kwargs...)
1414+
function compile_mlir(f, args; client=nothing, drop_unsupported_attributes=false, kwargs...)
14151415
client = client !== nothing ? client : XLA.default_backend()
14161416
backend = XLA.platform_name(client)
14171417

@@ -1441,6 +1441,11 @@ function compile_mlir(f, args; client=nothing, kwargs...)
14411441
mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas
14421442
)
14431443

1444+
if drop_unsupported_attributes
1445+
# Drop some of our attributes
1446+
run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes")
1447+
end
1448+
14441449
return mod, mlir_fn_res
14451450
end
14461451

@@ -3571,6 +3576,9 @@ function compile_xla(
35713576
mod, f; mlir_fn_res.num_partitions, mlir_fn_res.num_replicas
35723577
)
35733578

3579+
# Drop some of our attributes
3580+
run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes")
3581+
35743582
# compile MLIR module to XLA executable
35753583
global_device_ids = collect(Int64, mlir_fn_res.global_device_ids)
35763584
mlir_fn_res.is_sharded && (device = nothing)
@@ -3584,9 +3592,6 @@ function compile_xla(
35843592
module_string = ""
35853593
end
35863594

3587-
# Drop some of our attributes
3588-
run_pass_pipeline!(mod, "drop-unsupported-attributes", "drop_enzymexla_attributes")
3589-
35903595
if before_xla_optimizations
35913596
exec = nothing
35923597
hlo_modules = XLA.HloModule(mod)

src/Sharding.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ function HloSharding(sharding::NamedSharding, client::XLA.IFRT.Client, _, x)
949949
data = XLA.IFRT.AsyncArray(client, x, ifrt_sharding)
950950

951951
# XXX: Can we auto-pad this case too? Will think about it later, for now use
952-
# NamedSharidng
952+
# NamedSharding
953953
return data, ShardInfo(hlo_sharding, device_to_array_slices), nothing
954954
end
955955

@@ -997,7 +997,7 @@ function (sharding::HloSharding)(
997997
data = XLA.IFRT.AsyncArray(client, x, ifrt_sharding)
998998

999999
# XXX: Can we auto-pad this case too? Will think about it later, for now use
1000-
# NamedSharidng
1000+
# NamedSharding
10011001
return data, ShardInfo(sharding, device_to_array_slices), nothing
10021002
end
10031003

0 commit comments

Comments
 (0)