@@ -334,7 +334,11 @@ function full_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
334334 batch_size, pseudocount, softness = 0 , report_ll = true ,
335335 mars_mem = nothing , flows_mem = nothing , node_aggr_mem = nothing ,
336336 edge_aggr_mem = nothing , mine= 2 , maxe= 32 , debug = false , verbose = true ,
337- valid_x= nothing , test_x= nothing , log_iter= 10 )
337+ callbacks = [])
338+
339+ insert! (callbacks, 1 , FullBatchLog (verbose))
340+ callbacks = CALLBACKList (callbacks)
341+ init (callbacks; batch_size, bpc)
338342
339343 num_nodes = length (bpc. nodes)
340344 num_edges = length (bpc. edge_layers_down. vectors)
@@ -344,9 +348,6 @@ function full_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
344348 flows = prep_memory (flows_mem, (batch_size, num_nodes), (false , true ))
345349 node_aggr = prep_memory (node_aggr_mem, (num_nodes,))
346350 edge_aggr = prep_memory (edge_aggr_mem, (num_edges,))
347- if ! isnothing (valid_x) || ! isnothing (test_x)
348- valid_marginals = prep_memory (nothing , (batch_size, num_nodes), (false , true ))
349- end
350351
351352 log_likelihoods = Vector {Float32} ()
352353
@@ -356,26 +357,13 @@ function full_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
356357 marginals, flows, node_aggr, edge_aggr,
357358 mine, maxe, debug)
358359 push! (log_likelihoods, log_likelihood)
359-
360- if verbose
361- print (" Full-batch EM epoch $epoch ; train LL $log_likelihood " )
362- if epoch % log_iter == 0 && (! isnothing (valid_x) || ! isnothing (test_x))
363- if ! isnothing (valid_x)
364- print (" ; valid LL " , loglikelihood (bpc, valid_x; batch_size, mine, maxe, mars_mem= valid_marginals))
365- end
366- if ! isnothing (test_x)
367- print (" ; test LL " , loglikelihood (bpc, test_x; batch_size, mine, maxe, mars_mem= valid_marginals))
368- end
369- end
370- println ()
371- end
360+ call (callbacks, epoch, log_likelihood)
372361 end
373362
374363 cleanup_memory ((data, raw_data), (flows, flows_mem),
375364 (node_aggr, node_aggr_mem), (edge_aggr, edge_aggr_mem))
376- if ! isnothing (valid_x) || ! isnothing (test_x)
377- CUDA. unsafe_free! (valid_marginals)
378- end
365+ cleanup (callbacks)
366+
379367 log_likelihoods
380368end
381369
@@ -396,14 +384,18 @@ function mini_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
396384 softness = 0 , shuffle= :each_epoch ,
397385 mars_mem = nothing , flows_mem = nothing , node_aggr_mem = nothing , edge_aggr_mem = nothing ,
398386 mine = 2 , maxe = 32 , debug = false , verbose = true ,
399- valid_x = nothing , test_x = nothing , log_iter = 10 )
387+ callbacks = [] )
400388
401389 @assert pseudocount >= 0
402390 @assert 0 <= param_inertia <= 1
403391 @assert param_inertia <= param_inertia_end <= 1
404392 @assert 0 <= flow_memory
405393 @assert flow_memory <= flow_memory_end
406394 @assert shuffle ∈ [:once , :each_epoch , :each_batch ]
395+
396+ insert! (callbacks, 1 , MiniBatchLog (verbose))
397+ callbacks = CALLBACKList (callbacks)
398+ init (callbacks; batch_size, bpc)
407399
408400 num_examples = size (raw_data)[1 ]
409401 num_nodes = length (bpc. nodes)
@@ -418,9 +410,6 @@ function mini_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
418410 flows = prep_memory (flows_mem, (batch_size, num_nodes), (false , true ))
419411 node_aggr = prep_memory (node_aggr_mem, (num_nodes,))
420412 edge_aggr = prep_memory (edge_aggr_mem, (num_edges,))
421- if ! isnothing (valid_x) || ! isnothing (test_x)
422- valid_marginals = prep_memory (nothing , (batch_size, num_nodes), (false , true ))
423- end
424413
425414 edge_aggr .= zero (Float32)
426415 clear_input_node_mem (bpc; rate = 0 , debug)
@@ -478,18 +467,7 @@ function mini_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
478467 end
479468 log_likelihood = sum (log_likelihoods_epoch) / batch_size / num_batches
480469 push! (log_likelihoods, log_likelihood)
481- if verbose
482- print (" Mini-batch EM epoch $epoch ; train LL $log_likelihood " )
483- if epoch % log_iter == 0 && (! isnothing (valid_x) || ! isnothing (test_x))
484- if ! isnothing (valid_x)
485- print (" ; valid LL " , loglikelihood (bpc, valid_x; batch_size, mine, maxe, mars_mem= valid_marginals))
486- end
487- if ! isnothing (test_x)
488- print (" ; test LL " , loglikelihood (bpc, test_x; batch_size, mine, maxe, mars_mem= valid_marginals))
489- end
490- end
491- println ()
492- end
470+ call (callbacks, epoch, log_likelihood)
493471
494472 param_inertia += Δparam_inertia
495473 flow_memory += Δflow_memory
@@ -499,9 +477,85 @@ function mini_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
499477 (node_aggr, node_aggr_mem), (edge_aggr, edge_aggr_mem))
500478 CUDA. unsafe_free! (shuffled_indices)
501479
502- if ! isnothing (valid_x) || ! isnothing (test_x)
503- CUDA. unsafe_free! (valid_marginals)
504- end
480+ cleanup (callbacks)
505481
506482 log_likelihoods
507483end
484+
485+ abstract type CALLBACK end
486+
487+ struct CALLBACKList
488+ list:: Vector{CALLBACK}
489+ end
490+
491+ function call (callbacks:: CALLBACKList , epoch, log_likelihood)
492+ if callbacks. list[1 ]. verbose
493+ for x in callbacks. list
494+ call (x, epoch, log_likelihood)
495+ end
496+ println ()
497+ end
498+ end
499+
500+ function init (callbacks:: CALLBACKList ; kwargs... )
501+ for x in callbacks. list
502+ init (x; kwargs... )
503+ end
504+ end
505+
506+ function cleanup (callbacks:: CALLBACKList )
507+ for x in callbacks. list
508+ cleanup (x)
509+ end
510+ end
511+
512+ struct MiniBatchLog <: CALLBACK
513+ verbose
514+ end
515+
516+ struct FullBatchLog <: CALLBACK
517+ verbose
518+ end
519+
520+ mutable struct LikelihoodsLog <: CALLBACK
521+ valid_x
522+ test_x
523+ iter
524+ bpc
525+ batch_size
526+ mars_mem
527+ LikelihoodsLog (valid_x, test_x, iter) = begin
528+ new (valid_x, test_x, iter, nothing , nothing , nothing )
529+ end
530+ end
531+
532+ init (caller:: CALLBACK ; kwargs... ) = nothing
533+ init (caller:: LikelihoodsLog ; bpc, batch_size) = begin
534+ caller. bpc = bpc
535+ caller. batch_size = batch_size
536+ caller. mars_mem = prep_memory (nothing , (batch_size, length (bpc. nodes)), (false , true ))
537+ end
538+
539+ call (caller:: MiniBatchLog , epoch, log_likelihood) = begin
540+ caller. verbose && print (" Mini-batch EM epoch $epoch ; train LL $log_likelihood " )
541+ end
542+ call (caller:: FullBatchLog , epoch, log_likelihood) = begin
543+ caller. verbose && print (" Full-batch EM epoch $epoch ; train LL $log_likelihood " )
544+ end
545+ call (caller:: LikelihoodsLog , epoch, log_likelihood) = begin
546+ if epoch % caller. iter == 0 && (! isnothing (caller. valid_x) || ! isnothing (caller. test_x))
547+ if ! isnothing (caller. valid_x)
548+ print (" ; valid LL " , loglikelihood (caller. bpc, caller. valid_x;
549+ batch_size= caller. batch_size,mars_mem= caller. mars_mem))
550+ end
551+ if ! isnothing (caller. test_x)
552+ print (" ; test LL " , loglikelihood (caller. bpc, caller. test_x;
553+ batch_size= caller. batch_size,mars_mem= caller. mars_mem))
554+ end
555+ end
556+ end
557+
558+ cleanup (caller:: CALLBACK ) = nothing
559+ cleanup (caller:: LikelihoodsLog ) = begin
560+ CUDA. unsafe_free! (caller. mars_mem)
561+ end
0 commit comments