Skip to content

Commit 749dd47

Browse files
better retcodes
1 parent 110b804 commit 749dd47

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

src/cost_functions.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ function (f::L2Loss)(sol::DiffEqBase.DESolution)
3939
colloc_grad = f.colloc_grad
4040
dudt = f.dudt
4141

42-
if any((s.retcode != :Success for s in sol)) && any((s.retcode != :Terminated for s in sol))
43-
return Inf
42+
if sol_tmp isa DiffEqBase.AbstractEnsembleSolution
43+
failure = any((s.retcode != :Success for s in sol_tmp)) && any((s.retcode != :Terminated for s in sol_tmp))
44+
else
45+
failure = sol_tmp.retcode != :Success && sol_tmp != :Terminated
4446
end
47+
failure && return Inf
4548

4649
sumsq = 0.0
4750

@@ -124,9 +127,12 @@ LogLikeLoss(t,data_distributions,diff_distributions) = LogLikeLoss(t,matrixize(d
124127

125128
function (f::LogLikeLoss)(sol::DESolution)
126129
distributions = f.data_distributions
127-
if any((s.retcode != :Success for s in sol)) && any((s.retcode != :Terminated for s in sol))
128-
return Inf
130+
if sol_tmp isa DiffEqBase.AbstractEnsembleSolution
131+
failure = any((s.retcode != :Success for s in sol_tmp)) && any((s.retcode != :Terminated for s in sol_tmp))
132+
else
133+
failure = sol_tmp.retcode != :Success && sol_tmp != :Terminated
129134
end
135+
failure && return Inf
130136
ll = 0.0
131137

132138
if eltype(distributions) <: UnivariateDistribution
@@ -170,9 +176,12 @@ end
170176

171177
function (f::LogLikeLoss)(sol::DiffEqBase.AbstractMonteCarloSolution)
172178
distributions = f.data_distributions
173-
if any((s.retcode != :Success for s in sol)) && any((s.retcode != :Terminated for s in sol))
174-
return Inf
179+
if sol_tmp isa DiffEqBase.AbstractEnsembleSolution
180+
failure = any((s.retcode != :Success for s in sol_tmp)) && any((s.retcode != :Terminated for s in sol_tmp))
181+
else
182+
failure = sol_tmp.retcode != :Success && sol_tmp != :Terminated
175183
end
184+
failure && return Inf
176185
ll = 0.0
177186
if eltype(distributions) <: UnivariateDistribution
178187
for j in 1:length(f.t), i in 1:length(sol[1][1])

0 commit comments

Comments
 (0)