Skip to content

Commit 19c0db0

Browse files
author
Frankie Robertson
committed
Add SafeLikelihoodEstimator
1 parent b8046ec commit 19c0db0

File tree

1 file changed

+50
-1
lines changed

1 file changed

+50
-1
lines changed

src/aggregators/ability_estimator.jl

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ function Integrators.normdenom(rett::IntReturnType,
1111
rett(integrator(IntegralCoeffs.one, 0, est, tracked_responses))
1212
end
1313

14+
# This is not type piracy, but maybe a slightly distasteful overload
15+
# TODO: Fix this interface?
1416
function pdf(ability_est::DistributionAbilityEstimator,
1517
tracked_responses::TrackedResponses,
1618
x)
@@ -42,14 +44,61 @@ function pdf(est::PriorAbilityEstimator,
4244
AbilityLikelihood(tracked_responses))
4345
end
4446

47+
function multiple_response_types_guard(tracked_responses)
48+
if length(tracked_responses.responses.values) == 0
49+
return false
50+
end
51+
seen_value = tracked_responses.responses.values[1]
52+
for value in tracked_responses.responses.values
53+
if value !== seen_value
54+
return true
55+
end
56+
end
57+
return false
58+
end
59+
60+
struct GuardedAbilityEstimator{T <: DistributionAbilityEstimator, U <: DistributionAbilityEstimator, F} <: DistributionAbilityEstimator
61+
est::T
62+
fallback::U
63+
guard::F
64+
end
65+
66+
function pdf(est::GuardedAbilityEstimator,
67+
tracked_responses::TrackedResponses)
68+
if est.guard(tracked_responses)
69+
return pdf(est.est, tracked_responses)
70+
else
71+
return pdf(est.fallback, tracked_responses)
72+
end
73+
end
74+
75+
function SafeLikelihoodAbilityEstimator(args...; kwargs...)
76+
GuardedAbilityEstimator(
77+
LikelihoodAbilityEstimator(),
78+
PriorAbilityEstimator(args...),
79+
multiple_response_types_guard
80+
)
81+
end
82+
83+
unlog(x) = x
84+
unlog(x::Logarithmic{T}) where {T} = T(x)
85+
unlog(x::ULogarithmic{T}) where {T} = T(x)
86+
unlog(x::AbstractVector{Logarithmic{T}}) where {T} = T.(x)
87+
unlog(x::AbstractVector{ULogarithmic{T}}) where {T} = T.(x)
88+
#=unlog(x::ErrorIntegrationResult{Logarithmic{T}}) where {T} = T(x)
89+
unlog(x::ErrorIntegrationResult{ULogarithmic{T}}) where {T} = T(x)
90+
unlog(x::ErrorIntegrationResult{AbstractVector{Logarithmic{T}}}) where {T} = T.(x)
91+
unlog(x::ErrorIntegrationResult{AbstractVector{ULogarithmic{T}}}) where {T} = T.(x)
92+
=#
93+
4594
function expectation(rett::IntReturnType,
4695
f::F,
4796
ncomp,
4897
integrator::AbilityIntegrator,
4998
est::DistributionAbilityEstimator,
5099
tracked_responses::TrackedResponses,
51100
denom = normdenom(rett, integrator, est, tracked_responses)) where {F}
52-
rett(integrator(f, ncomp, est, tracked_responses)) / denom
101+
unlog(rett(integrator(f, ncomp, est, tracked_responses)) / denom)
53102
end
54103

55104
function expectation(f::F,

0 commit comments

Comments
 (0)