Skip to content

Commit 9dfb17d

Browse files
Merge pull request #383 from Ambar-13/fix-saveat-kernel
Resolve Metal.jl type instability for saveat and literals; Metal and AMDGPU build fix
2 parents a802496 + 70cf542 commit 9dfb17d

File tree

5 files changed

+132
-129
lines changed

5 files changed

+132
-129
lines changed

src/ensemblegpukernel/lowerlevel_solve.jl

Lines changed: 97 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
3636

3737
prob = convert(ImmutableODEProblem, prob)
3838
dt = convert(eltype(prob.tspan), dt)
39+
saveat_converted = nothing
3940

4041
if saveat === nothing
4142
if save_everystep
@@ -51,34 +52,30 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
5152
fill!(ts, prob.tspan[1])
5253
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
5354
else
54-
saveat = if saveat isa AbstractRange
55-
_saveat = range(convert(eltype(prob.tspan), first(saveat)),
56-
convert(eltype(prob.tspan), last(saveat)),
57-
length = length(saveat))
58-
convert(
59-
StepRangeLen{
60-
eltype(_saveat),
61-
eltype(_saveat),
62-
eltype(_saveat),
63-
eltype(_saveat) === Float32 ? Int32 : Int64
64-
},
65-
_saveat)
55+
# Get the time type from the problem
56+
Tt = eltype(prob.tspan)
57+
58+
# FIX for Issue #379: Convert saveat to proper type
59+
saveat_converted = if saveat isa AbstractRange
60+
Tt.(collect(range(Tt(first(saveat)), Tt(last(saveat)), length = length(saveat))))
6661
elseif saveat isa AbstractVector
67-
adapt(backend, convert.(eltype(prob.tspan), saveat))
62+
Tt.(collect(saveat))
6863
else
69-
_saveat = prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
70-
convert(
71-
StepRangeLen{
72-
eltype(_saveat),
73-
eltype(_saveat),
74-
eltype(_saveat),
75-
eltype(_saveat) === Float32 ? Int32 : Int64
76-
},
77-
_saveat)
64+
# saveat is a Number (step size)
65+
t0, tf = Tt.(prob.tspan)
66+
if Tt(saveat) == Tt(0.0)
67+
Tt.([t0, tf])
68+
else
69+
num_points = Int(ceil(abs(tf - t0) / abs(Tt(saveat)))) + 1
70+
Tt.(collect(range(t0, tf, length = num_points)))
71+
end
7872
end
79-
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
73+
74+
saveat_converted = adapt(backend, saveat_converted)
75+
76+
ts = allocate(backend, typeof(dt), (length(saveat_converted), length(probs)))
8077
fill!(ts, prob.tspan[1])
81-
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
78+
us = allocate(backend, typeof(prob.u0), (length(saveat_converted), length(probs)))
8279
end
8380

8481
tstops = adapt(backend, tstops)
@@ -89,15 +86,15 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
8986
@warn "Running the kernel on CPU"
9087
end
9188

