Skip to content

Commit 62c74c8

Browse files
committed
test: setup
1 parent e08c951 commit 62c74c8

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

ext/LuxReactantExt/batched_jacobian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct ApplyWithReshape{F,SZ}
1414
sz::SZ
1515
end
1616

17-
(f::ApplyWithReshape)(x) = f.f(reshape(x, f.sz))
17+
(f::ApplyWithReshape)(x) = reshape(f.f(reshape(x, f.sz)), :, size(x, ndims(x)))
1818

1919
function (f::ApplyWithReshape)(y, x)
2020
res = f.f(reshape(x, f.sz))

test/reactant/autodiff_tests.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,31 @@
5555
end
5656
end
5757
end
58+
59+
@testitem "AutoDiff APIs: Batched Jacobian" tags = [:reactant] setup = [SharedTestSetup] begin
60+
using Reactant, Lux, Zygote, Random, ForwardDiff, Enzyme
61+
62+
fn(x) = reshape(sum(abs2, x; dims=(1, 2, 3)), 1, :)
63+
64+
rng = Random.default_rng()
65+
66+
@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
67+
if mode == "amdgpu"
68+
@warn "Skipping AMDGPU tests for Reactant"
69+
continue
70+
end
71+
72+
if ongpu
73+
Reactant.set_default_backend("gpu")
74+
else
75+
Reactant.set_default_backend("cpu")
76+
end
77+
78+
dev = reactant_device(; force=true)
79+
80+
x = rand(rng, Float32, 2, 3, 4, 5)
81+
x_ra = dev(x)
82+
83+
# TODO: ....
84+
end
85+
end

0 commit comments

Comments
 (0)