@@ -36,6 +36,7 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
3636
3737 prob = convert (ImmutableODEProblem, prob)
3838 dt = convert (eltype (prob. tspan), dt)
39+ saveat_converted = nothing
3940
4041 if saveat === nothing
4142 if save_everystep
@@ -51,34 +52,30 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
5152 fill! (ts, prob. tspan[1 ])
5253 us = allocate (backend, typeof (prob. u0), (len, length (probs)))
5354 else
54- saveat = if saveat isa AbstractRange
55- _saveat = range (convert (eltype (prob. tspan), first (saveat)),
56- convert (eltype (prob. tspan), last (saveat)),
57- length = length (saveat))
58- convert (
59- StepRangeLen{
60- eltype (_saveat),
61- eltype (_saveat),
62- eltype (_saveat),
63- eltype (_saveat) === Float32 ? Int32 : Int64
64- },
65- _saveat)
55+ # Get the time type from the problem
56+ Tt = eltype (prob. tspan)
57+
58+ # FIX for Issue #379: Convert saveat to proper type
59+ saveat_converted = if saveat isa AbstractRange
60+ Tt .(collect (range (Tt (first (saveat)), Tt (last (saveat)), length = length (saveat))))
6661 elseif saveat isa AbstractVector
67- adapt (backend, convert .( eltype (prob . tspan), saveat))
62+ Tt .( collect ( saveat))
6863 else
69- _saveat = prob. tspan[1 ]: convert (eltype (prob. tspan), saveat): prob. tspan[end ]
70- convert (
71- StepRangeLen{
72- eltype (_saveat),
73- eltype (_saveat),
74- eltype (_saveat),
75- eltype (_saveat) === Float32 ? Int32 : Int64
76- },
77- _saveat)
64+ # saveat is a Number (step size)
65+ t0, tf = Tt .(prob. tspan)
66+ if Tt (saveat) == Tt (0.0 )
67+ Tt .([t0, tf])
68+ else
69+ num_points = Int (ceil (abs (tf - t0) / abs (Tt (saveat)))) + 1
70+ Tt .(collect (range (t0, tf, length = num_points)))
71+ end
7872 end
79- ts = allocate (backend, typeof (dt), (length (saveat), length (probs)))
73+
74+ saveat_converted = adapt (backend, saveat_converted)
75+
76+ ts = allocate (backend, typeof (dt), (length (saveat_converted), length (probs)))
8077 fill! (ts, prob. tspan[1 ])
81- us = allocate (backend, typeof (prob. u0), (length (saveat ), length (probs)))
78+ us = allocate (backend, typeof (prob. u0), (length (saveat_converted ), length (probs)))
8279 end
8380
8481 tstops = adapt (backend, tstops)
@@ -89,15 +86,15 @@ function vectorized_solve(probs, prob::ODEProblem, alg;
8986 @warn " Running the kernel on CPU"
9087 end
9188
92- kernel (probs, alg, us, ts, dt, callback, tstops, nsteps, saveat ,
89+ kernel (probs, alg, us, ts, dt, callback, tstops, nsteps, saveat_converted ,
9390 Val (save_everystep);
9491 ndrange = length (probs))
9592
9693 # we build the actual solution object on the CPU because the GPU would create one
9794 # containig CuDeviceArrays, which we cannot use on the host (not GC tracked,
9895 # no useful operations, etc). That's unfortunate though, since this loop is
9996 # generally slower than the entire GPU execution, and necessitates synchronization
100- # EDIT: Done when using with DiffEqGPU
97+ # EDIT: Done when using with DiffEqGPU
10198 ts, us
10299end
103100
@@ -111,7 +108,7 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
111108 backend = maybe_prefer_blocks (backend)
112109
113110 dt = convert (eltype (prob. tspan), dt)
114-
111+ saveat_converted = nothing
115112 if saveat === nothing
116113 if save_everystep
117114 len = length (prob. tspan[1 ]: dt: prob. tspan[2 ])
@@ -122,20 +119,32 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
122119 fill! (ts, prob. tspan[1 ])
123120 us = allocate (backend, typeof (prob. u0), (len, length (probs)))
124121 else
125- saveat = if saveat isa AbstractRange
126- range (convert (eltype (prob. tspan), first (saveat)),
127- convert (eltype (prob. tspan), last (saveat)),
128- length = length (saveat))
122+ # Get the time type from the problem
123+ Tt = eltype (prob. tspan)
124+
125+ # FIX for Issue #379: Convert saveat to proper type
126+ saveat_converted = if saveat isa AbstractRange
127+ Tt .(collect (range (Tt (first (saveat)), Tt (last (saveat)), length = length (saveat))))
129128 elseif saveat isa AbstractVector
130- convert .( eltype (prob . tspan), adapt (backend, saveat))
129+ Tt .( collect ( saveat))
131130 else
132- prob. tspan[1 ]: convert (eltype (prob. tspan), saveat): prob. tspan[end ]
131+ # saveat is a Number (step size)
132+ t0, tf = Tt .(prob. tspan)
133+ if Tt (saveat) == Tt (0.0 )
134+ Tt .([t0, tf])
135+ else
136+ num_points = Int (ceil (abs (tf - t0) / abs (Tt (saveat)))) + 1
137+ Tt .(collect (range (t0, tf, length = num_points)))
138+ end
133139 end
134- ts = allocate (backend, typeof (dt), (length (saveat), length (probs)))
140+
141+ ts = allocate (backend, typeof (dt), (length (saveat_converted), length (probs)))
135142 fill! (ts, prob. tspan[1 ])
136- us = allocate (backend, typeof (prob. u0), (length (saveat), length (probs)))
143+ us = allocate (backend, typeof (prob. u0), (length (saveat_converted), length (probs)))
144+ end
145+ if saveat_converted != = nothing
146+ saveat_converted = adapt (backend, saveat_converted)
137147 end
138-
139148 if alg isa GPUEM
140149 kernel = em_kernel (backend)
141150 elseif alg isa Union{GPUSIEA}
@@ -148,7 +157,7 @@ function vectorized_solve(probs, prob::SDEProblem, alg;
148157 @warn " Running the kernel on CPU"
149158 end
150159
151- kernel (probs, us, ts, dt, saveat , Val (save_everystep);
160+ kernel (probs, us, ts, dt, saveat_converted , Val (save_everystep);
152161 ndrange = length (probs))
153162 ts, us
154163end
@@ -184,62 +193,86 @@ function vectorized_asolve(probs, prob::ODEProblem, alg;
184193 abstol = 1.0f-6 , reltol = 1.0f-3 ,
185194 debug = false , callback = CallbackSet (nothing ), tstops = nothing ,
186195 kwargs... )
196+
187197 backend = get_backend (probs)
188198 backend = maybe_prefer_blocks (backend)
189199
190- prob = convert (ImmutableODEProblem, prob)
200+ # Get the time type from the problem
201+ Tt = eltype (prob. tspan)
202+
203+ # FIX for Issue #379: Convert saveat to eliminate
204+ # StepRangeLen's internal Float64 fields which crash Metal
205+
206+ if saveat != = nothing
207+ if saveat isa Number
208+ # Handle edge case: saveat = 0.0 means only save endpoints
209+ if Tt (saveat) == Tt (0.0 )
210+ saveat_converted = Tt .([prob. tspan[1 ], prob. tspan[2 ]])
211+ else
212+ # Create proper range with correct type
213+ t0, tf = Tt .(prob. tspan)
214+
215+ # Handle both forward and reverse time integration
216+ num_points = Int (ceil (abs (tf - t0) / abs (Tt (saveat)))) + 1
217+
218+ # Safety check: prevent massive arrays
219+ max_saveat_length = 100_000
220+ if num_points > max_saveat_length
221+ error (" saveat would create too many save points ($num_points ). " *
222+ " Consider using a larger saveat value." )
223+ end
224+
225+ # Create range and convert to pure Vector{Tt}
226+ saveat_range = range (t0, tf, length = num_points)
227+ saveat_converted = Tt .(collect (saveat_range))
228+ end
229+ elseif saveat isa AbstractRange || saveat isa AbstractArray
230+ # Range or array - convert all elements to Tt
231+ # This eliminates StepRangeLen's Float64 internals
232+ saveat_converted = Tt .(collect (saveat))
233+ else
234+ # Already in correct form
235+ saveat_converted = saveat
236+ end
237+ else
238+ saveat_converted = nothing
239+ end
191240
241+ prob = convert (ImmutableODEProblem, prob)
192242 dt = convert (eltype (prob. tspan), dt)
193- abstol = convert (eltype (prob. tspan), abstol)
194- reltol = convert (eltype (prob. tspan), reltol)
195- # if saveat is specified, we'll use a vector of timestamps.
196- # otherwise it's a matrix that may be different for each ODE.
197- if saveat === nothing
243+
244+ if saveat_converted === nothing
198245 if save_everystep
199- error ( " Don't use adaptive version with saveat == nothing and save_everystep = true " )
246+ len = ceil (Int, (prob . tspan[ 2 ] - prob . tspan[ 1 ]) / dt) + 1
200247 else
201248 len = 2
202249 end
203- # if tstops !== nothing
204- # len += length(tstops)
205- # end
206250 ts = allocate (backend, typeof (dt), (len, length (probs)))
207251 fill! (ts, prob. tspan[1 ])
208252 us = allocate (backend, typeof (prob. u0), (len, length (probs)))
209253 else
210- saveat = if saveat isa AbstractRange
211- range (convert (eltype (prob. tspan), first (saveat)),
212- convert (eltype (prob. tspan), last (saveat)),
213- length = length (saveat))
214- elseif saveat isa AbstractVector
215- adapt (backend, convert .(eltype (prob. tspan), saveat))
216- else
217- prob. tspan[1 ]: convert (eltype (prob. tspan), saveat): prob. tspan[end ]
218- end
219- ts = allocate (backend, typeof (dt), (length (saveat), length (probs)))
254+ ts = allocate (backend, typeof (dt), (length (saveat_converted), length (probs)))
220255 fill! (ts, prob. tspan[1 ])
221- us = allocate (backend, typeof (prob. u0), (length (saveat ), length (probs)))
256+ us = allocate (backend, typeof (prob. u0), (length (saveat_converted ), length (probs)))
222257 end
223258
224259 us = adapt (backend, us)
225260 ts = adapt (backend, ts)
226261 tstops = adapt (backend, tstops)
227262
263+ if saveat_converted != = nothing
264+ saveat_converted = adapt (backend, saveat_converted)
265+ end
228266 kernel = ode_asolve_kernel (backend)
229267
230268 if backend isa CPU
231269 @warn " Running the kernel on CPU"
232270 end
233271
234272 kernel (probs, alg, us, ts, dt, callback, tstops,
235- abstol, reltol, saveat , Val (save_everystep);
273+ abstol, reltol, saveat_converted , Val (save_everystep);
236274 ndrange = length (probs))
237275
238- # we build the actual solution object on the CPU because the GPU would create one
239- # containig CuDeviceArrays, which we cannot use on the host (not GC tracked,
240- # no useful operations, etc). That's unfortunate though, since this loop is
241- # generally slower than the entire GPU execution, and necessitates synchronization
242- # EDIT: Done when using with DiffEqGPU
243276 ts, us
244277end
245278
0 commit comments