Skip to content

Commit b2889fd

Browse files
authored
Merge pull request #123 from Juice-jl/em-log
implement log as call back function
2 parents ae527df + 6fa9bf6 commit b2889fd

File tree

1 file changed

+93
-39
lines changed

1 file changed

+93
-39
lines changed

src/parameters/em.jl

Lines changed: 93 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,11 @@ function full_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
334334
batch_size, pseudocount, softness = 0, report_ll = true,
335335
mars_mem = nothing, flows_mem = nothing, node_aggr_mem = nothing,
336336
edge_aggr_mem = nothing, mine=2, maxe=32, debug = false, verbose = true,
337-
valid_x=nothing, test_x=nothing, log_iter=10)
337+
callbacks = [])
338+
339+
insert!(callbacks, 1, FullBatchLog(verbose))
340+
callbacks = CALLBACKList(callbacks)
341+
init(callbacks; batch_size, bpc)
338342

339343
num_nodes = length(bpc.nodes)
340344
num_edges = length(bpc.edge_layers_down.vectors)
@@ -344,9 +348,6 @@ function full_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
344348
flows = prep_memory(flows_mem, (batch_size, num_nodes), (false, true))
345349
node_aggr = prep_memory(node_aggr_mem, (num_nodes,))
346350
edge_aggr = prep_memory(edge_aggr_mem, (num_edges,))
347-
if !isnothing(valid_x) || !isnothing(test_x)
348-
valid_marginals = prep_memory(nothing, (batch_size, num_nodes), (false, true))
349-
end
350351

351352
log_likelihoods = Vector{Float32}()
352353

@@ -356,26 +357,13 @@ function full_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
356357
marginals, flows, node_aggr, edge_aggr,
357358
mine, maxe, debug)
358359
push!(log_likelihoods, log_likelihood)
359-
360-
if verbose
361-
print("Full-batch EM epoch $epoch; train LL $log_likelihood")
362-
if epoch % log_iter == 0 && (!isnothing(valid_x) || !isnothing(test_x))
363-
if !isnothing(valid_x)
364-
print("; valid LL ", loglikelihood(bpc, valid_x; batch_size, mine, maxe, mars_mem=valid_marginals))
365-
end
366-
if !isnothing(test_x)
367-
print("; test LL ", loglikelihood(bpc, test_x; batch_size, mine, maxe, mars_mem=valid_marginals))
368-
end
369-
end
370-
println()
371-
end
360+
call(callbacks, epoch, log_likelihood)
372361
end
373362

374363
cleanup_memory((data, raw_data), (flows, flows_mem),
375364
(node_aggr, node_aggr_mem), (edge_aggr, edge_aggr_mem))
376-
if !isnothing(valid_x) || !isnothing(test_x)
377-
CUDA.unsafe_free!(valid_marginals)
378-
end
365+
cleanup(callbacks)
366+
379367
log_likelihoods
380368
end
381369

@@ -396,14 +384,18 @@ function mini_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
396384
softness = 0, shuffle=:each_epoch,
397385
mars_mem = nothing, flows_mem = nothing, node_aggr_mem = nothing, edge_aggr_mem = nothing,
398386
mine = 2, maxe = 32, debug = false, verbose = true,
399-
valid_x=nothing, test_x=nothing, log_iter=10)
387+
callbacks = [])
400388

401389
@assert pseudocount >= 0
402390
@assert 0 <= param_inertia <= 1
403391
@assert param_inertia <= param_inertia_end <= 1
404392
@assert 0 <= flow_memory
405393
@assert flow_memory <= flow_memory_end
406394
@assert shuffle [:once, :each_epoch, :each_batch]
395+
396+
insert!(callbacks, 1, MiniBatchLog(verbose))
397+
callbacks = CALLBACKList(callbacks)
398+
init(callbacks; batch_size, bpc)
407399

408400
num_examples = size(raw_data)[1]
409401
num_nodes = length(bpc.nodes)
@@ -418,9 +410,6 @@ function mini_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
418410
flows = prep_memory(flows_mem, (batch_size, num_nodes), (false, true))
419411
node_aggr = prep_memory(node_aggr_mem, (num_nodes,))
420412
edge_aggr = prep_memory(edge_aggr_mem, (num_edges,))
421-
if !isnothing(valid_x) || !isnothing(test_x)
422-
valid_marginals = prep_memory(nothing, (batch_size, num_nodes), (false, true))
423-
end
424413

