Skip to content

Commit fe10122

Browse files
authored
Merge pull request #124 from Juice-jl/em-log
Em log and data softening
2 parents e063c82 + 0c9fe8a commit fe10122

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

src/nodes/indicator_dist.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ bits(d::Indicator, _ = nothing) = d
2323
unbits(d::Indicator, _ = nothing) = d
2424

2525
loglikelihood(d::Indicator, value, _ = nothing) =
26-
(d.value == value) ? zero(Float32) : -Inf32
26+
if value isa AbstractFloat && d isa Literal
27+
(d.value) ? log(value) : log1p(-value)
28+
else
29+
(d.value == value) ? zero(Float32) : -Inf32
30+
end
2731

2832
init_params(d::Indicator, _) = d
2933

src/parameters/em.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,10 @@ function full_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
357357
marginals, flows, node_aggr, edge_aggr,
358358
mine, maxe, debug)
359359
push!(log_likelihoods, log_likelihood)
360-
call(callbacks, epoch, log_likelihood)
360+
done = call(callbacks, epoch, log_likelihood)
361+
if !isnothing(done) && done[end] == true
362+
break
363+
end
361364
end
362365

363366
cleanup_memory((data, raw_data), (flows, flows_mem),
@@ -467,10 +470,13 @@ function mini_batch_em(bpc::CuBitsProbCircuit, raw_data::CuArray, num_epochs;
467470
end
468471
log_likelihood = sum(log_likelihoods_epoch) / batch_size / num_batches
469472
push!(log_likelihoods, log_likelihood)
470-
call(callbacks, epoch, log_likelihood)
473+
done = call(callbacks, epoch, log_likelihood)
471474

472475
param_inertia += Δparam_inertia
473476
flow_memory += Δflow_memory
477+
if !isnothing(done) && done[end] == true
478+
break
479+
end
474480
end
475481

476482
cleanup_memory((data, raw_data), (flows, flows_mem),
@@ -490,10 +496,11 @@ end
490496

491497
function call(callbacks::CALLBACKList, epoch, log_likelihood)
492498
if callbacks.list[1].verbose
493-
for x in callbacks.list
499+
done = map(callbacks.list) do x
494500
call(x, epoch, log_likelihood)
495501
end
496502
println()
503+
done
497504
end
498505
end
499506

@@ -543,16 +550,20 @@ call(caller::FullBatchLog, epoch, log_likelihood) = begin
543550
caller.verbose && print("Full-batch EM epoch $epoch; train LL $log_likelihood")
544551
end
545552
call(caller::LikelihoodsLog, epoch, log_likelihood) = begin
553+
valid_ll, test_ll = nothing, nothing
546554
if epoch % caller.iter == 0 && (!isnothing(caller.valid_x) || !isnothing(caller.test_x))
547555
if !isnothing(caller.valid_x)
548-
print("; valid LL ", loglikelihood(caller.bpc, caller.valid_x;
549-
batch_size=caller.batch_size,mars_mem=caller.mars_mem))
556+
valid_ll = loglikelihood(caller.bpc, caller.valid_x;
557+
batch_size=caller.batch_size,mars_mem=caller.mars_mem)
558+
print("; valid LL ", valid_ll)
550559
end
551560
if !isnothing(caller.test_x)
552-
print("; test LL ", loglikelihood(caller.bpc, caller.test_x;
553-
batch_size=caller.batch_size,mars_mem=caller.mars_mem))
561+
test_ll = loglikelihood(caller.bpc, caller.test_x;
562+
batch_size=caller.batch_size,mars_mem=caller.mars_mem)
563+
print("; test LL ", test_ll)
554564
end
555565
end
566+
valid_ll, test_ll
556567
end
557568

558569
cleanup(caller::CALLBACK) = nothing

0 commit comments

Comments
 (0)