@@ -357,7 +357,10 @@ function full_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
357357 marginals, flows, node_aggr, edge_aggr,
358358 mine, maxe, debug)
359359 push! (log_likelihoods, log_likelihood)
360- call (callbacks, epoch, log_likelihood)
360+ done = call (callbacks, epoch, log_likelihood)
361+ if ! isnothing (done) && done[end ] == true
362+ break
363+ end
361364 end
362365
363366 cleanup_memory ((data, raw_data), (flows, flows_mem),
@@ -467,10 +470,13 @@ function mini_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
467470 end
468471 log_likelihood = sum (log_likelihoods_epoch) / batch_size / num_batches
469472 push! (log_likelihoods, log_likelihood)
470- call (callbacks, epoch, log_likelihood)
473+ done = call (callbacks, epoch, log_likelihood)
471474
472475 param_inertia += Δparam_inertia
473476 flow_memory += Δflow_memory
477+ if ! isnothing (done) && done[end ] == true
478+ break
479+ end
474480 end
475481
476482 cleanup_memory ((data, raw_data), (flows, flows_mem),
@@ -490,10 +496,11 @@ end
490496
491497function call (callbacks:: CALLBACKList , epoch, log_likelihood)
492498 if callbacks. list[1 ]. verbose
493- for x in callbacks. list
499+ done = map ( callbacks. list) do x
494500 call (x, epoch, log_likelihood)
495501 end
496502 println ()
503+ done
497504 end
498505end
499506
@@ -543,16 +550,20 @@ call(caller::FullBatchLog, epoch, log_likelihood) = begin
543550 caller. verbose && print (" Full-batch EM epoch $epoch ; train LL $log_likelihood " )
544551end
545552call (caller:: LikelihoodsLog , epoch, log_likelihood) = begin
553+ valid_ll, test_ll = nothing , nothing
546554 if epoch % caller. iter == 0 && (! isnothing (caller. valid_x) || ! isnothing (caller. test_x))
547555 if ! isnothing (caller. valid_x)
548- print (" ; valid LL " , loglikelihood (caller. bpc, caller. valid_x;
549- batch_size= caller. batch_size,mars_mem= caller. mars_mem))
556+ valid_ll = loglikelihood (caller. bpc, caller. valid_x;
557+ batch_size= caller. batch_size,mars_mem= caller. mars_mem)
558+ print (" ; valid LL " , valid_ll)
550559 end
551560 if ! isnothing (caller. test_x)
552- print (" ; test LL " , loglikelihood (caller. bpc, caller. test_x;
553- batch_size= caller. batch_size,mars_mem= caller. mars_mem))
561+ test_ll = loglikelihood (caller. bpc, caller. test_x;
562+ batch_size= caller. batch_size,mars_mem= caller. mars_mem)
563+ print (" ; test LL " , test_ll)
554564 end
555565 end
566+ valid_ll, test_ll
556567end
557568
558569cleanup (caller:: CALLBACK ) = nothing
0 commit comments