Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ function prepare_4pls(group)
tracked_responses = TrackedResponses(BareResponses(ResponseType(item_bank)),
item_bank,
NullAbilityTracker())
group["$(est_nick)_point_mepv_bare"] = @benchmarkable ($next_item_rule)(
$tracked_responses, $item_bank)
group["$(est_nick)_point_mepv_bare"] = @benchmarkable best_item(
$next_item_rule,
$tracked_responses,
$item_bank
)
bare_responses = BareResponses(
ResponseType(item_bank),
response_idxs,
Expand All @@ -60,8 +63,11 @@ function prepare_4pls(group)
bare_responses,
item_bank,
NullAbilityTracker())
group["$(est_nick)_point_mepv_10"] = @benchmarkable ($next_item_rule)(
$tracked_responses, $item_bank)
group["$(est_nick)_point_mepv_10"] = @benchmarkable best_item(
$next_item_rule,
$tracked_responses,
$item_bank
)
end
return group
end
Expand Down
4 changes: 2 additions & 2 deletions src/Sim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using FittedItemBanks: AbstractItemBank, ResponseType
using ..Responses
using ..CatConfig: CatLoopConfig, CatRules
using ..Aggregators: TrackedResponses, add_response!, Speculator, Aggregators
using ..NextItemRules: compute_criteria
using ..NextItemRules: compute_criteria, best_item

export run_cat, prompt_response, auto_responder

