569569cleanup (caller:: CALLBACK ) = nothing
570570cleanup (caller:: LikelihoodsLog ) = begin
571571 CUDA. unsafe_free! (caller. mars_mem)
572+ end
573+
574+ # early stopping
575+ mutable struct EarlyStopPC <: CALLBACK
576+ likelihoods_log
577+ patience
578+ warmup
579+ val
580+
581+ best_value
582+ best_iter
583+ best_bpc
584+
585+ n_increase
586+ iter
587+ EarlyStopPC (likelihoods_log; patience, warmup= 1 , val= :valid_x ) = begin
588+ @assert val == :valid_x
589+ @assert patience % likelihoods_log. iter == 0
590+ @assert ! isnothing (likelihoods_log. valid_x)
591+ new (likelihoods_log, Int (ceil (patience / likelihoods_log. iter)),
592+ warmup, val, - Inf , 0 , nothing , 0 , 0 )
593+ end
594+ end
595+
596+ init (caller:: EarlyStopPC ; args... ) = begin
597+ init (caller. likelihoods_log; args... )
598+ bpc = caller. likelihoods_log. bpc
599+ best_bpc = (edge_layers_up = deepcopy (bpc. edge_layers_up), heap = deepcopy (bpc. heap))
600+ caller. best_bpc = best_bpc
601+ end
602+
603+ call (caller:: EarlyStopPC , epoch, log_likelihood) = begin
604+ valid_ll, test_ll = call (caller. likelihoods_log, epoch, log_likelihood)
605+ caller. iter += 1
606+ flag = false
607+ if isnothing (valid_ll) || caller. iter < caller. warmup
608+ flag = false
609+ elseif valid_ll >= caller. best_value
610+ caller. n_increase = 0
611+ caller. best_value = valid_ll
612+ caller. best_iter = epoch
613+ copy_bpc! (caller. best_bpc, caller. likelihoods_log. bpc)
614+ flag = false
615+ elseif valid_ll < caller. best_value
616+ caller. n_increase += 1
617+ if caller. n_increase > caller. patience
618+ copy_bpc! (caller. likelihoods_log. bpc, caller. best_bpc)
619+ flag = true
620+ else
621+ flag = false
622+ end
623+ else
624+ error (" " )
625+ end
626+ return flag
627+ end
628+
629+ copy_bpc! (dst, src) = begin
630+ copyto! (dst. edge_layers_up. vectors, src. edge_layers_up. vectors)
631+ copyto! (dst. heap, src. heap)
632+ end
633+
634+ cleanup (caller:: EarlyStopPC ) = begin
635+ cleanup (caller. likelihoods_log)
572636end
0 commit comments