92-
kernel(probs, alg, us, ts, dt, callback, tstops, nsteps, saveat,
89+
kernel(probs, alg, us, ts, dt, callback, tstops, nsteps, saveat_converted,
9390
Val(save_everystep);
9491
ndrange = length(probs))
9592

9693
# we build the actual solution object on the CPU because the GPU would create one
9794
# containig CuDeviceArrays, which we cannot use on the host (not GC tracked,
9895
# no useful operations, etc). That's unfortunate though, since this loop is
9996
# generally slower than the entire GPU execution, and necessitates synchronization
100-
#EDIT: Done when using with DiffEqGPU
97+
# EDIT: Done when using with DiffEqGPU
10198
ts, us
10299
end
103100

@@ -111,7 +108,7 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
111108
backend = maybe_prefer_blocks(backend)
112109

113110
dt = convert(eltype(prob.tspan), dt)
114-
111+
saveat_converted = nothing
115112
if saveat === nothing
116113
if save_everystep
117114
len = length(prob.tspan[1]:dt:prob.tspan[2])
@@ -122,20 +119,32 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
122119
fill!(ts, prob.tspan[1])
123120
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
124121
else
125-
saveat = if saveat isa AbstractRange
126-
range(convert(eltype(prob.tspan), first(saveat)),
127-
convert(eltype(prob.tspan), last(saveat)),
128-
length = length(saveat))
122+
# Get the time type from the problem
123+
Tt = eltype(prob.tspan)
124+
125+
# FIX for Issue #379: Convert saveat to proper type
126+
saveat_converted = if saveat isa AbstractRange
127+
Tt.(collect(range(Tt(first(saveat)), Tt(last(saveat)), length = length(saveat))))
129128
elseif saveat isa AbstractVector
130-
convert.(eltype(prob.tspan), adapt(backend, saveat))
129+
Tt.(collect(saveat))
131130
else
132-
prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
131+
# saveat is a Number (step size)
132+
t0, tf = Tt.(prob.tspan)
133+
if Tt(saveat) == Tt(0.0)
134+
Tt.([t0, tf])
135+
else
136+
num_points = Int(ceil(abs(tf - t0) / abs(Tt(saveat)))) + 1
137+
Tt.(collect(range(t0, tf, length = num_points)))
138+
end
133139
end
134-
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
140+
141+
ts = allocate(backend, typeof(dt), (length(saveat_converted), length(probs)))
135142
fill!(ts, prob.tspan[1])
136-
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
143+
us = allocate(backend, typeof(prob.u0), (length(saveat_converted), length(probs)))
144+
end
145+
if saveat_converted !== nothing
146+
saveat_converted = adapt(backend, saveat_converted)
137147
end
138-
139148
if alg isa GPUEM
140149
kernel = em_kernel(backend)
141150
elseif alg isa Union{GPUSIEA}
@@ -148,7 +157,7 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
148157
@warn "Running the kernel on CPU"
149158
end
150159

151-
kernel(probs, us, ts, dt, saveat, Val(save_everystep);
160+
kernel(probs, us, ts, dt, saveat_converted, Val(save_everystep);
152161
ndrange = length(probs))
153162
ts, us
154163
end
@@ -184,62 +193,86 @@ function vectorized_asolve(probs, prob::ODEProblem, alg;
184193
abstol = 1.0f-6, reltol = 1.0f-3,
185194
debug = false, callback = CallbackSet(nothing), tstops = nothing,
186195
kwargs...)
196+
187197
backend = get_backend(probs)
188198
backend = maybe_prefer_blocks(backend)
189199

190-
prob = convert(ImmutableODEProblem, prob)
200+
# Get the time type from the problem
201+
Tt = eltype(prob.tspan)
202+
203+
# FIX for Issue #379: Convert saveat to eliminate
204+
# StepRangeLen's internal Float64 fields which crash Metal
205+
206+
if saveat !== nothing
207+
if saveat isa Number
208+
# Handle edge case: saveat = 0.0 means only save endpoints
209+
if Tt(saveat) == Tt(0.0)
210+
saveat_converted = Tt.([prob.tspan[1], prob.tspan[2]])
211+
else
212+
# Create proper range with correct type
213+
t0, tf = Tt.(prob.tspan)
214+
215+
# Handle both forward and reverse time integration
216+
num_points = Int(ceil(abs(tf - t0) / abs(Tt(saveat)))) + 1
217+
218+
# Safety check: prevent massive arrays
219+
max_saveat_length = 100_000
220+
if num_points > max_saveat_length
221+
error("saveat would create too many save points ($num_points). " *
222+
"Consider using a larger saveat value.")
223+
end
224+
225+
# Create range and convert to pure Vector{Tt}
226+
saveat_range = range(t0, tf, length = num_points)
227+
saveat_converted = Tt.(collect(saveat_range))
228+
end
229+
elseif saveat isa AbstractRange || saveat isa AbstractArray
230+
# Range or array - convert all elements to Tt
231+
# This eliminates StepRangeLen's Float64 internals
232+
saveat_converted = Tt.(collect(saveat))
233+
else
234+
# Already in correct form
235+
saveat_converted = saveat
236+
end
237+
else
238+
saveat_converted = nothing
239+
end
191240

241+
prob = convert(ImmutableODEProblem, prob)
192242
dt = convert(eltype(prob.tspan), dt)
193-
abstol = convert(eltype(prob.tspan), abstol)
194-
reltol = convert(eltype(prob.tspan), reltol)
195-
# if saveat is specified, we'll use a vector of timestamps.
196-
# otherwise it's a matrix that may be different for each ODE.
197-
if saveat === nothing
243+
244+
if saveat_converted === nothing
198245
if save_everystep
199-
error("Don't use adaptive version with saveat == nothing and save_everystep = true")
246+
len = ceil(Int, (prob.tspan[2] - prob.tspan[1]) / dt) + 1
200247
else
201248
len = 2
202249
end
203-
# if tstops !== nothing
204-
# len += length(tstops)
205-
# end
206250
ts = allocate(backend, typeof(dt), (len, length(probs)))
207251
fill!(ts, prob.tspan[1])
208252
us = allocate(backend, typeof(prob.u0), (len, length(probs)))
209253
else
210-
saveat = if saveat isa AbstractRange
211-
range(convert(eltype(prob.tspan), first(saveat)),
212-
convert(eltype(prob.tspan), last(saveat)),
213-
length = length(saveat))
214-
elseif saveat isa AbstractVector
215-
adapt(backend, convert.(eltype(prob.tspan), saveat))
216-
else
217-
prob.tspan[1]:convert(eltype(prob.tspan), saveat):prob.tspan[end]
218-
end
219-
ts = allocate(backend, typeof(dt), (length(saveat), length(probs)))
254+
ts = allocate(backend, typeof(dt), (length(saveat_converted), length(probs)))
220255
fill!(ts, prob.tspan[1])
221-
us = allocate(backend, typeof(prob.u0), (length(saveat), length(probs)))
256+
us = allocate(backend, typeof(prob.u0), (length(saveat_converted), length(probs)))
222257
end
223258

224259
us = adapt(backend, us)
225260
ts = adapt(backend, ts)
226261
tstops = adapt(backend, tstops)
227262

263+
if saveat_converted !== nothing
264+
saveat_converted = adapt(backend, saveat_converted)
265+
end
228266
kernel = ode_asolve_kernel(backend)
229267

230268
if backend isa CPU
231269
@warn "Running the kernel on CPU"
232270
end
233271

234272
kernel(probs, alg, us, ts, dt, callback, tstops,
235-
abstol, reltol, saveat, Val(save_everystep);
273+
abstol, reltol, saveat_converted, Val(save_everystep);
236274
ndrange = length(probs))
237275

238-
# we build the actual solution object on the CPU because the GPU would create one
239-
# containig CuDeviceArrays, which we cannot use on the host (not GC tracked,
240-
# no useful operations, etc). That's unfortunate though, since this loop is
241-
# generally slower than the entire GPU execution, and necessitates synchronization
242-
#EDIT: Done when using with DiffEqGPU
243276
ts, us
244277
end
245278

src/ensemblegpukernel/nlsolve/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
Δz = linear_solve(W_eval, f_rhs)
1818
z_i = z_i - Δz
1919

20-
if norm(dt * integrator.f(tmp + γ * z_i, p, t + c * dt) - z_i) < abstol
20+
if diffeqgpunorm(dt * integrator.f(tmp + γ * z_i, p, t + c * dt) - z_i, t) < abstol
2121
break
2222
end
2323
end
Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,25 @@
11
@kernel function em_kernel(@Const(probs), _us, _ts, dt,
22
saveat, ::Val{save_everystep}) where {save_everystep}
33
i = @index(Global, Linear)
4-
54
# get the actual problem for this thread
65
prob = @inbounds probs[i]
7-
86
Random.seed!(prob.seed)
9-
107
# get the input/output arrays for this thread
118
ts = @inbounds view(_ts, :, i)
129
us = @inbounds view(_us, :, i)
13-
1410
_saveat = get(prob.kwargs, :saveat, nothing)
15-
1611
saveat = _saveat === nothing ? saveat : _saveat
17-
1812
f = prob.f
1913
g = prob.g
2014
u0 = prob.u0
2115
tspan = prob.tspan
2216
p = prob.p
23-
17+
18+
# FIX for Issue #379: Get time type from tspan
19+
Tt = typeof(tspan[1])
20+
dt = Tt(dt)
21+
2422
is_diagonal_noise = SciMLBase.is_diagonal_noise(prob)
25-
2623
cur_t = 0
2724
if saveat !== nothing
2825
cur_t = 1
@@ -34,25 +31,26 @@
3431
@inbounds ts[1] = tspan[1]
3532
@inbounds us[1] = u0
3633
end
37-
34+
35+
# FIX: Use Tt for sqrt to ensure proper type
3836
sqdt = sqrt(dt)
3937
u = copy(u0)
4038
t = copy(tspan[1])
41-
n = length(tspan[1]:dt:tspan[2])
42-
39+
40+
# FIX: Ensure n calculation uses proper types
41+
t0, tf = tspan[1], tspan[2]
42+
n = floor(Int, abs(tf - t0) / abs(dt)) + 1
43+
4344
for j in 2:n
4445
uprev = u
45-
4646
if is_diagonal_noise
4747
u = uprev + f(uprev, p, t) * dt +
4848
sqdt * g(uprev, p, t) .* randn(typeof(u0))
4949
else
5050
u = uprev + f(uprev, p, t) * dt +
5151
sqdt * g(uprev, p, t) * randn(typeof(prob.noise_rate_prototype[1, :]))
5252
end
53-
5453
t += dt
55-
5654
if saveat === nothing && save_everystep
5755
@inbounds us[j] = u
5856
@inbounds ts[j] = t
@@ -67,9 +65,8 @@
6765
end
6866
end
6967
end
70-
7168
if saveat === nothing && !save_everystep
7269
@inbounds us[2] = u
7370
@inbounds ts[2] = t
7471
end
75-
end
72+
end

0 commit comments

Comments
 (0)