From eb97a2b048c45d6ef88dfae966ec8d6635a49e0d Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Thu, 28 Nov 2024 16:06:08 +0200 Subject: [PATCH 1/4] More next item rules refactoring * Criteria => MultiCriterion * Get rid of most functors and convert into named methods * Introduce abstract types for pointwise item criteria --- src/Sim.jl | 4 +- src/decision_tree/DecisionTree.jl | 2 +- src/next_item_rules/NextItemRules.jl | 7 +-- .../combinators/expectation.jl | 8 +-- .../combinators/scalarizers.jl | 45 +++++++++-------- .../criteria/item/information.jl | 9 ++-- src/next_item_rules/criteria/item/urry.jl | 3 +- .../criteria/state/ability_variance.jl | 28 +++++++---- src/next_item_rules/prelude/abstract.jl | 25 ++++++++-- src/next_item_rules/prelude/criteria.jl | 50 ++++++++++++++++--- src/next_item_rules/prelude/next_item_rule.jl | 44 +++++++++++++++- src/next_item_rules/prelude/strategy.jl | 42 ---------------- src/next_item_rules/strategies/exhaustive.jl | 9 ++-- src/next_item_rules/strategies/random.jl | 2 +- test/ability_estimator_1d.jl | 24 ++++----- test/ability_estimator_2d.jl | 18 ++++--- 16 files changed, 196 insertions(+), 124 deletions(-) delete mode 100644 src/next_item_rules/prelude/strategy.jl diff --git a/src/Sim.jl b/src/Sim.jl index 8bc2f1f..f29fdd2 100644 --- a/src/Sim.jl +++ b/src/Sim.jl @@ -5,7 +5,7 @@ using FittedItemBanks: AbstractItemBank, ResponseType using ..Responses using ..CatConfig: CatLoopConfig, CatRules using ..Aggregators: TrackedResponses, add_response!, Speculator, Aggregators -using ..NextItemRules: compute_criteria +using ..NextItemRules: compute_criteria, best_item export run_cat, prompt_response, auto_responder @@ -56,7 +56,7 @@ function run_cat(cat_config::CatLoopConfig{RulesT}, "Best items" end criteria try - next_index = next_item(responses, item_bank) + next_index = best_item(next_item, responses, item_bank) catch exc if isa(exc, NextItemError) @warn "Terminating early due to error getting next item" err=sprint( diff --git a/src/decision_tree/DecisionTree.jl b/src/decision_tree/DecisionTree.jl index 29ab7ea..a15e2b1 100644 --- a/src/decision_tree/DecisionTree.jl +++ b/src/decision_tree/DecisionTree.jl @@ -128,7 +128,7 @@ function generate_dt_cat(config::DecisionTreeGenerationConfig, item_bank) while true track!(responses, config.ability_tracker) ability = config.ability_estimator(responses) - next_item = config.next_item(responses, item_bank) + next_item = best_item(config.next_item, responses, item_bank) insert!(decision_tree_result, responses.responses, ability, next_item) if state_tree.cur_depth == state_tree.max_depth diff --git a/src/next_item_rules/NextItemRules.jl b/src/next_item_rules/NextItemRules.jl index e028c4e..acf69f2 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/next_item_rules/NextItemRules.jl @@ -39,17 +39,18 @@ export RandomNextItemRule export ExhaustiveSearch export catr_next_item_aliases export preallocate -export compute_criteria +export compute_criteria, compute_criterion, compute_multi_criterion, + compute_pointwise_criterion +export best_item export PointResponseExpectation, DistributionResponseExpectation export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer -export AbilityCovarianceStateCriteria, StateCriteria, ItemCriteria +export AbilityCovarianceStateMultiCriterion, StateMultiCriterion, ItemMultiCriterion export InformationMatrixCriteria export ScalarizedStateCriteron, ScalarizedItemCriteron # Prelude include("./prelude/abstract.jl") include("./prelude/next_item_rule.jl") -include("./prelude/strategy.jl") include("./prelude/criteria.jl") include("./prelude/preallocate.jl") diff --git a/src/next_item_rules/combinators/expectation.jl b/src/next_item_rules/combinators/expectation.jl index 3e58153..61ac76f 100644 --- a/src/next_item_rules/combinators/expectation.jl +++ b/src/next_item_rules/combinators/expectation.jl @@ -96,14 +96,16 @@ function init_thread(::ExpectationBasedItemCriterion, responses::TrackedResponse end function _generic_criterion(criterion::StateCriterion, tracked_responses, item_idx) - criterion(tracked_responses) + compute_criterion(criterion, tracked_responses) end # TODO: Support init_thread for wrapped ItemCriterion function _generic_criterion(criterion::ItemCriterion, tracked_responses, item_idx) - criterion(tracked_responses, item_idx) + compute_criterion(criterion, tracked_responses, item_idx) end -function (item_criterion::ExpectationBasedItemCriterion)(speculator::Speculator, +function compute_criterion( + item_criterion::ExpectationBasedItemCriterion, + speculator::Speculator, tracked_responses::TrackedResponses, item_idx) exp_resp = Aggregators.response_expectation(item_criterion.response_expectation, diff --git a/src/next_item_rules/combinators/scalarizers.jl b/src/next_item_rules/combinators/scalarizers.jl index 7bdf723..9326a6d 100644 --- a/src/next_item_rules/combinators/scalarizers.jl +++ b/src/next_item_rules/combinators/scalarizers.jl @@ -1,57 +1,58 @@ struct DeterminantScalarizer <: MatrixScalarizer end -(::DeterminantScalarizer)(mat) = det(mat) +scalarize(::DeterminantScalarizer, mat) = det(mat) struct TraceScalarizer <: MatrixScalarizer end -(::TraceScalarizer)(mat) = tr(mat) +scalarize(::TraceScalarizer, mat) = tr(mat) struct ScalarizedItemCriteron{ - ItemCriteriaT <: ItemCriteria, + ItemMultiCriterionT <: ItemMultiCriterion, MatrixScalarizerT <: MatrixScalarizer } <: ItemCriterion - criteria::ItemCriteriaT + criteria::ItemMultiCriterionT scalarizer::MatrixScalarizerT end -function (ssc::ScalarizedItemCriteron)(tracked_responses, item_idx) - res = ssc.criteria( - init_thread(ssc.criteria, tracked_responses), tracked_responses, item_idx) |> - ssc.scalarizer - if !should_minimize(ssc.criteria) - res = -res - end - res -end - struct ScalarizedStateCriteron{ - StateCriteriaT <: StateCriteria, + StateMultiCriterionT <: StateMultiCriterion, MatrixScalarizerT <: MatrixScalarizer } <: StateCriterion - criteria::StateCriteriaT + criteria::StateMultiCriterionT scalarizer::MatrixScalarizerT end -function (ssc::ScalarizedStateCriteron)(tracked_responses) - res = ssc.criteria(tracked_responses) |> ssc.scalarizer +function compute_criterion(ssc::Union{ScalarizedItemCriteron, ScalarizedStateCriteron}, + tracked_responses::TrackedResponses, item_idx...) + res = scalarize( + ssc.scalarizer, + compute_multi_criterion( + ssc.criteria, + init_thread(ssc.criteria, tracked_responses), + tracked_responses, + item_idx... + ) + ) if !should_minimize(ssc.criteria) res = -res end res end -struct WeightedStateCriteria{InnerT <: StateCriteria} <: StateCriteria +struct WeightedStateMultiCriterion{InnerT <: StateMultiCriterion} <: StateMultiCriterion weights::Vector{Float64} criteria::InnerT end -function (wsc::WeightedStateCriteria)(tracked_responses, item_idx) +function compute_multi_criterion( + wsc::WeightedStateMultiCriterion, tracked_responses::TrackedResponses, item_idx) wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights end -struct WeightedItemCriteria{InnerT <: ItemCriteria} <: ItemCriteria +struct WeightedItemMultiCriterion{InnerT <: ItemMultiCriterion} <: ItemMultiCriterion weights::Vector{Float64} criteria::InnerT end -function (wsc::WeightedItemCriteria)(tracked_responses, item_idx) +function compute_multi_criterion( + wsc::WeightedItemMultiCriterion, tracked_responses::TrackedResponses, item_idx) wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights end diff --git a/src/next_item_rules/criteria/item/information.jl b/src/next_item_rules/criteria/item/information.jl index b20484c..04987e4 100644 --- a/src/next_item_rules/criteria/item/information.jl +++ b/src/next_item_rules/criteria/item/information.jl @@ -9,7 +9,8 @@ function InformationItemCriterion(ability_estimator) InformationItemCriterion(ability_estimator, expected_item_information) end -function (item_criterion::InformationItemCriterion)(tracked_responses::TrackedResponses, +function compute_criterion( + item_criterion::InformationItemCriterion, tracked_responses::TrackedResponses, item_idx) ability = maybe_tracked_ability_estimate(tracked_responses, item_criterion.ability_estimator) @@ -17,7 +18,8 @@ function (item_criterion::InformationItemCriterion)(tracked_responses::TrackedRe return -item_criterion.expected_item_information(ir, ability) end -struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: ItemCriteria +struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: + ItemMultiCriterion ability_estimator::AbilityEstimatorT expected_item_information::F end @@ -35,7 +37,8 @@ function init_thread(item_criterion::InformationMatrixCriteria, responses_information(responses.item_bank, responses.responses, ability) end -function (item_criterion::InformationMatrixCriteria)(acc_info::Matrix{Float64}, +function compute_multi_criterion( + item_criterion::InformationMatrixCriteria, acc_info::Matrix{Float64}, tracked_responses::TrackedResponses, item_idx) # TODO: Add in information from the prior diff --git a/src/next_item_rules/criteria/item/urry.jl b/src/next_item_rules/criteria/item/urry.jl index dbf2de3..e71a82b 100644 --- a/src/next_item_rules/criteria/item/urry.jl +++ b/src/next_item_rules/criteria/item/urry.jl @@ -14,7 +14,8 @@ function raw_difficulty(item_bank, item_idx) item_params(item_bank, item_idx).difficulty end -function (item_criterion::UrryItemCriterion)(tracked_responses::TrackedResponses, item_idx) +function compute_criterion( + item_criterion::UrryItemCriterion, tracked_responses::TrackedResponses, item_idx) ability = maybe_tracked_ability_estimate(tracked_responses, item_criterion.ability_estimator) diff = raw_difficulty(tracked_responses.item_bank, item_idx) diff --git a/src/next_item_rules/criteria/state/ability_variance.jl b/src/next_item_rules/criteria/state/ability_variance.jl index b685e1b..5e13a7a 100644 --- a/src/next_item_rules/criteria/state/ability_variance.jl +++ b/src/next_item_rules/criteria/state/ability_variance.jl @@ -36,7 +36,8 @@ function AbilityVarianceStateCriterion(bits...) return AbilityVarianceStateCriterion(dist_est, integrator, skip_zero) end -function (criterion::AbilityVarianceStateCriterion)(tracked_responses::TrackedResponses)::Float64 +function compute_criterion(criterion::AbilityVarianceStateCriterion, + tracked_responses::TrackedResponses)::Float64 # XXX: Not sure if the estimator should come from somewhere else here denom = normdenom(criterion.integrator, criterion.dist_est, @@ -44,10 +45,11 @@ function (criterion::AbilityVarianceStateCriterion)(tracked_responses::TrackedRe if denom == 0.0 && criterion.skip_zero return Inf end - criterion(DomainType(tracked_responses.item_bank), tracked_responses, denom) + compute_criterion( + criterion, DomainType(tracked_responses.item_bank), tracked_responses, denom) end -function (criterion::AbilityVarianceStateCriterion)( +function compute_criterion(criterion::AbilityVarianceStateCriterion, ::Union{OneDimContinuousDomain, DiscreteDomain}, tracked_responses::TrackedResponses, denom)::Float64 @@ -59,9 +61,12 @@ function (criterion::AbilityVarianceStateCriterion)( ) end -function (criterion::AbilityVarianceStateCriterion)(::Vector, +function compute_criterion( + criterion::AbilityVarianceStateCriterion, + ::Vector, tracked_responses::TrackedResponses, - denom)::Float64 + denom +)::Float64 # XXX: Not quite sure about this --- is it useful, the MIRT rules cover this case mean = expectation(IntegralCoeffs.id, ndims(tracked_responses.item_bank), @@ -77,25 +82,26 @@ function (criterion::AbilityVarianceStateCriterion)(::Vector, denom) end -struct AbilityCovarianceStateCriteria{ +struct AbilityCovarianceStateMultiCriterion{ DistEstT <: DistributionAbilityEstimator, IntegratorT <: AbilityIntegrator -} <: StateCriteria +} <: StateMultiCriterion dist_est::DistEstT integrator::IntegratorT skip_zero::Bool end -function AbilityCovarianceStateCriteria(bits...) +function AbilityCovarianceStateMultiCriterion(bits...) skip_zero = false @requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...) - return AbilityCovarianceStateCriteria(dist_est, integrator, skip_zero) + return AbilityCovarianceStateMultiCriterion(dist_est, integrator, skip_zero) end # XXX: Should be at type level -should_minimize(::AbilityCovarianceStateCriteria) = true +should_minimize(::AbilityCovarianceStateMultiCriterion) = true -function (criteria::AbilityCovarianceStateCriteria)( +function compute_multi_criterion( + criteria::AbilityCovarianceStateMultiCriterion, tracked_responses::TrackedResponses, denom = normdenom(criteria.integrator, criteria.dist_est, diff --git a/src/next_item_rules/prelude/abstract.jl b/src/next_item_rules/prelude/abstract.jl index e00dba4..fdc68c9 100644 --- a/src/next_item_rules/prelude/abstract.jl +++ b/src/next_item_rules/prelude/abstract.jl @@ -4,22 +4,32 @@ $(TYPEDEF) Abstract base type for all item selection rules. All descendants of this type are expected to implement the interface -`(rule::NextItemRule)(responses::TrackedResponses, items::AbstractItemBank)::Int` +`(::NextItemRule)(responses::TrackedResponses, items::AbstractItemBank)::Int`. + +In practice, all adaptive rules in this package use `ItemStrategyNextItemRule`. $(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true) Implicit constructor for $(FUNCTIONNAME). Uses any given `NextItemRule` or -delegates to `ItemStrategyNextItemRule`. +delegates to `ItemStrategyNextItemRule` the default instance. """ abstract type NextItemRule <: CatConfigBase end """ $(TYPEDEF) + +Abstract type for next item strategies, tightly coupled with `ItemStrategyNextItemRule`. +All descendants of this type are expected to implement the interface +`(rule::ItemStrategyNextItemRule{::NextItemStrategy, ::ItemCriterion})(responses::TrackedResponses, + items) where {ItemCriterionT <: } +`(strategy::NextItemStrategy)(; parallel=true)::NextItemStrategy` """ abstract type NextItemStrategy <: CatConfigBase end """ $(TYPEDEF) + +Abstract type for next item criteria """ abstract type ItemCriterion <: CatConfigBase end @@ -28,6 +38,13 @@ $(TYPEDEF) """ abstract type StateCriterion <: CatConfigBase end +""" +$(TYPEDEF) +""" +abstract type PointwiseItemCriterion <: CatConfigBase end + +abstract type PurePointwiseItemCriterion <: PointwiseItemCriterion end + abstract type MatrixScalarizer end -abstract type StateCriteria end -abstract type ItemCriteria end +abstract type StateMultiCriterion end +abstract type ItemMultiCriterion end diff --git a/src/next_item_rules/prelude/criteria.jl b/src/next_item_rules/prelude/criteria.jl index 4c9986c..118a76e 100644 --- a/src/next_item_rules/prelude/criteria.jl +++ b/src/next_item_rules/prelude/criteria.jl @@ -10,25 +10,34 @@ function ItemCriterion(bits...; ability_estimator = nothing, ability_tracker = n ability_tracker = ability_tracker) end +function StateCriterion(bits...; ability_estimator = nothing, ability_tracker = nothing) + @returnsome find1_instance(StateCriterion, bits) + @returnsome find1_type(StateCriterion, bits) typ->typ() +end + function init_thread(::ItemCriterion, ::TrackedResponses) nothing end -function StateCriterion(bits...; ability_estimator = nothing, ability_tracker = nothing) - @returnsome find1_instance(StateCriterion, bits) - @returnsome find1_type(StateCriterion, bits) typ->typ() +function init_thread(::StateCriterion, ::TrackedResponses) + nothing end -function (item_criterion::ItemCriterion)(::Nothing, tracked_responses, item_idx) - item_criterion(tracked_responses, item_idx) +function compute_criterion( + item_criterion::ItemCriterion, ::Nothing, tracked_responses, item_idx) + compute_criterion(item_criterion, tracked_responses, item_idx) end -function (item_criterion::ItemCriterion)(tracked_responses, item_idx) +function compute_criterion(item_criterion::ItemCriterion, tracked_responses, item_idx) criterion_state = init_thread(item_criterion, tracked_responses) if criterion_state === nothing error("Tried to run an state-requiring item criterion $(typeof(item_criterion)), but init_thread(...) returned nothing") end - item_criterion(criterion_state, tracked_responses, item_idx) + compute_criterion(item_criterion, criterion_state, tracked_responses, item_idx) +end + +function compute_criterion(state_criterion::StateCriterion, ::Nothing, tracked_responses) + compute_criterion(state_criterion, tracked_responses) end function compute_criteria( @@ -48,3 +57,30 @@ function compute_criteria( ) where {StrategyT, ItemCriterionT <: ItemCriterion} compute_criteria(rule.criterion, responses, items) end + +function compute_pointwise_criterion( + ppic::PurePointwiseItemCriterion, tracked_responses, item_idx) + compute_pointwise_criterion(ppic, ItemResponse(tracked_responses.item_bank, item_idx)) +end + +struct PurePointwiseItemCriterionFunction{PointwiseItemCriterionT <: PointwiseItemCriterion} + item_response::ItemResponse +end + +function init_thread(::ItemMultiCriterion, ::TrackedResponses) + nothing +end + +function init_thread(::StateMultiCriterion, ::TrackedResponses) + nothing +end + +function compute_multi_criterion( + item_criterion::ItemMultiCriterion, ::Nothing, tracked_responses, item_idx) + compute_multi_criterion(item_criterion, tracked_responses, item_idx) +end + +function compute_multi_criterion( + state_criterion::StateMultiCriterion, ::Nothing, tracked_responses) + compute_multi_criterion(state_criterion, tracked_responses) +end diff --git a/src/next_item_rules/prelude/next_item_rule.jl b/src/next_item_rules/prelude/next_item_rule.jl index 244efdd..cfd8263 100644 --- a/src/next_item_rules/prelude/next_item_rule.jl +++ b/src/next_item_rules/prelude/next_item_rule.jl @@ -1,4 +1,3 @@ - function NextItemRule(bits...; ability_estimator = nothing, ability_tracker = nothing, @@ -9,3 +8,46 @@ function NextItemRule(bits...; ability_tracker = ability_tracker, parallel = parallel) end + +function NextItemStrategy(; parallel = true) + ExhaustiveSearch(parallel) +end + +function NextItemStrategy(bits...; parallel = true) + @returnsome find1_instance(NextItemStrategy, bits) + @returnsome find1_type(NextItemStrategy, bits) typ->typ(; parallel = parallel) + @returnsome NextItemStrategy(; parallel = parallel) +end + +""" +$(TYPEDEF) +$(TYPEDFIELDS) + +`ItemStrategyNextItemRule` which together with a `NextItemStrategy` acts as an +adapter by which an `ItemCriterion` can serve as a `NextItemRule`. + + $(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true) + +Implicit constructor for $(FUNCTIONNAME). Will default to +`ExhaustiveSearch` when no `NextItemStrategy` is given. +""" +struct ItemStrategyNextItemRule{ + NextItemStrategyT <: NextItemStrategy, + ItemCriterionT <: ItemCriterion +} <: NextItemRule + strategy::NextItemStrategyT + criterion::ItemCriterionT +end + +function ItemStrategyNextItemRule(bits...; + parallel = true, + ability_estimator = nothing, + ability_tracker = nothing) + strategy = NextItemStrategy(bits...; parallel = parallel) + criterion = ItemCriterion(bits...; + ability_estimator = ability_estimator, + ability_tracker = ability_tracker) + if strategy !== nothing && criterion !== nothing + return ItemStrategyNextItemRule(strategy, criterion) + end +end diff --git a/src/next_item_rules/prelude/strategy.jl b/src/next_item_rules/prelude/strategy.jl deleted file mode 100644 index 995831f..0000000 --- a/src/next_item_rules/prelude/strategy.jl +++ /dev/null @@ -1,42 +0,0 @@ -function NextItemStrategy(; parallel = true) - ExhaustiveSearch(parallel) -end - -function NextItemStrategy(bits...; parallel = true) - @returnsome find1_instance(NextItemStrategy, bits) - @returnsome find1_type(NextItemStrategy, bits) typ->typ(; parallel = parallel) - @returnsome NextItemStrategy(; parallel = parallel) -end - -""" -$(TYPEDEF) -$(TYPEDFIELDS) - -`ItemStrategyNextItemRule` which together with a `NextItemStrategy` acts as an -adapter by which an `ItemCriterion` can serve as a `NextItemRule`. - - $(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true) - -Implicit constructor for $(FUNCTIONNAME). Will default to -`ExhaustiveSearch` when no `NextItemStrategy` is given. -""" -struct ItemStrategyNextItemRule{ - NextItemStrategyT <: NextItemStrategy, - ItemCriterionT <: ItemCriterion -} <: NextItemRule - strategy::NextItemStrategyT - criterion::ItemCriterionT -end - -function ItemStrategyNextItemRule(bits...; - parallel = true, - ability_estimator = nothing, - ability_tracker = nothing) - strategy = NextItemStrategy(bits...; parallel = parallel) - criterion = ItemCriterion(bits...; - ability_estimator = ability_estimator, - ability_tracker = ability_tracker) - if strategy !== nothing && criterion !== nothing - return ItemStrategyNextItemRule(strategy, criterion) - end -end diff --git a/src/next_item_rules/strategies/exhaustive.jl b/src/next_item_rules/strategies/exhaustive.jl index 7f20895..8337ffb 100644 --- a/src/next_item_rules/strategies/exhaustive.jl +++ b/src/next_item_rules/strategies/exhaustive.jl @@ -15,7 +15,7 @@ function exhaustive_search(objective::ItemCriterionT, continue end - obj_val = objective(objective_state, responses, item_idx) + obj_val = compute_criterion(objective, objective_state, responses, item_idx) if obj_val <= min_obj_val min_obj_val = obj_val @@ -34,8 +34,11 @@ $(TYPEDFIELDS) parallel::Bool = false end -function (rule::ItemStrategyNextItemRule{ExhaustiveSearch, ItemCriterionT})(responses, - items) where {ItemCriterionT <: ItemCriterion} +function best_item( + rule::ItemStrategyNextItemRule{ExhaustiveSearch, ItemCriterionT}, + responses::TrackedResponses, + items +) where {ItemCriterionT <: ItemCriterion} #, rule.strategy.parallel exhaustive_search(rule.criterion, responses, items)[1] end diff --git a/src/next_item_rules/strategies/random.jl b/src/next_item_rules/strategies/random.jl index 9da6f8c..4f5965d 100644 --- a/src/next_item_rules/strategies/random.jl +++ b/src/next_item_rules/strategies/random.jl @@ -22,7 +22,7 @@ function RandomNextItemRule(bits...) end =# -function (rule::RandomNextItemRule)(responses::TrackedResponses, items) +function best_item(rule::RandomNextItemRule, responses::TrackedResponses, items) # TODO: This is not efficient item_idxes = Set(1:length(items)) available = setdiff(item_idxes, Set(responses.responses.indices)) diff --git a/test/ability_estimator_1d.jl b/test/ability_estimator_1d.jl index 3e9f40c..2ae2dde 100644 --- a/test/ability_estimator_1d.jl +++ b/test/ability_estimator_1d.jl @@ -58,16 +58,16 @@ information_item_criterion = InformationItemCriterion(mle_mean_1d) @testcase "1 dim neg information smaller closer to current estimate" begin @test ( - information_item_criterion(tracked_responses_1d, 5) < - information_item_criterion(tracked_responses_1d, 6) + compute_criterion(information_item_criterion, tracked_responses_1d, 5) < + compute_criterion(information_item_criterion, tracked_responses_1d, 6) ) end @testcase "1 dim neg information smaller with igher discrimination" begin @test ( - information_item_criterion(tracked_responses_1d, 7) < - information_item_criterion(tracked_responses_1d, 5) < - information_item_criterion(tracked_responses_1d, 8) + compute_criterion(information_item_criterion, tracked_responses_1d, 7) < + compute_criterion(information_item_criterion, tracked_responses_1d, 5) < + compute_criterion(information_item_criterion, tracked_responses_1d, 8) ) end @@ -79,23 +79,23 @@ ability_variance_item_criterion = ExpectationBasedItemCriterion( @testcase "postposterior 1 dim variance smaller closer to current estimate" begin @test ( - ability_variance_item_criterion(tracked_responses_1d, 5) < - ability_variance_item_criterion(tracked_responses_1d, 6) + compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 5) < + compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 6) ) end @testcase "postposterior 1 dim variance smaller with higher discrimination" begin @test ( - ability_variance_item_criterion(tracked_responses_1d, 7) < - ability_variance_item_criterion(tracked_responses_1d, 5) < - ability_variance_item_criterion(tracked_responses_1d, 8) + compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 7) < + compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 5) < + compute_criterion(ability_variance_item_criterion, tracked_responses_1d, 8) ) end @testcase "1 dim variance decreases with new responses" begin - orig_var = ability_variance_state_criterion(tracked_responses_1d) + orig_var = compute_criterion(ability_variance_state_criterion, tracked_responses_1d) next_responses = deepcopy(tracked_responses_1d) add_response!(next_responses, Response(ResponseType(item_bank_1d), 5, 0)) - new_var = ability_variance_state_criterion(next_responses) + new_var = compute_criterion(ability_variance_state_criterion, next_responses) @test new_var < orig_var end diff --git a/test/ability_estimator_2d.jl b/test/ability_estimator_2d.jl index ca1abcb..ab2f465 100644 --- a/test/ability_estimator_2d.jl +++ b/test/ability_estimator_2d.jl @@ -68,14 +68,15 @@ end # Item further from the current estimate far_item = 6 - close_info = information_criterion(tracked_responses_2d, close_item) - far_info = information_criterion(tracked_responses_2d, far_item) + close_info = compute_criterion(information_criterion, tracked_responses_2d, close_item) + far_info = compute_criterion(information_criterion, tracked_responses_2d, far_item) @test close_info > far_info end @testcase "2 dim variance smaller closer to current estimate" begin - covariance_state_criterion = AbilityCovarianceStateCriteria(lh_est_2d, integrator_2d) + covariance_state_criterion = AbilityCovarianceStateMultiCriterion( + lh_est_2d, integrator_2d) variance_criterion = ScalarizedStateCriteron( covariance_state_criterion, DeterminantScalarizer()) variance_item_criterion = ExpectationBasedItemCriterion(mle_mean_2d, variance_criterion) @@ -85,14 +86,15 @@ end # Item further from the current estimate far_item = 6 - close_var = variance_item_criterion(tracked_responses_2d, close_item) - far_var = variance_item_criterion(tracked_responses_2d, far_item) + close_var = compute_criterion(variance_item_criterion, tracked_responses_2d, close_item) + far_var = compute_criterion(variance_item_criterion, tracked_responses_2d, far_item) @test close_var < far_var end @testcase "2 dim variance is whack with trace scalarizer" begin - covariance_state_criterion = AbilityCovarianceStateCriteria(lh_est_2d, integrator_2d) + covariance_state_criterion = AbilityCovarianceStateMultiCriterion( + lh_est_2d, integrator_2d) variance_criterion = ScalarizedStateCriteron( covariance_state_criterion, TraceScalarizer()) variance_item_criterion = ExpectationBasedItemCriterion(mle_mean_2d, variance_criterion) @@ -102,8 +104,8 @@ end # Item further from the current estimate far_item = 6 - close_var = variance_item_criterion(tracked_responses_2d, close_item) - far_var = variance_item_criterion(tracked_responses_2d, far_item) + close_var = compute_criterion(variance_item_criterion, tracked_responses_2d, close_item) + far_var = compute_criterion(variance_item_criterion, tracked_responses_2d, far_item) @test far_var < close_var end From afbbfa36896e1d081f3891458ab3dd10ad1cf70c Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Thu, 28 Nov 2024 23:35:49 +0200 Subject: [PATCH 2/4] Add tests for Stateful --- test/stateful.jl | 58 +++++++++++++++++++++++++++++++++++++++++++++++ test/tests_top.jl | 5 ++++ 2 files changed, 63 insertions(+) create mode 100644 test/stateful.jl diff --git a/test/stateful.jl b/test/stateful.jl new file mode 100644 index 0000000..4e90f1d --- /dev/null +++ b/test/stateful.jl @@ -0,0 +1,58 @@ +@testcase "Stateful" begin + rng = Random.default_rng(42) + + # Create test data + (item_bank, abilities, true_responses) = dummy_full( + rng, + SimpleItemBankSpec(StdModel3PL(), OneDimContinuousDomain(), BooleanResponse()); + num_questions = 4, + num_testees = 2 + ) + + @testset "StatefulCatConfig basic usage" begin + rules = CatRules( + FixedItemsTerminationCondition(2), + Dummy.DummyAbilityEstimator(0), + RandomNextItemRule() + ) + + # Initialize config + cat_config = Stateful.StatefulCatConfig(rules, item_bank) + + # Test initialization state + @test isempty(Stateful.get_responses(cat_config)) + + # Add responses and check state + Stateful.add_response!(cat_config, 1, true) + Stateful.add_response!(cat_config, 2, false) + + @test length(Stateful.get_responses(cat_config).indices) == 2 + + # Test ability estimation + ability, _ = Stateful.get_ability(cat_config) + @test ability isa Real + + # Test reset + Stateful.reset!(cat_config) + @test isempty(Stateful.get_responses(cat_config)) + end + + @testset "Stateful next item selection" begin + rules = CatRules( + FixedItemsTerminationCondition(2), + Dummy.DummyAbilityEstimator(0), + RandomNextItemRule() + ) + cat_config = Stateful.StatefulCatConfig(rules, item_bank) + + # Test first item selection + first_item = Stateful.next_item(cat_config) + @test 1 <= first_item <= 4 + + # Add response and test next item + Stateful.add_response!(cat_config, first_item, true) + second_item = Stateful.next_item(cat_config) + @test 1 <= second_item <= 4 + @test second_item != first_item # Should select different item + end +end diff --git a/test/tests_top.jl b/test/tests_top.jl index 64e887c..81e2511 100644 --- a/test/tests_top.jl +++ b/test/tests_top.jl @@ -12,6 +12,7 @@ using ComputerAdaptiveTesting.Sim using PsychometricsBazaarBase.Integrators using PsychometricsBazaarBase.Optimizers using ComputerAdaptiveTesting.DecisionTree +using ComputerAdaptiveTesting: Stateful using Distributions using Distributions: ZeroMeanIsoNormal, Zeros, ScalMat using Optim @@ -47,6 +48,10 @@ end include("./dt.jl") end +@testset "stateful" begin + include("./stateful.jl") +end + @testset "format" begin include("./format.jl") end From c247598f35f60326f5546f0b667a57a3058aa090 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Thu, 28 Nov 2024 23:35:56 +0200 Subject: [PATCH 3/4] Use best_item in StatefulCatConfig --- src/Stateful.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Stateful.jl b/src/Stateful.jl index 4092eef..1a7563f 100644 --- a/src/Stateful.jl +++ b/src/Stateful.jl @@ -4,7 +4,7 @@ using FittedItemBanks: AbstractItemBank, ResponseType using ..Aggregators: TrackedResponses, Aggregators using ..CatConfig: CatLoopConfig, CatRules using ..Responses: BareResponses, Response -using ..NextItemRules: compute_criteria +using ..NextItemRules: compute_criteria, best_item ## StatefulCat interface abstract type StatefulCat end @@ -73,7 +73,7 @@ function StatefulCatConfig(rules, item_bank) end function next_item(config::StatefulCatConfig) - return config.rules.next_item(config.tracked_responses, config.item_bank) + return best_item(config.rules.next_item, config.tracked_responses, config.item_bank) end function ranked_items(config::StatefulCatConfig) From 5da1f4aa87ffd835287465af40ffc5c2b26e8bd9 Mon Sep 17 00:00:00 2001 From: Frankie Robertson Date: Fri, 29 Nov 2024 00:28:54 +0200 Subject: [PATCH 4/4] Fix up benchmark --- benchmark/benchmarks.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index a1c8597..ff1c61e 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -48,8 +48,11 @@ function prepare_4pls(group) tracked_responses = TrackedResponses(BareResponses(ResponseType(item_bank)), item_bank, NullAbilityTracker()) - group["$(est_nick)_point_mepv_bare"] = @benchmarkable ($next_item_rule)( - $tracked_responses, $item_bank) + group["$(est_nick)_point_mepv_bare"] = @benchmarkable best_item( + $next_item_rule, + $tracked_responses, + $item_bank + ) bare_responses = BareResponses( ResponseType(item_bank), response_idxs, @@ -60,8 +63,11 @@ function prepare_4pls(group) bare_responses, item_bank, NullAbilityTracker()) - group["$(est_nick)_point_mepv_10"] = @benchmarkable ($next_item_rule)( - $tracked_responses, $item_bank) + group["$(est_nick)_point_mepv_10"] = @benchmarkable best_item( + $next_item_rule, + $tracked_responses, + $item_bank + ) end return group end