Skip to content

Commit f593c26

Browse files
authored
feat: precompile common workloads (#1485)
* feat: precompile common workloads * fix: try reactivating precompile * fix: add behind a check * perf: run compilation time experiments on each commit * fix: add all deps * fix: only pre-compile forward pass for now * test: disable some testing (drop me) * fix: try moving to module * fix: bad rebase * fix: clear oc cache
1 parent 18336f8 commit f593c26

File tree

4 files changed

+90
-7
lines changed

4 files changed

+90
-7
lines changed

Project.toml

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
5656
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5757

5858
[sources]
59-
LuxCore = {path = "lib/LuxCore"}
60-
LuxLib = {path = "lib/LuxLib"}
61-
MLDataDevices = {path = "lib/MLDataDevices"}
62-
WeightInitializers = {path = "lib/WeightInitializers"}
59+
LuxCore = { path = "lib/LuxCore" }
60+
LuxLib = { path = "lib/LuxLib" }
61+
MLDataDevices = { path = "lib/MLDataDevices" }
62+
WeightInitializers = { path = "lib/WeightInitializers" }
6363

6464
[extensions]
6565
LuxComponentArraysExt = "ComponentArrays"
@@ -71,7 +71,16 @@ LuxMLUtilsExt = "MLUtils"
7171
LuxMPIExt = "MPI"
7272
LuxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
7373
LuxMooncakeExt = "Mooncake"
74-
LuxReactantExt = ["Enzyme", "Reactant"]
74+
LuxReactantExt = [
75+
"Enzyme",
76+
"Reactant",
77+
"ReactantCore",
78+
"LuxLib",
79+
"MLDataDevices",
80+
"Optimisers",
81+
"NNlib",
82+
"Statistics",
83+
]
7584
LuxReverseDiffExt = ["FunctionWrappers", "ReverseDiff"]
7685
LuxSimpleChainsExt = "SimpleChains"
7786
LuxTrackerExt = "Tracker"

ext/LuxReactantExt/LuxReactantExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
module LuxReactantExt
22

3+
using ADTypes: AutoEnzyme
34
using Enzyme: Enzyme, Active, Const, Duplicated
45
using Functors: Functors
56
using Preferences: load_preference
7+
using Random: Random
68
using Optimisers: Optimisers
79
using Reactant:
810
Reactant,
@@ -23,7 +25,8 @@ using Lux: Lux, LuxOps, Training, Utils, StatefulLuxLayer
2325
using Lux.Training: TrainingBackendCache, ReactantBackend
2426
using Lux: get_time_dimension, time_dimension_size, init_recurrent_state
2527
using LuxCore: LuxCore, AbstractLuxLayer
26-
using MLDataDevices: MLDataDevices, ReactantDevice, get_device
28+
using LuxLib: LuxLib
29+
using MLDataDevices: MLDataDevices, ReactantDevice, reactant_device, get_device
2730

2831
Lux.is_extension_loaded(::Val{:Reactant}) = true
2932

@@ -61,4 +64,6 @@ include("layers.jl")
6164
include("tracing.jl")
6265
include("saved_model.jl")
6366

67+
include("precompile_workloads.jl")
68+
6469
end
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
using PrecompileTools: @setup_workload, @compile_workload
2+
3+
module PrecompileWorkloads
4+
5+
function sumabs2attnloss(model, ps, st, data)
6+
(y, _), stₙ = model(data, ps, st)
7+
return sum(abs2, y), stₙ, NamedTuple()
8+
end
9+
10+
function sumabs2loss(model, ps, st, data)
11+
y, stₙ = model(data, ps, st)
12+
return sum(abs2, y), stₙ, NamedTuple()
13+
end
14+
15+
end
16+
17+
if Reactant.Reactant_jll.is_available()
18+
@setup_workload begin
19+
orig_backend = Reactant.XLA.default_backend()
20+
Reactant.set_default_backend("cpu") # always precompile on CPU
21+
22+
dev = reactant_device(; force=true)
23+
24+
# attention model
25+
mha = Lux.MultiHeadAttention(4; nheads=2)
26+
ps_mha, st_mha = Lux.setup(Random.default_rng(), mha) |> dev
27+
28+
q = rand(Float32, (4, 3, 2)) |> dev
29+
k = rand(Float32, (4, 3, 2)) |> dev
30+
v = rand(Float32, (4, 3, 2)) |> dev
31+
32+
# convolution + dense model
33+
conv_model = Lux.Chain(
34+
Lux.Conv((3, 3), 3 => 32),
35+
Lux.Conv((3, 3), 32 => 64),
36+
Lux.GlobalMaxPool(),
37+
Lux.FlattenLayer(),
38+
Lux.Dense(64 => 10),
39+
)
40+
ps_conv_model, st_conv_model = Lux.setup(Random.default_rng(), conv_model) |> dev
41+
42+
x = rand(Float32, (28, 28, 3, 2)) |> dev
43+
44+
@compile_workload begin
45+
@compile mha((q, k, v), ps_mha, LuxCore.testmode(st_mha))
46+
47+
Lux.Training.single_train_step(
48+
AutoEnzyme(),
49+
PrecompileWorkloads.sumabs2attnloss,
50+
(q, k, v),
51+
Lux.Training.TrainState(mha, ps_mha, st_mha, Optimisers.Adam(0.001f0)),
52+
)
53+
54+
@compile conv_model(x, ps_conv_model, LuxCore.testmode(st_conv_model))
55+
56+
Lux.Training.single_train_step(
57+
AutoEnzyme(),
58+
PrecompileWorkloads.sumabs2loss,
59+
x,
60+
Lux.Training.TrainState(
61+
conv_model, ps_conv_model, st_conv_model, Optimisers.Adam(0.001f0)
62+
),
63+
)
64+
end
65+
66+
Reactant.clear_oc_cache()
67+
Reactant.set_default_backend(orig_backend)
68+
end
69+
end

src/Lux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ include("serialization/serialization.jl")
123123
# Deprecations for v2
124124
include("deprecations.jl")
125125

126-
# Precompile common workloads
126+
# Precompile Common Workloads
127127
include("precompile_workloads.jl")
128128

129129
# Layers

0 commit comments

Comments
 (0)