Expand Down Expand Up @@ -56,7 +56,7 @@ function run_cat(cat_config::CatLoopConfig{RulesT},
"Best items"
end criteria
try
next_index = next_item(responses, item_bank)
next_index = best_item(next_item, responses, item_bank)
catch exc
if isa(exc, NextItemError)
@warn "Terminating early due to error getting next item" err=sprint(
Expand Down
4 changes: 2 additions & 2 deletions src/Stateful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using FittedItemBanks: AbstractItemBank, ResponseType
using ..Aggregators: TrackedResponses, Aggregators
using ..CatConfig: CatLoopConfig, CatRules
using ..Responses: BareResponses, Response
using ..NextItemRules: compute_criteria
using ..NextItemRules: compute_criteria, best_item

## StatefulCat interface
abstract type StatefulCat end
Expand Down Expand Up @@ -73,7 +73,7 @@ function StatefulCatConfig(rules, item_bank)
end

function next_item(config::StatefulCatConfig)
return config.rules.next_item(config.tracked_responses, config.item_bank)
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank)
end

function ranked_items(config::StatefulCatConfig)
Expand Down
2 changes: 1 addition & 1 deletion src/decision_tree/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function generate_dt_cat(config::DecisionTreeGenerationConfig, item_bank)
while true
track!(responses, config.ability_tracker)
ability = config.ability_estimator(responses)
next_item = config.next_item(responses, item_bank)
next_item = best_item(config.next_item, responses, item_bank)

insert!(decision_tree_result, responses.responses, ability, next_item)
if state_tree.cur_depth == state_tree.max_depth
Expand Down
7 changes: 4 additions & 3 deletions src/next_item_rules/NextItemRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,18 @@ export RandomNextItemRule
export ExhaustiveSearch
export catr_next_item_aliases
export preallocate
export compute_criteria
export compute_criteria, compute_criterion, compute_multi_criterion,
compute_pointwise_criterion
export best_item
export PointResponseExpectation, DistributionResponseExpectation
export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer
export AbilityCovarianceStateCriteria, StateCriteria, ItemCriteria
export AbilityCovarianceStateMultiCriterion, StateMultiCriterion, ItemMultiCriterion
export InformationMatrixCriteria
export ScalarizedStateCriteron, ScalarizedItemCriteron

# Prelude
include("./prelude/abstract.jl")
include("./prelude/next_item_rule.jl")
include("./prelude/strategy.jl")
include("./prelude/criteria.jl")
include("./prelude/preallocate.jl")

Expand Down
8 changes: 5 additions & 3 deletions src/next_item_rules/combinators/expectation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,16 @@ function init_thread(::ExpectationBasedItemCriterion, responses::TrackedResponse
end

function _generic_criterion(criterion::StateCriterion, tracked_responses, item_idx)
criterion(tracked_responses)
compute_criterion(criterion, tracked_responses)
end
# TODO: Support init_thread for wrapped ItemCriterion
function _generic_criterion(criterion::ItemCriterion, tracked_responses, item_idx)
criterion(tracked_responses, item_idx)
compute_criterion(criterion, tracked_responses, item_idx)
end

function (item_criterion::ExpectationBasedItemCriterion)(speculator::Speculator,
function compute_criterion(
item_criterion::ExpectationBasedItemCriterion,
speculator::Speculator,
tracked_responses::TrackedResponses,
item_idx)
exp_resp = Aggregators.response_expectation(item_criterion.response_expectation,
Expand Down
45 changes: 23 additions & 22 deletions src/next_item_rules/combinators/scalarizers.jl
Original file line number Diff line number Diff line change
@@ -1,57 +1,58 @@
struct DeterminantScalarizer <: MatrixScalarizer end
(::DeterminantScalarizer)(mat) = det(mat)
scalarize(::DeterminantScalarizer, mat) = det(mat)

struct TraceScalarizer <: MatrixScalarizer end
(::TraceScalarizer)(mat) = tr(mat)
scalarize(::TraceScalarizer, mat) = tr(mat)

struct ScalarizedItemCriteron{
ItemCriteriaT <: ItemCriteria,
ItemMultiCriterionT <: ItemMultiCriterion,
MatrixScalarizerT <: MatrixScalarizer
} <: ItemCriterion
criteria::ItemCriteriaT
criteria::ItemMultiCriterionT
scalarizer::MatrixScalarizerT
end

function (ssc::ScalarizedItemCriteron)(tracked_responses, item_idx)
res = ssc.criteria(
init_thread(ssc.criteria, tracked_responses), tracked_responses, item_idx) |>
ssc.scalarizer
if !should_minimize(ssc.criteria)
res = -res
end
res
end

struct ScalarizedStateCriteron{
StateCriteriaT <: StateCriteria,
StateMultiCriterionT <: StateMultiCriterion,
MatrixScalarizerT <: MatrixScalarizer
} <: StateCriterion
criteria::StateCriteriaT
criteria::StateMultiCriterionT
scalarizer::MatrixScalarizerT
end

function (ssc::ScalarizedStateCriteron)(tracked_responses)
res = ssc.criteria(tracked_responses) |> ssc.scalarizer
function compute_criterion(ssc::Union{ScalarizedItemCriteron, ScalarizedStateCriteron},
tracked_responses::TrackedResponses, item_idx...)
res = scalarize(
ssc.scalarizer,
compute_multi_criterion(
ssc.criteria,
init_thread(ssc.criteria, tracked_responses),
tracked_responses,
item_idx...
)
)
if !should_minimize(ssc.criteria)
res = -res
end
res
end

struct WeightedStateCriteria{InnerT <: StateCriteria} <: StateCriteria
struct WeightedStateMultiCriterion{InnerT <: StateMultiCriterion} <: StateMultiCriterion
weights::Vector{Float64}
criteria::InnerT
end

function (wsc::WeightedStateCriteria)(tracked_responses, item_idx)
function compute_multi_criterion(
wsc::WeightedStateMultiCriterion, tracked_responses::TrackedResponses, item_idx)
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
end

struct WeightedItemCriteria{InnerT <: ItemCriteria} <: ItemCriteria
struct WeightedItemMultiCriterion{InnerT <: ItemMultiCriterion} <: ItemMultiCriterion
weights::Vector{Float64}
criteria::InnerT
end

function (wsc::WeightedItemCriteria)(tracked_responses, item_idx)
function compute_multi_criterion(
wsc::WeightedItemMultiCriterion, tracked_responses::TrackedResponses, item_idx)
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
end
9 changes: 6 additions & 3 deletions src/next_item_rules/criteria/item/information.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ function InformationItemCriterion(ability_estimator)
InformationItemCriterion(ability_estimator, expected_item_information)
end

function (item_criterion::InformationItemCriterion)(tracked_responses::TrackedResponses,
function compute_criterion(
item_criterion::InformationItemCriterion, tracked_responses::TrackedResponses,
item_idx)
ability = maybe_tracked_ability_estimate(tracked_responses,
item_criterion.ability_estimator)
ir = ItemResponse(tracked_responses.item_bank, item_idx)
return -item_criterion.expected_item_information(ir, ability)
end

struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: ItemCriteria
struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <:
ItemMultiCriterion
ability_estimator::AbilityEstimatorT
expected_item_information::F
end
Expand All @@ -35,7 +37,8 @@ function init_thread(item_criterion::InformationMatrixCriteria,
responses_information(responses.item_bank, responses.responses, ability)
end

function (item_criterion::InformationMatrixCriteria)(acc_info::Matrix{Float64},
function compute_multi_criterion(
item_criterion::InformationMatrixCriteria, acc_info::Matrix{Float64},
tracked_responses::TrackedResponses,
item_idx)
# TODO: Add in information from the prior
Expand Down
3 changes: 2 additions & 1 deletion src/next_item_rules/criteria/item/urry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ function raw_difficulty(item_bank, item_idx)
item_params(item_bank, item_idx).difficulty
end

function (item_criterion::UrryItemCriterion)(tracked_responses::TrackedResponses, item_idx)
function compute_criterion(
item_criterion::UrryItemCriterion, tracked_responses::TrackedResponses, item_idx)
ability = maybe_tracked_ability_estimate(tracked_responses,
item_criterion.ability_estimator)
diff = raw_difficulty(tracked_responses.item_bank, item_idx)
Expand Down
28 changes: 17 additions & 11 deletions src/next_item_rules/criteria/state/ability_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,20 @@ function AbilityVarianceStateCriterion(bits...)
return AbilityVarianceStateCriterion(dist_est, integrator, skip_zero)
end

function (criterion::AbilityVarianceStateCriterion)(tracked_responses::TrackedResponses)::Float64
function compute_criterion(criterion::AbilityVarianceStateCriterion,
tracked_responses::TrackedResponses)::Float64
# XXX: Not sure if the estimator should come from somewhere else here
denom = normdenom(criterion.integrator,
criterion.dist_est,
tracked_responses)
if denom == 0.0 && criterion.skip_zero
return Inf
end
criterion(DomainType(tracked_responses.item_bank), tracked_responses, denom)
compute_criterion(
criterion, DomainType(tracked_responses.item_bank), tracked_responses, denom)
end

function (criterion::AbilityVarianceStateCriterion)(
function compute_criterion(criterion::AbilityVarianceStateCriterion,
::Union{OneDimContinuousDomain, DiscreteDomain},
tracked_responses::TrackedResponses,
denom)::Float64
Expand All @@ -59,9 +61,12 @@ function (criterion::AbilityVarianceStateCriterion)(
)
end

function (criterion::AbilityVarianceStateCriterion)(::Vector,
function compute_criterion(
criterion::AbilityVarianceStateCriterion,
::Vector,
tracked_responses::TrackedResponses,
denom)::Float64
denom
)::Float64
# XXX: Not quite sure about this --- is it useful, the MIRT rules cover this case
mean = expectation(IntegralCoeffs.id,
ndims(tracked_responses.item_bank),
Expand All @@ -77,25 +82,26 @@ function (criterion::AbilityVarianceStateCriterion)(::Vector,
denom)
end

struct AbilityCovarianceStateCriteria{
struct AbilityCovarianceStateMultiCriterion{
DistEstT <: DistributionAbilityEstimator,
IntegratorT <: AbilityIntegrator
} <: StateCriteria
} <: StateMultiCriterion
dist_est::DistEstT
integrator::IntegratorT
skip_zero::Bool
end

function AbilityCovarianceStateCriteria(bits...)
function AbilityCovarianceStateMultiCriterion(bits...)
skip_zero = false
@requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...)
return AbilityCovarianceStateCriteria(dist_est, integrator, skip_zero)
return AbilityCovarianceStateMultiCriterion(dist_est, integrator, skip_zero)
end

# XXX: Should be at type level
should_minimize(::AbilityCovarianceStateCriteria) = true
should_minimize(::AbilityCovarianceStateMultiCriterion) = true

function (criteria::AbilityCovarianceStateCriteria)(
function compute_multi_criterion(
criteria::AbilityCovarianceStateMultiCriterion,
tracked_responses::TrackedResponses,
denom = normdenom(criteria.integrator,
criteria.dist_est,
Expand Down
25 changes: 21 additions & 4 deletions src/next_item_rules/prelude/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,32 @@ $(TYPEDEF)

Abstract base type for all item selection rules. All descendants of this type
are expected to implement the interface
`(rule::NextItemRule)(responses::TrackedResponses, items::AbstractItemBank)::Int`
`(::NextItemRule)(responses::TrackedResponses, items::AbstractItemBank)::Int`.

In practice, all adaptive rules in this package use `ItemStrategyNextItemRule`.

$(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true)

Implicit constructor for $(FUNCTIONNAME). Uses any given `NextItemRule` or
delegates to `ItemStrategyNextItemRule`.
delegates to `ItemStrategyNextItemRule` the default instance.
"""
abstract type NextItemRule <: CatConfigBase end

"""
$(TYPEDEF)

Abstract type for next item strategies, tightly coupled with `ItemStrategyNextItemRule`.
All descendants of this type are expected to implement the interface
`(rule::ItemStrategyNextItemRule{::NextItemStrategy, ::ItemCriterion})(responses::TrackedResponses,
items) where {ItemCriterionT <: }
`(strategy::NextItemStrategy)(; parallel=true)::NextItemStrategy`
"""
abstract type NextItemStrategy <: CatConfigBase end

"""
$(TYPEDEF)

Abstract type for next item criteria
"""
abstract type ItemCriterion <: CatConfigBase end

Expand All @@ -28,6 +38,13 @@ $(TYPEDEF)
"""
abstract type StateCriterion <: CatConfigBase end

"""
$(TYPEDEF)
"""
abstract type PointwiseItemCriterion <: CatConfigBase end

abstract type PurePointwiseItemCriterion <: PointwiseItemCriterion end

abstract type MatrixScalarizer end
abstract type StateCriteria end
abstract type ItemCriteria end
abstract type StateMultiCriterion end
abstract type ItemMultiCriterion end
Loading
Loading