Skip to content

Commit dac0c6b

Browse files
author
Frankie Robertson
committed
Add watchdog, likelihood sampling, callback skipping to comparison
1 parent e04a19f commit dac0c6b

File tree

2 files changed

+321
-92
lines changed

2 files changed

+321
-92
lines changed
Lines changed: 167 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ export CatComparisonExecutionStrategy, IncreaseItemBankSizeExecutionStrategy
2323
export ReplayResponsesExecutionStrategy
2424
export CatComparisonConfig
2525

26+
include("./watchdog.jl")
27+
2628
struct RandomCatComparison
2729
true_abilities::Array{Float64}
2830
rand_abilities::Array{Float64, 3}
@@ -82,8 +84,7 @@ end
8284

8385
abstract type CatComparisonExecutionStrategy end
8486

85-
struct CatComparisonConfig{
86-
StrategyT <: CatComparisonExecutionStrategy, PhasesT <: NamedTuple}
87+
struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrategy, PhasesT <: NamedTuple}
8788
"""
8889
A named tuple with the (named) CatRules (or compatable) to be compared
8990
"""
@@ -102,13 +103,26 @@ struct CatComparisonConfig{
102103
The phases to run, optionally paired with a callback
103104
"""
104105
phases::PhasesT
106+
"""
107+
Where to sample for likelihood
108+
"""
109+
sample_points::Union{Vector{Float64}, Nothing}
110+
"""
111+
Skips
112+
"""
113+
skip_callback
114+
"""
115+
Watchdog timeout
116+
"""
117+
timeout::Float64
105118
end
106119

107120
"""
108121
CatComparisonConfig(;
109122
rules::NamedTuple{Symbol, StatefulCat},
110123
strategy::CatComparisonExecutionStrategy,
111124
phases::Union{NamedTuple{Symbol, Callable}, Tuple{Symbol}},
125+
skips::Set{Tuple{Symbol, Symbol}},
112126
callback::Callable
113127
) -> CatComparisonConfig
114128
@@ -123,18 +137,24 @@ no callback is provided.
123137
124138
The exact phases depend on the strategy used. See their individual documentation for more.
125139
"""
126-
function CatComparisonConfig(; rules, strategy, phases = nothing, callback = nothing)
140+
function CatComparisonConfig(; rules, strategy, phases = nothing, skip_callback = ((_, _, _) -> false), sample_points = nothing, callback = nothing, timeout = Inf)
127141
if callback === nothing
128142
callback = (info; kwargs...) -> nothing
129143
end
130144
if phases === nothing
131145
phases = (:before_next_item, :after_next_item)
132146
end
133-
# TODO: normalize phases into named tuple
134147
if !(phases isa NamedTuple)
135148
phases = NamedTuple((phase => callback for phase in phases))
136149
end
137-
CatComparisonConfig(rules, strategy, phases)
150+
CatComparisonConfig(
151+
rules,
152+
strategy,
153+
phases,
154+
sample_points,
155+
skip_callback,
156+
timeout
157+
)
138158
end
139159

140160
# Comparison scenarios:
@@ -158,7 +178,6 @@ end
158178

159179
#phase_func=nothing;
160180
function measure_all(comparison, system, cat, phase; kwargs...)
161-
@info "measure_all" phase system kwargs
162181
if !(phase in keys(comparison.phases))
163182
return
164183
end
@@ -273,7 +292,6 @@ function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExec
273292
num_items=size,
274293
system_name=name
275294
)
276-
@info "next_item" name timed_next_item.time strategy.time_limit
277295
if timed_next_item.time < strategy.time_limit
278296
push!(next_current_cats, name => cat)
279297
end
@@ -300,108 +318,165 @@ end
300318

301319
struct ReplayResponsesExecutionStrategy <: CatComparisonExecutionStrategy
302320
responses::BareResponses
321+
time_limit::Float64
322+
end
323+
324+
ReplayResponsesExecutionStrategy(responses) = ReplayResponsesExecutionStrategy(responses, Inf)
325+
326+
function should_run(comparison, name, cat, phase)
327+
return phase in keys(comparison.phases) &&
328+
!comparison.skip_callback(name, cat, phase)
303329
end
304330

305331
# Which questions to ask: Specified
306332
# Which answer to use: From response memory
307333
function run_comparison(comparison::CatComparisonConfig{ReplayResponsesExecutionStrategy})
308334
strategy = comparison.strategy
309-
for (items_answered, response) in zip(
310-
Iterators.countfrom(0), Iterators.flatten((strategy.responses, [nothing])))
311-
for (name, cat) in pairs(comparison.rules)
312-
if :before_item_criteria in comparison.phases
313-
timed_item_criteria = @timed Stateful.item_criteria(cat)
314-
measure_all(
315-
comparison,
316-
name,
317-
cat,
318-
:before_item_criteria,
319-
items_answered = items_answered,
320-
item_criteria = timed_item_criteria.value,
321-
timing = timed_item_criteria
322-
)
323-
end
324-
if :before_ranked_items in comparison.phases
325-
timed_ranked_items = @timed Stateful.ranked_items(cat)
326-
measure_all(
327-
comparison,
328-
name,
329-
cat,
330-
:before_ranked_items,
331-
items_answered = items_answered,
332-
ranked_items = timed_ranked_items.value,
333-
timing = timed_ranked_items
334-
)
335-
end
336-
if :before_ability in comparison.phases
337-
timed_get_ability = @timed Stateful.get_ability(cat)
338-
measure_all(
339-
comparison,
340-
name,
341-
cat,
342-
:before_ability,
343-
items_answered = items_answered,
344-
ability = timed_get_ability.value,
345-
timing = timed_get_ability
346-
)
347-
end
348-
measure_all(
349-
comparison,
350-
name,
351-
cat,
352-
:before_next_item,
353-
items_answered = items_answered
354-
)
355-
timed_next_item = @timed Stateful.next_item(cat)
356-
next_item = timed_next_item.value
357-
measure_all(
358-
comparison,
359-
name,
360-
cat,
361-
:after_next_item,
362-
next_item = next_item,
363-
timing = timed_next_item,
364-
items_answered = items_answered
365-
)
366-
if :after_item_criteria in comparison.phases
367-
# TOOD: Combine with next_item if possible and requested?
368-
timed_item_criteria = @timed Stateful.item_criteria(cat)
369-
measure_all(
370-
comparison,
371-
name,
372-
cat,
373-
:after_item_criteria,
374-
items_answered = items_answered,
375-
item_criteria = timed_item_criteria.value,
376-
timing = timed_item_criteria
377-
)
335+
current_cats = Dict(pairs(comparison.rules))
336+
function check_time(name, timer)
337+
if timer.time >= strategy.time_limit
338+
if name in keys(current_cats)
339+
@info "Time limit exceeded" name timer.time
340+
delete!(current_cats, name)
378341
end
379-
if :after_ranked_items in comparison.phases
380-
timed_ranked_items = @timed Stateful.ranked_items(cat)
342+
end
343+
end
344+
watchdog = WatchdogTask(comparison.timeout)
345+
start!(watchdog) do
346+
for (items_answered, response) in zip(
347+
Iterators.countfrom(0), Iterators.flatten((strategy.responses, [nothing])))
348+
for (name, cat) in pairs(current_cats)
349+
println("")
350+
println("Starting $name for $items_answered")
351+
flush(stdout)
352+
if should_run(comparison, name, cat, :before_item_criteria)
353+
reset!(watchdog, "$name item_criteria")
354+
timed_item_criteria = @timed Stateful.item_criteria(cat)
355+
check_time(name, timed_item_criteria)
356+
measure_all(
357+
comparison,
358+
name,
359+
cat,
360+
:before_item_criteria,
361+
items_answered = items_answered,
362+
item_criteria = timed_item_criteria.value,
363+
timing = timed_item_criteria
364+
)
365+
end
366+
if should_run(comparison, name, cat, :before_ranked_items)
367+
reset!(watchdog, "$name ranked_items")
368+
timed_ranked_items = @timed Stateful.ranked_items(cat)
369+
check_time(name, timed_ranked_items)
370+
measure_all(
371+
comparison,
372+
name,
373+
cat,
374+
:before_ranked_items,
375+
items_answered = items_answered,
376+
ranked_items = timed_ranked_items.value,
377+
timing = timed_ranked_items
378+
)
379+
end
380+
if should_run(comparison, name, cat, :before_ability)
381+
reset!(watchdog, "$name get_ability")
382+
timed_get_ability = @timed Stateful.get_ability(cat)
383+
check_time(name, timed_get_ability)
384+
measure_all(
385+
comparison,
386+
name,
387+
cat,
388+
:before_ability,
389+
items_answered = items_answered,
390+
ability = timed_get_ability.value,
391+
timing = timed_get_ability
392+
)
393+
end
381394
measure_all(
382395
comparison,
383396
name,
384397
cat,
385-
:after_ranked_items,
386-
items_answered = items_answered,
387-
ranked_items = timed_ranked_items.value,
388-
timing = timed_ranked_items
398+
:before_next_item,
399+
items_answered = items_answered
389400
)
390-
end
391-
if :after_ability in comparison.phases
392-
timed_get_ability = @timed Stateful.get_ability(cat)
401+
reset!(watchdog, "$name next_item")
402+
timed_next_item = @timed Stateful.next_item(cat)
403+
check_time(name, timed_next_item)
404+
next_item = timed_next_item.value
393405
measure_all(
394406
comparison,
395407
name,
396408
cat,
397-
:after_ability,
398-
items_answered = items_answered,
399-
ability = timed_get_ability.value,
400-
timing = timed_get_ability
409+
:after_next_item,
410+
next_item = next_item,
411+
timing = timed_next_item,
412+
items_answered = items_answered
401413
)
402-
end
403-
if response !== nothing
404-
Stateful.add_response!(cat, response.index, response.value)
414+
if should_run(comparison, name, cat, :after_item_criteria)
415+
# TOOD: Combine with next_item if possible and requested?
416+
reset!(watchdog, "$name item_criteria")
417+
timed_item_criteria = @timed Stateful.item_criteria(cat)
418+
check_time(name, timed_item_criteria)
419+
if timed_item_criteria.value !== nothing
420+
measure_all(
421+
comparison,
422+
name,
423+
cat,
424+
:after_item_criteria,
425+
items_answered = items_answered,
426+
item_criteria = timed_item_criteria.value,
427+
timing = timed_item_criteria
428+
)
429+
end
430+
end
431+
if should_run(comparison, name, cat, :after_ranked_items)
432+
reset!(watchdog, "$name ranked_items")
433+
timed_ranked_items = @timed Stateful.ranked_items(cat)
434+
check_time(name, timed_ranked_items)
435+
if timed_ranked_items.value !== nothing
436+
measure_all(
437+
comparison,
438+
name,
439+
cat,
440+
:after_ranked_items,
441+
items_answered = items_answered,
442+
ranked_items = timed_ranked_items.value,
443+
timing = timed_ranked_items
444+
)
445+
end
446+
end
447+
if should_run(comparison, name, cat, :after_likelihood)
448+
reset!(watchdog, "$name likelihood")
449+
timed_likelihood = @timed Stateful.likelihood.(Ref(cat), comparison.sample_points)
450+
check_time(name, timed_likelihood)
451+
measure_all(
452+
comparison,
453+
name,
454+
cat,
455+
:after_likelihood,
456+
items_answered = items_answered,
457+
sample_points = comparison.sample_points,
458+
likelihood = timed_likelihood.value,
459+
timing = timed_likelihood
460+
)
461+
462+
end
463+
if should_run(comparison, name, cat, :after_ability)
464+
reset!(watchdog, "$name get_ability")
465+
timed_get_ability = @timed Stateful.get_ability(cat)
466+
check_time(name, timed_get_ability)
467+
measure_all(
468+
comparison,
469+
name,
470+
cat,
471+
:after_ability,
472+
items_answered = items_answered,
473+
ability = timed_get_ability.value,
474+
timing = timed_get_ability
475+
)
476+
end
477+
if response !== nothing
478+
Stateful.add_response!(cat, response.index, response.value)
479+
end
405480
end
406481
end
407482
end

0 commit comments

Comments
 (0)