Skip to content

Commit ae94c1d

Browse files
authored
Merge pull request #125 from Juice-jl/callback
early stopping as a call back function
2 parents fe10122 + 7c8cdb3 commit ae94c1d

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

src/parameters/em.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,4 +569,68 @@ end
569569
cleanup(caller::CALLBACK) = nothing
570570
cleanup(caller::LikelihoodsLog) = begin
571571
CUDA.unsafe_free!(caller.mars_mem)
572+
end
573+
574+
# early stopping
575+
mutable struct EarlyStopPC <: CALLBACK
576+
likelihoods_log
577+
patience
578+
warmup
579+
val
580+
581+
best_value
582+
best_iter
583+
best_bpc
584+
585+
n_increase
586+
iter
587+
EarlyStopPC(likelihoods_log; patience, warmup=1, val=:valid_x) = begin
588+
@assert val == :valid_x
589+
@assert patience % likelihoods_log.iter == 0
590+
@assert !isnothing(likelihoods_log.valid_x)
591+
new(likelihoods_log, Int(ceil(patience / likelihoods_log.iter)),
592+
warmup, val, -Inf, 0, nothing, 0, 0)
593+
end
594+
end
595+
596+
init(caller::EarlyStopPC; args...) = begin
597+
init(caller.likelihoods_log; args...)
598+
bpc = caller.likelihoods_log.bpc
599+
best_bpc = (edge_layers_up = deepcopy(bpc.edge_layers_up), heap = deepcopy(bpc.heap))
600+
caller.best_bpc = best_bpc
601+
end
602+
603+
call(caller::EarlyStopPC, epoch, log_likelihood) = begin
604+
valid_ll, test_ll = call(caller.likelihoods_log, epoch, log_likelihood)
605+
caller.iter += 1
606+
flag = false
607+
if isnothing(valid_ll) || caller.iter < caller.warmup
608+
flag = false
609+
elseif valid_ll >= caller.best_value
610+
caller.n_increase = 0
611+
caller.best_value = valid_ll
612+
caller.best_iter = epoch
613+
copy_bpc!(caller.best_bpc, caller.likelihoods_log.bpc)
614+
flag = false
615+
elseif valid_ll < caller.best_value
616+
caller.n_increase += 1
617+
if caller.n_increase > caller.patience
618+
copy_bpc!(caller.likelihoods_log.bpc, caller.best_bpc)
619+
flag = true
620+
else
621+
flag = false
622+
end
623+
else
624+
error("")
625+
end
626+
return flag
627+
end
628+
629+
copy_bpc!(dst, src) = begin
630+
copyto!(dst.edge_layers_up.vectors, src.edge_layers_up.vectors)
631+
copyto!(dst.heap, src.heap)
632+
end
633+
634+
cleanup(caller::EarlyStopPC) = begin
635+
cleanup(caller.likelihoods_log)
572636
end

0 commit comments

Comments
 (0)