425414
edge_aggr .= zero(Float32)
426415
clear_input_node_mem(bpc; rate = 0, debug)
@@ -478,18 +467,7 @@ function mini_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
478467
end
479468
log_likelihood = sum(log_likelihoods_epoch) / batch_size / num_batches
480469
push!(log_likelihoods, log_likelihood)
481-
if verbose
482-
print("Mini-batch EM epoch $epoch; train LL $log_likelihood")
483-
if epoch % log_iter == 0 && (!isnothing(valid_x) || !isnothing(test_x))
484-
if !isnothing(valid_x)
485-
print("; valid LL ", loglikelihood(bpc, valid_x; batch_size, mine, maxe, mars_mem=valid_marginals))
486-
end
487-
if !isnothing(test_x)
488-
print("; test LL ", loglikelihood(bpc, test_x; batch_size, mine, maxe, mars_mem=valid_marginals))
489-
end
490-
end
491-
println()
492-
end
470+
call(callbacks, epoch, log_likelihood)
493471

494472
param_inertia += Δparam_inertia
495473
flow_memory += Δflow_memory
@@ -499,9 +477,85 @@ function mini_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
499477
(node_aggr, node_aggr_mem), (edge_aggr, edge_aggr_mem))
500478
CUDA.unsafe_free!(shuffled_indices)
501479

502-
if !isnothing(valid_x) || !isnothing(test_x)
503-
CUDA.unsafe_free!(valid_marginals)
504-
end
480+
cleanup(callbacks)
505481

506482
log_likelihoods
507483
end
484+
485+
abstract type CALLBACK end
486+
487+
struct CALLBACKList
488+
list::Vector{CALLBACK}
489+
end
490+
491+
function call(callbacks::CALLBACKList, epoch, log_likelihood)
492+
if callbacks.list[1].verbose
493+
for x in callbacks.list
494+
call(x, epoch, log_likelihood)
495+
end
496+
println()
497+
end
498+
end
499+
500+
function init(callbacks::CALLBACKList; kwargs...)
501+
for x in callbacks.list
502+
init(x; kwargs...)
503+
end
504+
end
505+
506+
function cleanup(callbacks::CALLBACKList)
507+
for x in callbacks.list
508+
cleanup(x)
509+
end
510+
end
511+
512+
struct MiniBatchLog <: CALLBACK
513+
verbose
514+
end
515+
516+
struct FullBatchLog <: CALLBACK
517+
verbose
518+
end
519+
520+
mutable struct LikelihoodsLog <: CALLBACK
521+
valid_x
522+
test_x
523+
iter
524+
bpc
525+
batch_size
526+
mars_mem
527+
LikelihoodsLog(valid_x, test_x, iter) = begin
528+
new(valid_x, test_x, iter, nothing, nothing, nothing)
529+
end
530+
end
531+
532+
init(caller::CALLBACK; kwargs...) = nothing
533+
init(caller::LikelihoodsLog; bpc, batch_size) = begin
534+
caller.bpc = bpc
535+
caller.batch_size = batch_size
536+
caller.mars_mem = prep_memory(nothing, (batch_size, length(bpc.nodes)), (false, true))
537+
end
538+
539+
call(caller::MiniBatchLog, epoch, log_likelihood) = begin
540+
caller.verbose && print("Mini-batch EM epoch $epoch; train LL $log_likelihood")
541+
end
542+
call(caller::FullBatchLog, epoch, log_likelihood) = begin
543+
caller.verbose && print("Full-batch EM epoch $epoch; train LL $log_likelihood")
544+
end
545+
call(caller::LikelihoodsLog, epoch, log_likelihood) = begin
546+
if epoch % caller.iter == 0 && (!isnothing(caller.valid_x) || !isnothing(caller.test_x))
547+
if !isnothing(caller.valid_x)
548+
print("; valid LL ", loglikelihood(caller.bpc, caller.valid_x;
549+
batch_size=caller.batch_size,mars_mem=caller.mars_mem))
550+
end
551+
if !isnothing(caller.test_x)
552+
print("; test LL ", loglikelihood(caller.bpc, caller.test_x;
553+
batch_size=caller.batch_size,mars_mem=caller.mars_mem))
554+
end
555+
end
556+
end
557+
558+
cleanup(caller::CALLBACK) = nothing
559+
cleanup(caller::LikelihoodsLog) = begin
560+
CUDA.unsafe_free!(caller.mars_mem)
561+
end

0 commit comments

Comments
 (0)