|
| 1 | +using Boltz: Vision |
| 2 | +using Lux: Lux |
| 3 | +using MLDataDevices: AbstractDevice, CPUDevice, CUDADevice |
| 4 | +using Random: Random |
| 5 | +using Reactant: Reactant, @compile |
| 6 | + |
| 7 | +using Enzyme: Enzyme |
| 8 | +using Zygote: Zygote |
| 9 | + |
| 10 | +# Helper Functions |
| 11 | +@inline synchronize(::CPUDevice) = nothing |
| 12 | +@inline synchronize(::CUDADevice) = CUDA.synchronize() |
| 13 | + |
| 14 | +@inline reclaim(::CPUDevice) = GC.gc() |
| 15 | +@inline reclaim(::CUDADevice) = CUDA.reclaim() |
| 16 | + |
| 17 | +@inline sumabs2(model, x, p, st) = sum(abs2, first(Lux.apply(model, x, p, st))) |
| 18 | +@inline sumabs2(model, x) = sum(abs2, model(x)) |
| 19 | + |
| 20 | +function benchmark_group_to_backend(benchmark_group::String) |
| 21 | + benchmark_group == "CPU" && return CPUDevice() |
| 22 | + benchmark_group == "CUDA" && return CUDADevice() |
| 23 | + return error("Unknown backend: $(benchmark_group)") |
| 24 | +end |
| 25 | + |
| 26 | +function general_lux_setup(model, x_dims) |
| 27 | + rng = Random.default_rng() # don't use any other rng |
| 28 | + ps, st = Lux.setup(rng, model) |
| 29 | + x_dims === nothing && return ps, st |
| 30 | + x = randn(rng, Float32, x_dims) |
| 31 | + return x, ps, st |
| 32 | +end |
| 33 | + |
| 34 | +function setup_benchmarks!(suite::BenchmarkGroup, backend::String) |
| 35 | + dev = benchmark_group_to_backend(backend) |
| 36 | + |
| 37 | + setup_vit_benchmark!(suite, backend, dev) |
| 38 | + |
| 39 | + return nothing |
| 40 | +end |
| 41 | + |
| 42 | +# Lux Benchmarks |
| 43 | +function setup_vit_benchmark!(suite::BenchmarkGroup, backend, dev::AbstractDevice) |
| 44 | + for mode in (:tiny, :small, :base), bsize in (4, 16, 32) |
| 45 | + benchmark_name = "ViT $(mode) (256 x 256 x 3 x $(bsize))" |
| 46 | + |
| 47 | + setup_lux_forward_pass_benchmark!( |
| 48 | + suite, benchmark_name, backend, Vision.ViT(mode), (256, 256, 3, bsize), dev |
| 49 | + ) |
| 50 | + end |
| 51 | +end |
| 52 | + |
| 53 | +function setup_lux_forward_pass_benchmark!( |
| 54 | + suite::BenchmarkGroup, |
| 55 | + benchmark_name::String, |
| 56 | + backend::String, |
| 57 | + model, |
| 58 | + x_dims, |
| 59 | + dev::AbstractDevice, |
| 60 | +) |
| 61 | + suite[benchmark_name]["forward"][backend]["Lux"] = @benchmarkable begin |
| 62 | + Lux.apply($model, x, ps, st_test) |
| 63 | + synchronize($dev) |
| 64 | + end setup = begin |
| 65 | + GC.gc() |
| 66 | + reclaim($dev) |
| 67 | + x, ps, st = $dev(general_lux_setup($model, $x_dims)) |
| 68 | + st_test = Lux.testmode(st) |
| 69 | + GC.gc() |
| 70 | + reclaim($dev) |
| 71 | + end |
| 72 | + |
| 73 | + suite[benchmark_name]["forward"][backend]["Reactant"] = @benchmarkable begin |
| 74 | + y, _ = apply_compiled($model, x_ra, ps_ra, st_test_ra) |
| 75 | + Reactant.synchronize(y) |
| 76 | + end setup = begin |
| 77 | + GC.gc() |
| 78 | + reclaim($dev) |
| 79 | + x, ps, st = general_lux_setup($model, $x_dims) |
| 80 | + st_test = Lux.testmode(st) |
| 81 | + x_ra = Reactant.to_rarray(x) |
| 82 | + ps_ra = Reactant.to_rarray(ps) |
| 83 | + st_test_ra = Reactant.to_rarray(st_test) |
| 84 | + apply_compiled = @compile Lux.apply($model, x_ra, ps_ra, st_test_ra) |
| 85 | + GC.gc() |
| 86 | + reclaim($dev) |
| 87 | + end |
| 88 | + |
| 89 | + return nothing |
| 90 | +end |
0 commit comments