Skip to content

Commit 1860d73

Browse files
committed
Convert dt once instead of multiple times per iteration
Converting dt only once to make it simpler
1 parent b57808c commit 1860d73

File tree

2 files changed

+19
-17
lines changed

2 files changed

+19
-17
lines changed

src/ensemblegpukernel/perform_step/gpu_em_perform_step.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
# FIX for Issue #379: Get time type from tspan
1919
Tt = typeof(tspan[1])
20+
dt = Tt(dt)
2021

2122
is_diagonal_noise = SciMLBase.is_diagonal_noise(prob)
2223
cur_t = 0
@@ -32,31 +33,31 @@
3233
end
3334

3435
# FIX: Use Tt for sqrt to ensure proper type
35-
sqdt = sqrt(Tt(dt))
36+
sqdt = sqrt(dt)
3637
u = copy(u0)
3738
t = copy(tspan[1])
3839

3940
# FIX: Ensure n calculation uses proper types
4041
t0, tf = tspan[1], tspan[2]
41-
n = floor(Int, abs(tf - t0) / abs(Tt(dt))) + 1
42+
n = floor(Int, abs(tf - t0) / abs(dt)) + 1
4243

4344
for j in 2:n
4445
uprev = u
4546
if is_diagonal_noise
46-
u = uprev + f(uprev, p, t) * Tt(dt) +
47+
u = uprev + f(uprev, p, t) * dt +
4748
sqdt * g(uprev, p, t) .* randn(typeof(u0))
4849
else
49-
u = uprev + f(uprev, p, t) * Tt(dt) +
50+
u = uprev + f(uprev, p, t) * dt +
5051
sqdt * g(uprev, p, t) * randn(typeof(prob.noise_rate_prototype[1, :]))
5152
end
52-
t += Tt(dt)
53+
t += dt
5354
if saveat === nothing && save_everystep
5455
@inbounds us[j] = u
5556
@inbounds ts[j] = t
5657
elseif saveat !== nothing
5758
while cur_t <= length(saveat) && saveat[cur_t] <= t
5859
savet = saveat[cur_t]
59-
Θ = (savet - (t - Tt(dt))) / Tt(dt)
60+
Θ = (savet - (t - dt)) / dt
6061
# Linear Interpolation
6162
@inbounds us[cur_t] = uprev + (u - uprev) * Θ
6263
@inbounds ts[cur_t] = savet

src/ensemblegpukernel/perform_step/gpu_siea_perform_step.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ end
6363

6464
# FIX for Issue #379: Get time type from tspan
6565
Tt = typeof(tspan[1])
66+
dt = Tt(dt)
6667

6768
is_diagonal_noise = SciMLBase.is_diagonal_noise(prob)
6869
cur_t = 0
@@ -78,13 +79,13 @@ end
7879
end
7980

8081
# FIX: Use Tt for sqrt to ensure proper type
81-
sqdt = sqrt(Tt(dt))
82+
sqdt = sqrt(dt)
8283
u = copy(u0)
8384
t = copy(tspan[1])
8485

8586
# FIX: Ensure n calculation uses proper types
8687
t0, tf = tspan[1], tspan[2]
87-
n = floor(Int, abs(tf - t0) / abs(Tt(dt))) + 1
88+
n = floor(Int, abs(tf - t0) / abs(dt)) + 1
8889

8990
cache = SIEAConstantCache(eltype(u0), Tt)
9091
@unpack α1, α2, γ1, λ1, λ2, λ3, µ1, µ2, µ3, µ0, µbar0, λ0, λbar0, ν1, ν2, β2, β3, δ2,
@@ -98,24 +99,24 @@ end
9899
if is_diagonal_noise
99100
dW = sqdt * randn(typeof(u0))
100101
W2 = (dW) .^ 2 / sqdt
101-
W3 = ν2 * (dW) .^ 3 / Tt(dt)
102-
k1 = f(uprev + λ0 * k0 * Tt(dt) + ν1 * g0 .* dW + g0 .* W3, p, t + µ0 * Tt(dt))
103-
g1 = g(uprev + λbar0 * k0 * Tt(dt) + β2 * g0 * sqdt + β3 * g0 .* W2, p,
104-
t + µbar0 * Tt(dt))
105-
g2 = g(uprev + λbar0 * k0 * Tt(dt) + δ2 * g0 * sqdt + δ3 * g0 .* W2, p,
106-
t + µbar0 * Tt(dt))
107-
u = uprev + (α1 * k0 + α2 * k1) * Tt(dt)
102+
W3 = ν2 * (dW) .^ 3 / dt
103+
k1 = f(uprev + λ0 * k0 * dt + ν1 * g0 .* dW + g0 .* W3, p, t + µ0 * dt)
104+
g1 = g(uprev + λbar0 * k0 * dt + β2 * g0 * sqdt + β3 * g0 .* W2, p,
105+
t + µbar0 * dt)
106+
g2 = g(uprev + λbar0 * k0 * dt + δ2 * g0 * sqdt + δ3 * g0 .* W2, p,
107+
t + µbar0 * dt)
108+
u = uprev + (α1 * k0 + α2 * k1) * dt
108109
u += γ1 * g0 .* dW + (λ1 .* dW .+ λ2 * sqdt + λ3 .* W2) .* g1 +
109110
(µ1 .* dW .+ µ2 * sqdt + µ3 .* W2) .* g2
110111
end
111-
t += Tt(dt)
112+
t += dt
112113
if saveat === nothing && save_everystep
113114
@inbounds us[j] = u
114115
@inbounds ts[j] = t
115116
elseif saveat !== nothing
116117
while cur_t <= length(saveat) && saveat[cur_t] <= t
117118
savet = saveat[cur_t]
118-
Θ = (savet - (t - Tt(dt))) / Tt(dt)
119+
Θ = (savet - (t - dt)) / dt
119120
# Linear Interpolation
120121
@inbounds us[cur_t] = uprev + (u - uprev) * Θ
121122
@inbounds ts[cur_t] = savet

0 commit comments

Comments
 (0)