|
63 | 63 |
|
64 | 64 | # FIX for Issue #379: Get time type from tspan |
65 | 65 | Tt = typeof(tspan[1]) |
| 66 | + dt = Tt(dt) |
66 | 67 |
|
67 | 68 | is_diagonal_noise = SciMLBase.is_diagonal_noise(prob) |
68 | 69 | cur_t = 0 |
|
78 | 79 | end |
79 | 80 |
|
80 | 81 | # FIX: Use Tt for sqrt to ensure proper type |
81 | | - sqdt = sqrt(Tt(dt)) |
| 82 | + sqdt = sqrt(dt) |
82 | 83 | u = copy(u0) |
83 | 84 | t = copy(tspan[1]) |
84 | 85 |
|
85 | 86 | # FIX: Ensure n calculation uses proper types |
86 | 87 | t0, tf = tspan[1], tspan[2] |
87 | | - n = floor(Int, abs(tf - t0) / abs(Tt(dt))) + 1 |
| 88 | + n = floor(Int, abs(tf - t0) / abs(dt)) + 1 |
88 | 89 |
|
89 | 90 | cache = SIEAConstantCache(eltype(u0), Tt) |
90 | 91 | @unpack α1, α2, γ1, λ1, λ2, λ3, µ1, µ2, µ3, µ0, µbar0, λ0, λbar0, ν1, ν2, β2, β3, δ2, |
|
98 | 99 | if is_diagonal_noise |
99 | 100 | dW = sqdt * randn(typeof(u0)) |
100 | 101 | W2 = (dW) .^ 2 / sqdt |
101 | | - W3 = ν2 * (dW) .^ 3 / Tt(dt) |
102 | | - k1 = f(uprev + λ0 * k0 * Tt(dt) + ν1 * g0 .* dW + g0 .* W3, p, t + µ0 * Tt(dt)) |
103 | | - g1 = g(uprev + λbar0 * k0 * Tt(dt) + β2 * g0 * sqdt + β3 * g0 .* W2, p, |
104 | | - t + µbar0 * Tt(dt)) |
105 | | - g2 = g(uprev + λbar0 * k0 * Tt(dt) + δ2 * g0 * sqdt + δ3 * g0 .* W2, p, |
106 | | - t + µbar0 * Tt(dt)) |
107 | | - u = uprev + (α1 * k0 + α2 * k1) * Tt(dt) |
| 102 | + W3 = ν2 * (dW) .^ 3 / dt |
| 103 | + k1 = f(uprev + λ0 * k0 * dt + ν1 * g0 .* dW + g0 .* W3, p, t + µ0 * dt) |
| 104 | + g1 = g(uprev + λbar0 * k0 * dt + β2 * g0 * sqdt + β3 * g0 .* W2, p, |
| 105 | + t + µbar0 * dt) |
| 106 | + g2 = g(uprev + λbar0 * k0 * dt + δ2 * g0 * sqdt + δ3 * g0 .* W2, p, |
| 107 | + t + µbar0 * dt) |
| 108 | + u = uprev + (α1 * k0 + α2 * k1) * dt |
108 | 109 | u += γ1 * g0 .* dW + (λ1 .* dW .+ λ2 * sqdt + λ3 .* W2) .* g1 + |
109 | 110 | (µ1 .* dW .+ µ2 * sqdt + µ3 .* W2) .* g2 |
110 | 111 | end |
111 | | - t += Tt(dt) |
| 112 | + t += dt |
112 | 113 | if saveat === nothing && save_everystep |
113 | 114 | @inbounds us[j] = u |
114 | 115 | @inbounds ts[j] = t |
115 | 116 | elseif saveat !== nothing |
116 | 117 | while cur_t <= length(saveat) && saveat[cur_t] <= t |
117 | 118 | savet = saveat[cur_t] |
118 | | - Θ = (savet - (t - Tt(dt))) / Tt(dt) |
| 119 | + Θ = (savet - (t - dt)) / dt |
119 | 120 | # Linear Interpolation |
120 | 121 | @inbounds us[cur_t] = uprev + (u - uprev) * Θ |
121 | 122 | @inbounds ts[cur_t] = savet |
|
0 commit comments