From 355209a9d31c36217912341e627d39b265b55fee Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Mon, 7 Oct 2024 10:46:25 +0300 Subject: [PATCH 1/6] Add postposterior covariance matrix based mirt rules --- src/aggregators/Aggregators.jl | 2 +- src/aggregators/ability_estimator.jl | 53 +++++++++++++++++++- src/next_item_rules/NextItemRules.jl | 5 ++ src/next_item_rules/mirt.jl | 61 +++++++++++++++++++++++ src/next_item_rules/objective_function.jl | 31 +++++------- test/ability_estimator_2d.jl | 34 +++++++++++-- 6 files changed, 162 insertions(+), 24 deletions(-) create mode 100644 src/next_item_rules/mirt.jl diff --git a/src/aggregators/Aggregators.jl b/src/aggregators/Aggregators.jl index 9c2b487..6b0a30d 100644 --- a/src/aggregators/Aggregators.jl +++ b/src/aggregators/Aggregators.jl @@ -20,7 +20,7 @@ using PsychometricsBazaarBase.ConfigTools using PsychometricsBazaarBase.Integrators using PsychometricsBazaarBase: Integrators using PsychometricsBazaarBase.Optimizers -using PsychometricsBazaarBase.ConstDistributions: std_normal +using PsychometricsBazaarBase.ConstDistributions: std_normal, std_mv_normal import FittedItemBanks import PsychometricsBazaarBase.IntegralCoeffs diff --git a/src/aggregators/ability_estimator.jl b/src/aggregators/ability_estimator.jl index e26c947..070839a 100644 --- a/src/aggregators/ability_estimator.jl +++ b/src/aggregators/ability_estimator.jl @@ -28,7 +28,13 @@ struct PriorAbilityEstimator{PriorT <: Distribution} <: DistributionAbilityEstim prior::PriorT end -PriorAbilityEstimator() = PriorAbilityEstimator(std_normal) +function PriorAbilityEstimator(; ncomp=0) + if ncomp == 0 + return PriorAbilityEstimator(std_normal) + else + return PriorAbilityEstimator(std_mv_normal(ncomp)) + end +end function pdf(est::PriorAbilityEstimator, tracked_responses::TrackedResponses) @@ -73,6 +79,21 @@ function mean_1d(integrator::AbilityIntegrator, denom) end +function mean( + integrator::AbilityIntegrator, + est::DistributionAbilityEstimator, + tracked_responses::TrackedResponses, + denom = normdenom(integrator, est, tracked_responses) +) + n = domdims(tracked_responses.item_bank) + expectation(IntegralCoeffs.id, + n, + integrator, + est, + tracked_responses, + denom) +end + function variance_given_mean(integrator::AbilityIntegrator, est::DistributionAbilityEstimator, tracked_responses::TrackedResponses, @@ -97,6 +118,36 @@ function variance(integrator::AbilityIntegrator, denom) end +function covariance_matrix_given_mean( + integrator::AbilityIntegrator, + est::DistributionAbilityEstimator, + tracked_responses::TrackedResponses, + mean, + denom = normdenom(integrator, est, tracked_responses) +) + n = domdims(tracked_responses.item_bank) + expectation(IntegralCoeffs.OuterProdDev(mean), + n, + integrator, + est, + tracked_responses, + denom) +end + +function covariance_matrix( + integrator::AbilityIntegrator, + est::DistributionAbilityEstimator, + tracked_responses::TrackedResponses, + denom = normdenom(integrator, est, tracked_responses)) + covariance_matrix_given_mean( + integrator, + est, + tracked_responses, + mean(integrator, est, tracked_responses, denom), + denom + ) +end + struct ModeAbilityEstimator{ DistEst <: DistributionAbilityEstimator, OptimizerT <: AbilityOptimizer diff --git a/src/next_item_rules/NextItemRules.jl b/src/next_item_rules/NextItemRules.jl index a212316..cc174d5 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/next_item_rules/NextItemRules.jl @@ -26,6 +26,7 @@ import PsychometricsBazaarBase.IntegralCoeffs using FittedItemBanks using FittedItemBanks: item_params using ..Aggregators +using ..Aggregators: covariance_matrix using QuadGK, Distributions, Optim, Base.Threads, Base.Order, StaticArrays using ConstructionBase: constructorof @@ -40,6 +41,9 @@ export catr_next_item_aliases export preallocate export compute_criteria export PointResponseExpectation, DistributionResponseExpectation +export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer +export AbilityCovarianceStateCriteria, StateCriteria, ItemCriteria +export ScalarizedStateCriteron """ $(TYPEDEF) @@ -197,6 +201,7 @@ function compute_criteria( compute_criteria(rule.criterion, responses, items) end +include("./mirt.jl") include("./aliases.jl") include("./preallocate.jl") diff --git a/src/next_item_rules/mirt.jl b/src/next_item_rules/mirt.jl new file mode 100644 index 0000000..c65c410 --- /dev/null +++ b/src/next_item_rules/mirt.jl @@ -0,0 +1,61 @@ +abstract type MatrixScalarizer end + +struct DeterminantScalarizer <: MatrixScalarizer end +(::DeterminantScalarizer)(mat) = det(mat) + +struct TraceScalarizer <: MatrixScalarizer end +(::TraceScalarizer)(mat) = tr(mat) + +abstract type StateCriteria end +abstract type ItemCriteria end + +struct AbilityCovarianceStateCriteria{ + DistEstT <: DistributionAbilityEstimator, + IntegratorT <: AbilityIntegrator +} <: StateCriteria + dist_est::DistEstT + integrator::IntegratorT + skip_zero::Bool +end + +function AbilityCovarianceStateCriteria(bits...) + skip_zero = false + @requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...) + return AbilityCovarianceStateCriteria(dist_est, integrator, skip_zero) +end + +# XXX: Should be at type level +should_minimize(::AbilityCovarianceStateCriteria) = true + +function (criteria::AbilityCovarianceStateCriteria)( + tracked_responses::TrackedResponses, + denom = normdenom(criteria.integrator, + criteria.dist_est, + tracked_responses) +) + if denom == 0.0 && criteria.skip_zero + return Inf + end + covariance_matrix( + criteria.integrator, + criteria.dist_est, + tracked_responses, + denom + ) +end + +struct ScalarizedStateCriteron{ + StateCriteriaT <: StateCriteria, + MatrixScalarizerT <: MatrixScalarizer +} <: StateCriterion + criteria::StateCriteriaT + scalarizer::MatrixScalarizerT +end + +function (ssc::ScalarizedStateCriteron)(tracked_responses) + res = ssc.criteria(tracked_responses) |> ssc.scalarizer + if !should_minimize(ssc.criteria) + res = -res + end + res +end diff --git a/src/next_item_rules/objective_function.jl b/src/next_item_rules/objective_function.jl index dfc182f..eccb292 100644 --- a/src/next_item_rules/objective_function.jl +++ b/src/next_item_rules/objective_function.jl @@ -39,8 +39,7 @@ struct AbilityVarianceStateCriterion{ skip_zero::Bool end -function AbilityVarianceStateCriterion(bits...) - skip_zero = false +function _get_dist_est_and_integrator(bits...) # XXX: Weakness in this initialisation system is showing now # This needs ot be explicitly passed dist_est and integrator, but this may # be burried within a MeanAbilityEstimator @@ -48,16 +47,18 @@ function AbilityVarianceStateCriterion(bits...) dist_est = DistributionAbilityEstimator(bits...) integrator = AbilityIntegrator(bits...) if dist_est !== nothing && integrator !== nothing - return AbilityVarianceStateCriterion(dist_est, integrator, skip_zero) + return (dist_est, integrator) end # So let's just handle this case individually for now # (Is this going to cause a problem with this being picked over something more appropriate?) @requiresome mean_ability_est = MeanAbilityEstimator(bits...) - return AbilityVarianceStateCriterion( - mean_ability_est.dist_est, - mean_ability_est.integrator, - skip_zero - ) + return (dist_est, integrator) +end + +function AbilityVarianceStateCriterion(bits...) + skip_zero = false + @requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...) + return AbilityVarianceStateCriterion(dist_est, integrator, skip_zero) end function (criterion::AbilityVarianceStateCriterion)(tracked_responses::TrackedResponses)::Float64 @@ -75,20 +76,12 @@ function (criterion::AbilityVarianceStateCriterion)( ::Union{OneDimContinuousDomain, DiscreteDomain}, tracked_responses::TrackedResponses, denom)::Float64 - mean = expectation(IntegralCoeffs.id, - 0, + return variance( criterion.integrator, criterion.dist_est, tracked_responses, - denom) - # XXX: This is not type stable and seems to possibly allocate. We need to - # show that mean is the same as our tracked responses. - return expectation(IntegralCoeffs.SqDev(mean), - 0, - criterion.integrator, - criterion.dist_est, - tracked_responses, - denom) + denom + ) end function (criterion::AbilityVarianceStateCriterion)(::Vector, diff --git a/test/ability_estimator_2d.jl b/test/ability_estimator_2d.jl index 8ca658a..33a6aa6 100644 --- a/test/ability_estimator_2d.jl +++ b/test/ability_estimator_2d.jl @@ -58,11 +58,39 @@ end @test ans[1] + ans[2]≈2.0 atol=0.001 end -# TODO -#= @testcase "2 dim information higher closer to current estimate" begin end @testcase "2 dim variance smaller closer to current estimate" begin + covariance_state_criterion = AbilityCovarianceStateCriteria(lh_est_2d, integrator_2d) + variance_criterion = ScalarizedStateCriteron( + covariance_state_criterion, DeterminantScalarizer()) + variance_item_criterion = ExpectationBasedItemCriterion(mle_mean_2d, variance_criterion) + + # Item closer to the current estimate (1, 1) + close_item = 5 + # Item further from the current estimate + far_item = 6 + + close_var = variance_item_criterion(tracked_responses_2d, close_item) + far_var = variance_item_criterion(tracked_responses_2d, far_item) + + @test close_var < far_var +end + +@testcase "2 dim variance is whack with trace scalarizer" begin + covariance_state_criterion = AbilityCovarianceStateCriteria(lh_est_2d, integrator_2d) + variance_criterion = ScalarizedStateCriteron( + covariance_state_criterion, TraceScalarizer()) + variance_item_criterion = ExpectationBasedItemCriterion(mle_mean_2d, variance_criterion) + + # Item closer to the current estimate (1, 1) + close_item = 5 + # Item further from the current estimate + far_item = 6 + + close_var = variance_item_criterion(tracked_responses_2d, close_item) + far_var = variance_item_criterion(tracked_responses_2d, far_item) + + @test far_var < close_var end -=# From 98c00b9364bc6953fcc6fb1654d1aa85561f3ad0 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Mon, 7 Oct 2024 15:45:55 +0300 Subject: [PATCH 2/6] Qualify usage of even_grid in dt test --- test/dt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dt.jl b/test/dt.jl index fb52c1a..0372378 100644 --- a/test/dt.jl +++ b/test/dt.jl @@ -4,7 +4,7 @@ num_questions = 20, num_testees = 1 ) -integrator = FunctionIntegrator(even_grid(-6, 6, 61)) +integrator = FunctionIntegrator(Integrators.even_grid(-6, 6, 61)) ability_estimator = MeanAbilityEstimator(LikelihoodAbilityEstimator(), integrator) get_response = auto_responder(@view true_responses[:, 1]) From bab98c2988d296356132c07e9fb37d95d983fbf6 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Mon, 7 Oct 2024 15:46:26 +0300 Subject: [PATCH 3/6] Formatting of ability_estimator --- src/aggregators/ability_estimator.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/aggregators/ability_estimator.jl b/src/aggregators/ability_estimator.jl index 070839a..87a372e 100644 --- a/src/aggregators/ability_estimator.jl +++ b/src/aggregators/ability_estimator.jl @@ -28,7 +28,7 @@ struct PriorAbilityEstimator{PriorT <: Distribution} <: DistributionAbilityEstim prior::PriorT end -function PriorAbilityEstimator(; ncomp=0) +function PriorAbilityEstimator(; ncomp = 0) if ncomp == 0 return PriorAbilityEstimator(std_normal) else @@ -80,10 +80,10 @@ function mean_1d(integrator::AbilityIntegrator, end function mean( - integrator::AbilityIntegrator, - est::DistributionAbilityEstimator, - tracked_responses::TrackedResponses, - denom = normdenom(integrator, est, tracked_responses) + integrator::AbilityIntegrator, + est::DistributionAbilityEstimator, + tracked_responses::TrackedResponses, + denom = normdenom(integrator, est, tracked_responses) ) n = domdims(tracked_responses.item_bank) expectation(IntegralCoeffs.id, @@ -119,11 +119,11 @@ function variance(integrator::AbilityIntegrator, end function covariance_matrix_given_mean( - integrator::AbilityIntegrator, - est::DistributionAbilityEstimator, - tracked_responses::TrackedResponses, - mean, - denom = normdenom(integrator, est, tracked_responses) + integrator::AbilityIntegrator, + est::DistributionAbilityEstimator, + tracked_responses::TrackedResponses, + mean, + denom = normdenom(integrator, est, tracked_responses) ) n = domdims(tracked_responses.item_bank) expectation(IntegralCoeffs.OuterProdDev(mean), From 7bf4e17ab8361cd1a65497d1b942ebab6f8fac3d Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Tue, 29 Oct 2024 12:28:47 +0200 Subject: [PATCH 4/6] Add todo note to comparison.jl --- src/Comparison.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Comparison.jl b/src/Comparison.jl index df92482..d8716cf 100644 --- a/src/Comparison.jl +++ b/src/Comparison.jl @@ -263,6 +263,7 @@ function run_comparison(comparison::CatComparisonConfig{ReplayResponsesExecution items_answered = items_answered ) if :after_item_criteria in comparison.phases + # TOOD: Combine with next_item if possible and requested? timed_item_criteria = @timed Stateful.item_criteria(cat) measure_all( comparison, From 038ce18300e9d307d3b1c78758ebabb96a436caf Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Tue, 29 Oct 2024 13:05:52 +0200 Subject: [PATCH 5/6] Refactor dispersion around ScalarizedStateCriteron --- src/next_item_rules/NextItemRules.jl | 6 +- src/next_item_rules/aliases.jl | 8 ++- src/next_item_rules/information.jl | 37 ++++++----- src/next_item_rules/information_special.jl | 65 ++++++++++++++++++ src/next_item_rules/mirt.jl | 67 +++++++++++++++++++ src/next_item_rules/objective_function.jl | 76 +++------------------- test/ability_estimator_2d.jl | 13 ++++ 7 files changed, 186 insertions(+), 86 deletions(-) create mode 100644 src/next_item_rules/information_special.jl diff --git a/src/next_item_rules/NextItemRules.jl b/src/next_item_rules/NextItemRules.jl index cc174d5..418fc0f 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/next_item_rules/NextItemRules.jl @@ -34,7 +34,7 @@ import ForwardDiff export ExpectationBasedItemCriterion, AbilityVarianceStateCriterion, init_thread export NextItemRule, ItemStrategyNextItemRule -export UrryItemCriterion, InformationItemCriterion, DRuleItemCriterion, TRuleItemCriterion +export UrryItemCriterion, InformationItemCriterion export RandomNextItemRule export ExhaustiveSearch1Ply export catr_next_item_aliases @@ -43,7 +43,8 @@ export compute_criteria export PointResponseExpectation, DistributionResponseExpectation export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer export AbilityCovarianceStateCriteria, StateCriteria, ItemCriteria -export ScalarizedStateCriteron +export InformationMatrixCriteria +export ScalarizedStateCriteron, ScalarizedItemCriteron """ $(TYPEDEF) @@ -72,6 +73,7 @@ end include("./random.jl") include("./information.jl") +include("./information_special.jl") include("./objective_function.jl") include("./expectation.jl") diff --git a/src/next_item_rules/aliases.jl b/src/next_item_rules/aliases.jl index 6990bac..fdf80a2 100644 --- a/src/next_item_rules/aliases.jl +++ b/src/next_item_rules/aliases.jl @@ -40,7 +40,13 @@ const mirtcat_next_item_aliases = Dict( # 'MEPV' for minimum expected posterior variance "MEPV" => _mirtcat_helper((bits, ability_estimator) -> ExpectationBasedItemCriterion( ability_estimator, - AbilityVarianceStateCriterion(bits...))) + AbilityVarianceStateCriterion(bits...))), + "Drule" => _mirtcat_helper((bits, ability_estimator) -> ScalarizedItemCriteron( + InformationMatrixCriteria(ability_estimator), + DeterminantScalarizer())), + "Trule" => _mirtcat_helper((bits, ability_estimator) -> ScalarizedItemCriteron( + InformationMatrixCriteria(ability_estimator), + TraceScalarizer())) ) # 'MLWI' for maximum likelihood weighted information diff --git a/src/next_item_rules/information.jl b/src/next_item_rules/information.jl index 4138492..c63e6af 100644 --- a/src/next_item_rules/information.jl +++ b/src/next_item_rules/information.jl @@ -3,6 +3,18 @@ using FittedItemBanks: CdfMirtItemBank, using FittedItemBanks: inner_item_response, norm_abil, y_offset, irf_size using StatsFuns: logaddexp +function log_resp_vec(ir::ItemResponse{<:TransferItemBank}, θ) + nθ = norm_abil(ir, θ) + return SVector( + logccdf(ir.item_bank.distribution, nθ), + logcdf(ir.item_bank.distribution, nθ) + ) +end + +function log_resp(ir::ItemResponse{<:TransferItemBank}, resp, θ) + logcdf(ir.item_bank.distribution, norm_abil(ir, θ)) +end + function log_resp_vec(ir::ItemResponse{<:CdfMirtItemBank}, θ) nθ = norm_abil(ir, θ) SVector(logccdf(ir.item_bank.distribution, nθ), @@ -52,26 +64,21 @@ function log_resp(ir::ItemResponse{<:AnySlipOrGuessItemBank}, val, θ) log_transform_irf_y(ir, val, log_resp(inner_item_response(ir), val, θ)) end -# How does this compare with expected_item_information. Speeds/accuracies? -# TODO: Which response models is this valid for? -# TODO: Citation/source for this equation -# TODO: Do it in log space? -function item_information(ir::ItemResponse, θ) - # irθ_prime = ForwardDiff.derivative(ir, θ) - irθ_prime = ForwardDiff.derivative(x -> resp(ir, x), θ) - irθ = resp(ir, θ) - if irθ_prime == 0.0 - return 0.0 - else - return (irθ_prime * irθ_prime) / (irθ * (1 - irθ)) - end -end - function vector_hessian(f, x, n) out = ForwardDiff.jacobian(x -> ForwardDiff.jacobian(f, x), x) return reshape(out, n, n, n) end +function double_derivative(f, x) + ForwardDiff.derivative(x -> ForwardDiff.derivative(f, x), x) +end + +function expected_item_information(ir::ItemResponse, θ::Float64) + exp_resp = resp_vec(ir, θ) + d² = double_derivative((θ -> log_resp_vec(ir, θ)), θ) + -sum(exp_resp .* d²) +end + # TODO: Unclear whether this should be implemented with ExpectationBasedItemCriterion # TODO: This is not implementing DRule but postposterior DRule function expected_item_information(ir::ItemResponse, θ::Vector{Float64}) diff --git a/src/next_item_rules/information_special.jl b/src/next_item_rules/information_special.jl new file mode 100644 index 0000000..2e75649 --- /dev/null +++ b/src/next_item_rules/information_special.jl @@ -0,0 +1,65 @@ +#= +This file contains some specialised ways to calculate information. +For some models analytical solutions are possible for information. +Most are simple applications of the chain rule +However, I haven't taken a systematic approach yet yet. +So these are just from equations in the literature. + +There aren't really any type guards on these so its up to the caller to make sure they are using the right ones. +=# + +function alt_expected_1d_item_information(ir::ItemResponse, θ) + """ + This is a special case of the expected_item_information function for + * 1-dimensional ability + * Dichotomous items + * It should be valid for at least up to the 3PL model, probably others too + + TODO: citation + """ + # irθ_prime = ForwardDiff.derivative(ir, θ) + irθ_prime = ForwardDiff.derivative(x -> resp(ir, x), θ) + irθ = resp(ir, θ) + if irθ_prime == 0.0 + return 0.0 + else + return (irθ_prime * irθ_prime) / (irθ * (1 - irθ)) + end +end + +function alt_expected_mirt_item_information(ir::ItemResponse, θ) + """ + This is a special case of the expected_item_information function for + * Multidimensional + * Dichotomous items + * It should be valid for at least up to the 3PL model, probably others too + + TODO: citation + """ + irθ_prime = ForwardDiff.gradient(x -> resp(ir, x), θ) + pθ = resp(ir, θ) + qθ = 1 - pθ + (irθ_prime * irθ_prime') / (pθ * qθ) +end + +function alt_expected_mirt_3pl_item_information(ir::ItemResponse, θ) + """ + This is a special case of the expected_item_information function for + * Multidimensional + * Dichotomous items + * 3PL model only + + Mulder J, van der Linden WJ. + Multidimensional Adaptive Testing with Optimal Design Criteria for Item Selection. + Psychometrika. 2009 Jun;74(2):273-296. doi: 10.1007/s11336-008-9097-5. + Equation 4 + """ + # XXX: Should avoid using item_params + params = item_params(ir.item_bank.discriminations, ir.index) + pθ = resp(ir, θ) + qθ = 1 - pθ + a = params.discrimination + c = params.guess + common_factor = (qθ * (pθ - c)^2) / (pθ * (1 - c)^2) + common_factor * (a * a') +end diff --git a/src/next_item_rules/mirt.jl b/src/next_item_rules/mirt.jl index c65c410..8fdc805 100644 --- a/src/next_item_rules/mirt.jl +++ b/src/next_item_rules/mirt.jl @@ -59,3 +59,70 @@ function (ssc::ScalarizedStateCriteron)(tracked_responses) end res end + +struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: ItemCriteria + ability_estimator::AbilityEstimatorT + expected_item_information::F +end + +function InformationMatrixCriteria(ability_estimator) + InformationMatrixCriteria(ability_estimator, expected_item_information) +end + +function init_thread(item_criterion::InformationMatrixCriteria, + responses::TrackedResponses) + # TODO: No need to do this one per thread. It just need to be done once per + # θ update. + # TODO: Update this to use track!(...) mechanism + ability = maybe_tracked_ability_estimate(responses, item_criterion.ability_estimator) + responses_information(responses.item_bank, responses.responses, ability) +end + +function (item_criterion::InformationMatrixCriteria)(acc_info::Matrix{Float64}, + tracked_responses::TrackedResponses, + item_idx) + # TODO: Add in information from the prior + ability = maybe_tracked_ability_estimate( + tracked_responses, item_criterion.ability_estimator) + return acc_info .+ + item_criterion.expected_item_information( + ItemResponse(tracked_responses.item_bank, item_idx), ability) +end + +should_minimize(::InformationMatrixCriteria) = false + +struct ScalarizedItemCriteron{ + ItemCriteriaT <: ItemCriteria, + MatrixScalarizerT <: MatrixScalarizer +} <: ItemCriterion + criteria::ItemCriteriaT + scalarizer::MatrixScalarizerT +end + +function (ssc::ScalarizedItemCriteron)(tracked_responses, item_idx) + res = ssc.criteria( + init_thread(ssc.criteria, tracked_responses), tracked_responses, item_idx) |> + ssc.scalarizer + if !should_minimize(ssc.criteria) + res = -res + end + res +end + +struct WeightedStateCriteria{InnerT <: StateCriteria} <: StateCriteria + weights::Vector{Float64} + criteria::InnerT +end + +function (wsc::WeightedStateCriteria)(tracked_responses, item_idx) + wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights +end + +struct WeightedItemCriteria{InnerT <: ItemCriteria} <: ItemCriteria + weights::Vector{Float64} + criteria::InnerT +end + +function (wsc::WeightedItemCriteria)(tracked_responses, item_idx) + wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights +end diff --git a/src/next_item_rules/objective_function.jl b/src/next_item_rules/objective_function.jl index eccb292..acc6940 100644 --- a/src/next_item_rules/objective_function.jl +++ b/src/next_item_rules/objective_function.jl @@ -126,8 +126,13 @@ function (item_criterion::UrryItemCriterion)(tracked_responses::TrackedResponses end # TODO: Should have Variants for point ability versus distribution ability -struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <: ItemCriterion +struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator, F} <: ItemCriterion ability_estimator::AbilityEstimatorT + expected_item_information::F +end + +function InformationItemCriterion(ability_estimator) + InformationItemCriterion(ability_estimator, expected_item_information) end function (item_criterion::InformationItemCriterion)(tracked_responses::TrackedResponses, @@ -135,70 +140,5 @@ function (item_criterion::InformationItemCriterion)(tracked_responses::TrackedRe ability = maybe_tracked_ability_estimate(tracked_responses, item_criterion.ability_estimator) ir = ItemResponse(tracked_responses.item_bank, item_idx) - return -item_information(ir, ability) -end - -abstract type InformationMatrixCriterion <: ItemCriterion end - -function init_thread(item_criterion::InformationMatrixCriterion, - responses::TrackedResponses) - # TODO: No need to do this one per thread. It just need to be done once per - # θ update. - # TODO: Update this to use track!(...) mechanism - ability = maybe_tracked_ability_estimate(responses, item_criterion.ability_estimator) - responses_information(responses.item_bank, responses.responses, ability) -end - -function information_matrix(ability_estimator, - acc_info, - tracked_responses::TrackedResponses, - item_idx) - # TODO: Add in information from the prior - ability = maybe_tracked_ability_estimate(tracked_responses, ability_estimator) - acc_info .+ - expected_item_information(ItemResponse(tracked_responses.item_bank, item_idx), ability) -end - -struct DRuleItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <: - InformationMatrixCriterion - ability_estimator::AbilityEstimatorT -end - -function (item_criterion::DRuleItemCriterion)(acc_info::Matrix{Float64}, - tracked_responses::TrackedResponses, - item_idx) - -det(information_matrix(item_criterion.ability_estimator, - acc_info, - tracked_responses, - item_idx)) -end - -# TODO: Weighted version -struct TRuleItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <: - InformationMatrixCriterion - ability_estimator::AbilityEstimatorT -end - -function (item_criterion::TRuleItemCriterion)(acc_info::Matrix{Float64}, - tracked_responses, - item_idx) - # XXX: Should not strictly need to calculate whole information matrix to get this. - # Should just be able to calculate Laplacians as we go, but ForwardDiff doesn't support this (yet?). - -tr(information_matrix(item_criterion.ability_estimator, - acc_info, - tracked_responses, - item_idx)) -end - -struct ARuleItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <: ItemCriterion - ability_estimator::AbilityEstimatorT -end - -function (item_criterion::ARuleItemCriterion)(acc_info::Nothing, - tracked_responses, - item_idx) - # TODO - # Step 1. Get covariance of ability estimate - # Basically the same idea as AbilityVarianceStateCriterion - # Step 2. Get the (weighted) trace -end + return -item_criterion.expected_item_information(ir, ability) +end \ No newline at end of file diff --git a/test/ability_estimator_2d.jl b/test/ability_estimator_2d.jl index 33a6aa6..ca1abcb 100644 --- a/test/ability_estimator_2d.jl +++ b/test/ability_estimator_2d.jl @@ -59,6 +59,19 @@ end end @testcase "2 dim information higher closer to current estimate" begin + information_matrix_criteria = InformationMatrixCriteria(mle_mean_2d) + information_criterion = ScalarizedItemCriteron( + information_matrix_criteria, DeterminantScalarizer()) + + # Item closer to the current estimate (1, 1) + close_item = 5 + # Item further from the current estimate + far_item = 6 + + close_info = information_criterion(tracked_responses_2d, close_item) + far_info = information_criterion(tracked_responses_2d, far_item) + + @test close_info > far_info end @testcase "2 dim variance smaller closer to current estimate" begin From beccee522587e04a4d288be280240beb2c08ff48 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Wed, 30 Oct 2024 15:55:12 +0200 Subject: [PATCH 6/6] Apply formatting --- src/next_item_rules/objective_function.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/next_item_rules/objective_function.jl b/src/next_item_rules/objective_function.jl index acc6940..0c901ba 100644 --- a/src/next_item_rules/objective_function.jl +++ b/src/next_item_rules/objective_function.jl @@ -126,7 +126,8 @@ function (item_criterion::UrryItemCriterion)(tracked_responses::TrackedResponses end # TODO: Should have Variants for point ability versus distribution ability -struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator, F} <: ItemCriterion +struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator, F} <: + ItemCriterion ability_estimator::AbilityEstimatorT expected_item_information::F end @@ -141,4 +142,4 @@ function (item_criterion::InformationItemCriterion)(tracked_responses::TrackedRe item_criterion.ability_estimator) ir = ItemResponse(tracked_responses.item_bank, item_idx) return -item_criterion.expected_item_information(ir, ability) -end \ No newline at end of file +end