@@ -11,6 +11,8 @@ function Integrators.normdenom(rett::IntReturnType,
1111 rett (integrator (IntegralCoeffs. one, 0 , est, tracked_responses))
1212end
1313
14+ # This is not type piracy, but maybe a slightly distasteful overload
15+ # TODO : Fix this interface?
1416function pdf (ability_est:: DistributionAbilityEstimator ,
1517 tracked_responses:: TrackedResponses ,
1618 x)
@@ -42,14 +44,61 @@ function pdf(est::PriorAbilityEstimator,
4244 AbilityLikelihood (tracked_responses))
4345end
4446
47+ function multiple_response_types_guard (tracked_responses)
48+ if length (tracked_responses. responses. values) == 0
49+ return false
50+ end
51+ seen_value = tracked_responses. responses. values[1 ]
52+ for value in tracked_responses. responses. values
53+ if value != = seen_value
54+ return true
55+ end
56+ end
57+ return false
58+ end
59+
60+ struct GuardedAbilityEstimator{T <: DistributionAbilityEstimator , U <: DistributionAbilityEstimator , F} <: DistributionAbilityEstimator
61+ est:: T
62+ fallback:: U
63+ guard:: F
64+ end
65+
66+ function pdf (est:: GuardedAbilityEstimator ,
67+ tracked_responses:: TrackedResponses )
68+ if est. guard (tracked_responses)
69+ return pdf (est. est, tracked_responses)
70+ else
71+ return pdf (est. fallback, tracked_responses)
72+ end
73+ end
74+
75+ function SafeLikelihoodAbilityEstimator (args... ; kwargs... )
76+ GuardedAbilityEstimator (
77+ LikelihoodAbilityEstimator (),
78+ PriorAbilityEstimator (args... ),
79+ multiple_response_types_guard
80+ )
81+ end
82+
83+ unlog (x) = x
84+ unlog (x:: Logarithmic{T} ) where {T} = T (x)
85+ unlog (x:: ULogarithmic{T} ) where {T} = T (x)
86+ unlog (x:: AbstractVector{Logarithmic{T}} ) where {T} = T .(x)
87+ unlog (x:: AbstractVector{ULogarithmic{T}} ) where {T} = T .(x)
88+ #= unlog(x::ErrorIntegrationResult{Logarithmic{T}}) where {T} = T(x)
89+ unlog(x::ErrorIntegrationResult{ULogarithmic{T}}) where {T} = T(x)
90+ unlog(x::ErrorIntegrationResult{AbstractVector{Logarithmic{T}}}) where {T} = T.(x)
91+ unlog(x::ErrorIntegrationResult{AbstractVector{ULogarithmic{T}}}) where {T} = T.(x)
92+ =#
93+
4594function expectation (rett:: IntReturnType ,
4695 f:: F ,
4796 ncomp,
4897 integrator:: AbilityIntegrator ,
4998 est:: DistributionAbilityEstimator ,
5099 tracked_responses:: TrackedResponses ,
51100 denom = normdenom (rett, integrator, est, tracked_responses)) where {F}
52- rett (integrator (f, ncomp, est, tracked_responses)) / denom
101+ unlog ( rett (integrator (f, ncomp, est, tracked_responses)) / denom)
53102end
54103
55104function expectation (f:: F ,
0 commit comments