Skip to content

Commit 24097c7

Browse files
committed
em minibatch update & low GPU util bug fix
1 parent bb012c1 commit 24097c7

File tree

2 files changed

+117
-23
lines changed

2 files changed

+117
-23
lines changed

src/param_bit_circuit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export ParamBitCircuit
22

33
"A `BitCircuit` with parameters attached to the elements"
4-
struct ParamBitCircuit{V,M,W}
4+
mutable struct ParamBitCircuit{V,M,W}
55
bitcircuit::BitCircuit{V,M}
66
params::W
77
end

src/parameters.jl

Lines changed: 116 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export estimate_parameters, uniform_parameters, estimate_parameters_em
1+
export estimate_parameters, uniform_parameters, estimate_parameters_em, estimate_parameters_cached!
22

33
using StatsFuns: logsumexp, logaddexp
44
using CUDA
@@ -81,9 +81,18 @@ function estimate_parameters_cached!(pc::SharedProbCircuit, bc, params, componen
8181
end
8282
nothing
8383
end
84+
function estimate_parameters_cached!(pc::ProbCircuit, pbc; exp_update_factor = 0.0)
85+
if isgpu(pbc)
86+
pbc = to_cpu(pbc)
87+
end
88+
bc = pbc.bitcircuit
89+
params = pbc.params
90+
estimate_parameters_cached!(pc, bc, params; exp_update_factor)
91+
end
8492
function estimate_parameters_cached!(pc::ProbCircuit, bc, params; exp_update_factor = 0.0)
8593
log_exp_factor = log(exp_update_factor)
8694
log_1_exp_factor = log(1.0 - exp_update_factor)
95+
8796
foreach_reset(pc) do pn
8897
if is⋁gate(pn)
8998
if num_children(pn) == 1
@@ -98,6 +107,7 @@ function estimate_parameters_cached!(pc::ProbCircuit, bc, params; exp_update_fac
98107
end
99108
end
100109
end
110+
101111
nothing
102112
end
103113

@@ -415,9 +425,13 @@ end
415425
Expectation maximization parameter learning given missing data
416426
"""
417427
function estimate_parameters_em(pc::ProbCircuit, data; pseudocount::Float64,
418-
use_sample_weights::Bool = true, use_gpu::Bool = false, reuse::Bool = false,
419-
reuse_v = nothing, reuse_f = nothing, reuse_counts = nothing,
420-
exp_update_factor = 0.0)
428+
use_sample_weights::Bool = true, use_gpu::Bool = false,
429+
exp_update_factor = 0.0, update_per_batch::Bool = false)
430+
if update_per_batch && isbatched(data)
431+
estimate_parameters_em_per_batch(pc, data; pseudocount, use_sample_weights, use_gpu,
432+
exp_update_factor)
433+
end
434+
421435
if isweighted(data)
422436
# `data' is weighted according to its `weight' column
423437
data, weights = split_sample_weights(data)
@@ -432,37 +446,105 @@ function estimate_parameters_em(pc::ProbCircuit, data; pseudocount::Float64,
432446
data = to_gpu(data)
433447
end
434448

435-
if reuse_counts === nothing
436-
reuse_counts = use_gpu ? (nothing, nothing) : (nothing, nothing, nothing, nothing, nothing)
437-
end
438-
439449
params = if use_gpu
440450
if !isgpu(data)
441451
data = to_gpu(data)
442452
end
443453
if use_sample_weights
444-
estimate_parameters_gpu(to_gpu(pbc), data, pseudocount; weights, reuse, reuse_v, reuse_f, reuse_counts)
454+
estimate_parameters_gpu(to_gpu(pbc), data, pseudocount; weights)
445455
else
446-
estimate_parameters_gpu(to_gpu(pbc), data, pseudocount; reuse, reuse_v, reuse_f, reuse_counts)
456+
estimate_parameters_gpu(to_gpu(pbc), data, pseudocount)
447457
end
448458
else
449459
if use_sample_weights
450-
estimate_parameters_cpu(pbc, data, pseudocount; weights, reuse, reuse_v, reuse_f, reuse_counts)
460+
estimate_parameters_cpu(pbc, data, pseudocount; weights)
451461
else
452-
estimate_parameters_cpu(pbc, data, pseudocount; reuse, reuse_v, reuse_f, reuse_counts)
462+
estimate_parameters_cpu(pbc, data, pseudocount)
453463
end
454464
end
455-
if reuse
456-
params, v, f, reuse_counts = params
457-
end
458465

459466
estimate_parameters_cached!(pc, pbc.bitcircuit, params; exp_update_factor)
460467

461-
if reuse
462-
params, v, f, reuse_counts
468+
params
469+
end
470+
function estimate_parameters_em_per_batch(pc::ProbCircuit, data; pseudocount::Float64,
471+
use_sample_weights::Bool = true, use_gpu::Bool = false,
472+
exp_update_factor = 0.0)
473+
if isgpu(data)
474+
use_gpu = true
475+
elseif use_gpu && !isgpu(data)
476+
data = to_gpu(data)
477+
end
478+
479+
pbc = ParamBitCircuit(pc, data; reset=false)
480+
if use_gpu
481+
pbc = to_gpu(pbc)
482+
end
483+
484+
reuse_v, reuse_f = nothing, nothing
485+
reuse_counts = use_gpu ? (nothing, nothing) : (nothing, nothing, nothing, nothing, nothing)
486+
487+
for idx = 1 : length(data)
488+
pbc, reuse_v, reuse_f, reuse_counts = estimate_parameters_em(pbc, data[idx]; pseudocount, use_gpu, reuse_v, reuse_f, reuse_counts, exp_update_factor)
489+
end
490+
491+
estimate_parameters_cached!(pc, pbc)
492+
end
493+
function estimate_parameters_em(pbc::ParamBitCircuit, data; pseudocount::Float64,
494+
use_sample_weights::Bool = true, use_gpu::Bool = false,
495+
reuse_v = nothing, reuse_f = nothing, reuse_counts = nothing,
496+
exp_update_factor = 0.0)
497+
if isweighted(data)
498+
# `data' is weighted according to its `weight' column
499+
data, weights = split_sample_weights(data)
463500
else
464-
params
501+
use_sample_weights = false
465502
end
503+
504+
if isgpu(data)
505+
use_gpu = true
506+
elseif use_gpu && !isgpu(data)
507+
data = to_gpu(data)
508+
end
509+
510+
if reuse_counts === nothing
511+
reuse_counts = use_gpu ? (nothing, nothing) : (nothing, nothing, nothing, nothing, nothing)
512+
end
513+
514+
params, v, f, reuse_counts = if use_gpu
515+
if !isgpu(pbc)
516+
pbc.bitcircuit = to_gpu(pbc.bitcircuit)
517+
pbc.params = to_gpu(pbc.params)
518+
end
519+
if use_sample_weights
520+
estimate_parameters_gpu(pbc, data, pseudocount; weights, reuse = true, reuse_v, reuse_f, reuse_counts)
521+
else
522+
estimate_parameters_gpu(pbc, data, pseudocount; reuse = true, reuse_v, reuse_f, reuse_counts)
523+
end
524+
else
525+
if use_sample_weights
526+
estimate_parameters_cpu(pbc, data, pseudocount; weights, reuse = true, reuse_v, reuse_f, reuse_counts)
527+
else
528+
estimate_parameters_cpu(pbc, data, pseudocount; reuse = true, reuse_v, reuse_f, reuse_counts)
529+
end
530+
end
531+
532+
# Update the parameters to `pbc`
533+
if use_gpu # GPU
534+
tempparam = Vector{Float64}(undef, length(params))
535+
tempparam .= to_cpu(params)
536+
@inbounds @views pbc.params .+= log(exp_update_factor)
537+
@inbounds @views params .+= log(1.0 - exp_update_factor)
538+
delta = @inbounds @views @. CUDA.ifelse(pbc.params == params, CUDA.zero(params), CUDA.abs(pbc.params - params))
539+
@inbounds @views @. pbc.params = CUDA.max(pbc.params, params) + CUDA.log1p(CUDA.exp(-delta))
540+
541+
CUDA.unsafe_free!(params)
542+
CUDA.unsafe_free!(delta)
543+
else # CPU
544+
@inbounds @views pbc.params = logaddexp.(pbc.params .+ log(exp_update_factor), params .+ log(1.0 - exp_update_factor))
545+
end
546+
547+
pbc, v, f, reuse_counts
466548
end
467549

468550
function estimate_parameters_cpu(pbc::ParamBitCircuit, data, pseudocount; weights = nothing, reuse::Bool = false,
@@ -602,15 +684,21 @@ function estimate_parameters_cpu(pbc::ParamBitCircuit, data, pseudocount; weight
602684
v, f = marginal_flows(pbc, data, reuse_v, reuse_f; on_node, on_edge, weights)
603685
end
604686

687+
# `edge_counts` now becomes "params"
605688
@simd for i = 1 : num_elements(bc)
606-
@inbounds params[i] = log((edge_counts[i] + pseudocount / num_elements(bc.nodes, bc.elements[1, i])) / (parent_node_counts[i] + pseudocount))
689+
num_els = num_elements(bc.nodes, bc.elements[1, i])
690+
if num_els == 1
691+
@inbounds edge_counts[i] = zero(eltype(edge_counts)) # log(1)
692+
else
693+
@inbounds edge_counts[i] = log((edge_counts[i] + pseudocount / num_elements(bc.nodes, bc.elements[1, i])) / (parent_node_counts[i] + pseudocount))
694+
end
607695
end
608696

609697
if reuse
610698
# Also return the allocated vars v, f, counts for future reuse
611-
params, v, f, (node_counts, edge_counts, parent_node_counts, buffer, log_weights)
699+
edge_counts, v, f, (node_counts, edge_counts, parent_node_counts, buffer, log_weights)
612700
else
613-
params # a.k.a. log_probs
701+
edge_counts # a.k.a. log_probs
614702
end
615703
end
616704

@@ -687,6 +775,7 @@ function estimate_parameters_gpu(pbc::ParamBitCircuit, data, pseudocount; weight
687775
if weights != nothing
688776
weights = to_gpu(weights)
689777
end
778+
690779
v, f = marginal_flows(pbc, data, reuse_v, reuse_f; on_node, on_edge, weights)
691780
end
692781

@@ -696,6 +785,11 @@ function estimate_parameters_gpu(pbc::ParamBitCircuit, data, pseudocount; weight
696785
@inbounds parent_elcount = bc.nodes[2,parents] .- bc.nodes[1,parents] .+ 1
697786
params = log.((edge_counts .+ (pseudocount ./ parent_elcount))
698787
./ (parent_counts .+ pseudocount))
788+
params = @inbounds @views @. ifelse(parent_elcount == 1, zero(params), params)
789+
790+
CUDA.unsafe_free!(parents)
791+
CUDA.unsafe_free!(parent_counts)
792+
CUDA.unsafe_free!(parent_elcount)
699793

700794
# Only free the memory if the reuse memory is not provided
701795
if !reuse
@@ -707,7 +801,7 @@ function estimate_parameters_gpu(pbc::ParamBitCircuit, data, pseudocount; weight
707801

708802
if reuse
709803
# Also return the allocated vars v, f, counts for future reuse
710-
to_cpu(params), v, f, (node_counts, edge_counts)
804+
params, v, f, (node_counts, edge_counts)
711805
else
712806
to_cpu(params) # a.k.a. log_probs
713807
end

0 commit comments

Comments
 (0)