1- export estimate_parameters, uniform_parameters, estimate_parameters_em
1+ export estimate_parameters, uniform_parameters, estimate_parameters_em, estimate_parameters_cached!
22
33using StatsFuns: logsumexp, logaddexp
44using CUDA
@@ -81,9 +81,18 @@ function estimate_parameters_cached!(pc::SharedProbCircuit, bc, params, componen
8181 end
8282 nothing
8383end
84+ function estimate_parameters_cached! (pc:: ProbCircuit , pbc; exp_update_factor = 0.0 )
85+ if isgpu (pbc)
86+ pbc = to_cpu (pbc)
87+ end
88+ bc = pbc. bitcircuit
89+ params = pbc. params
90+ estimate_parameters_cached! (pc, bc, params; exp_update_factor)
91+ end
8492function estimate_parameters_cached! (pc:: ProbCircuit , bc, params; exp_update_factor = 0.0 )
8593 log_exp_factor = log (exp_update_factor)
8694 log_1_exp_factor = log (1.0 - exp_update_factor)
95+
8796 foreach_reset (pc) do pn
8897 if is⋁gate (pn)
8998 if num_children (pn) == 1
@@ -98,6 +107,7 @@ function estimate_parameters_cached!(pc::ProbCircuit, bc, params; exp_update_fac
98107 end
99108 end
100109 end
110+
101111 nothing
102112end
103113
415425Expectation maximization parameter learning given missing data
416426"""
417427function estimate_parameters_em (pc:: ProbCircuit , data; pseudocount:: Float64 ,
418- use_sample_weights:: Bool = true , use_gpu:: Bool = false , reuse:: Bool = false ,
419- reuse_v = nothing , reuse_f = nothing , reuse_counts = nothing ,
420- exp_update_factor = 0.0 )
428+ use_sample_weights:: Bool = true , use_gpu:: Bool = false ,
429+ exp_update_factor = 0.0 , update_per_batch:: Bool = false )
430+ if update_per_batch && isbatched (data)
431+ estimate_parameters_em_per_batch (pc, data; pseudocount, use_sample_weights, use_gpu,
432+ exp_update_factor)
433+ end
434+
421435 if isweighted (data)
422436 # `data' is weighted according to its `weight' column
423437 data, weights = split_sample_weights (data)
@@ -432,37 +446,105 @@ function estimate_parameters_em(pc::ProbCircuit, data; pseudocount::Float64,
432446 data = to_gpu (data)
433447 end
434448
435- if reuse_counts === nothing
436- reuse_counts = use_gpu ? (nothing , nothing ) : (nothing , nothing , nothing , nothing , nothing )
437- end
438-
439449 params = if use_gpu
440450 if ! isgpu (data)
441451 data = to_gpu (data)
442452 end
443453 if use_sample_weights
444- estimate_parameters_gpu (to_gpu (pbc), data, pseudocount; weights, reuse, reuse_v, reuse_f, reuse_counts )
454+ estimate_parameters_gpu (to_gpu (pbc), data, pseudocount; weights)
445455 else
446- estimate_parameters_gpu (to_gpu (pbc), data, pseudocount; reuse, reuse_v, reuse_f, reuse_counts )
456+ estimate_parameters_gpu (to_gpu (pbc), data, pseudocount)
447457 end
448458 else
449459 if use_sample_weights
450- estimate_parameters_cpu (pbc, data, pseudocount; weights, reuse, reuse_v, reuse_f, reuse_counts )
460+ estimate_parameters_cpu (pbc, data, pseudocount; weights)
451461 else
452- estimate_parameters_cpu (pbc, data, pseudocount; reuse, reuse_v, reuse_f, reuse_counts )
462+ estimate_parameters_cpu (pbc, data, pseudocount)
453463 end
454464 end
455- if reuse
456- params, v, f, reuse_counts = params
457- end
458465
459466 estimate_parameters_cached! (pc, pbc. bitcircuit, params; exp_update_factor)
460467
461- if reuse
462- params, v, f, reuse_counts
468+ params
469+ end
470+ function estimate_parameters_em_per_batch (pc:: ProbCircuit , data; pseudocount:: Float64 ,
471+ use_sample_weights:: Bool = true , use_gpu:: Bool = false ,
472+ exp_update_factor = 0.0 )
473+ if isgpu (data)
474+ use_gpu = true
475+ elseif use_gpu && ! isgpu (data)
476+ data = to_gpu (data)
477+ end
478+
479+ pbc = ParamBitCircuit (pc, data; reset= false )
480+ if use_gpu
481+ pbc = to_gpu (pbc)
482+ end
483+
484+ reuse_v, reuse_f = nothing , nothing
485+ reuse_counts = use_gpu ? (nothing , nothing ) : (nothing , nothing , nothing , nothing , nothing )
486+
487+ for idx = 1 : length (data)
488+ pbc, reuse_v, reuse_f, reuse_counts = estimate_parameters_em (pbc, data[idx]; pseudocount, use_gpu, reuse_v, reuse_f, reuse_counts, exp_update_factor)
489+ end
490+
491+ estimate_parameters_cached! (pc, pbc)
492+ end
493+ function estimate_parameters_em (pbc:: ParamBitCircuit , data; pseudocount:: Float64 ,
494+ use_sample_weights:: Bool = true , use_gpu:: Bool = false ,
495+ reuse_v = nothing , reuse_f = nothing , reuse_counts = nothing ,
496+ exp_update_factor = 0.0 )
497+ if isweighted (data)
498+ # `data' is weighted according to its `weight' column
499+ data, weights = split_sample_weights (data)
463500 else
464- params
501+ use_sample_weights = false
465502 end
503+
504+ if isgpu (data)
505+ use_gpu = true
506+ elseif use_gpu && ! isgpu (data)
507+ data = to_gpu (data)
508+ end
509+
510+ if reuse_counts === nothing
511+ reuse_counts = use_gpu ? (nothing , nothing ) : (nothing , nothing , nothing , nothing , nothing )
512+ end
513+
514+ params, v, f, reuse_counts = if use_gpu
515+ if ! isgpu (pbc)
516+ pbc. bitcircuit = to_gpu (pbc. bitcircuit)
517+ pbc. params = to_gpu (pbc. params)
518+ end
519+ if use_sample_weights
520+ estimate_parameters_gpu (pbc, data, pseudocount; weights, reuse = true , reuse_v, reuse_f, reuse_counts)
521+ else
522+ estimate_parameters_gpu (pbc, data, pseudocount; reuse = true , reuse_v, reuse_f, reuse_counts)
523+ end
524+ else
525+ if use_sample_weights
526+ estimate_parameters_cpu (pbc, data, pseudocount; weights, reuse = true , reuse_v, reuse_f, reuse_counts)
527+ else
528+ estimate_parameters_cpu (pbc, data, pseudocount; reuse = true , reuse_v, reuse_f, reuse_counts)
529+ end
530+ end
531+
532+ # Update the parameters to `pbc`
533+ if use_gpu # GPU
534+ tempparam = Vector {Float64} (undef, length (params))
535+ tempparam .= to_cpu (params)
536+ @inbounds @views pbc. params .+ = log (exp_update_factor)
537+ @inbounds @views params .+ = log (1.0 - exp_update_factor)
538+ delta = @inbounds @views @. CUDA. ifelse (pbc. params == params, CUDA. zero (params), CUDA. abs (pbc. params - params))
539+ @inbounds @views @. pbc. params = CUDA. max (pbc. params, params) + CUDA. log1p (CUDA. exp (- delta))
540+
541+ CUDA. unsafe_free! (params)
542+ CUDA. unsafe_free! (delta)
543+ else # CPU
544+ @inbounds @views pbc. params = logaddexp .(pbc. params .+ log (exp_update_factor), params .+ log (1.0 - exp_update_factor))
545+ end
546+
547+ pbc, v, f, reuse_counts
466548end
467549
468550function estimate_parameters_cpu (pbc:: ParamBitCircuit , data, pseudocount; weights = nothing , reuse:: Bool = false ,
@@ -602,15 +684,21 @@ function estimate_parameters_cpu(pbc::ParamBitCircuit, data, pseudocount; weight
602684 v, f = marginal_flows (pbc, data, reuse_v, reuse_f; on_node, on_edge, weights)
603685 end
604686
687+ # `edge_counts` now becomes "params"
605688 @simd for i = 1 : num_elements (bc)
606- @inbounds params[i] = log ((edge_counts[i] + pseudocount / num_elements (bc. nodes, bc. elements[1 , i])) / (parent_node_counts[i] + pseudocount))
689+ num_els = num_elements (bc. nodes, bc. elements[1 , i])
690+ if num_els == 1
691+ @inbounds edge_counts[i] = zero (eltype (edge_counts)) # log(1)
692+ else
693+ @inbounds edge_counts[i] = log ((edge_counts[i] + pseudocount / num_elements (bc. nodes, bc. elements[1 , i])) / (parent_node_counts[i] + pseudocount))
694+ end
607695 end
608696
609697 if reuse
610698 # Also return the allocated vars v, f, counts for future reuse
611- params , v, f, (node_counts, edge_counts, parent_node_counts, buffer, log_weights)
699+ edge_counts , v, f, (node_counts, edge_counts, parent_node_counts, buffer, log_weights)
612700 else
613- params # a.k.a. log_probs
701+ edge_counts # a.k.a. log_probs
614702 end
615703end
616704
@@ -687,6 +775,7 @@ function estimate_parameters_gpu(pbc::ParamBitCircuit, data, pseudocount; weight
687775 if weights != nothing
688776 weights = to_gpu (weights)
689777 end
778+
690779 v, f = marginal_flows (pbc, data, reuse_v, reuse_f; on_node, on_edge, weights)
691780 end
692781
@@ -696,6 +785,11 @@ function estimate_parameters_gpu(pbc::ParamBitCircuit, data, pseudocount; weight
696785 @inbounds parent_elcount = bc. nodes[2 ,parents] .- bc. nodes[1 ,parents] .+ 1
697786 params = log .((edge_counts .+ (pseudocount ./ parent_elcount))
698787 ./ (parent_counts .+ pseudocount))
788+ params = @inbounds @views @. ifelse (parent_elcount == 1 , zero (params), params)
789+
790+ CUDA. unsafe_free! (parents)
791+ CUDA. unsafe_free! (parent_counts)
792+ CUDA. unsafe_free! (parent_elcount)
699793
700794 # Only free the memory if the reuse memory is not provided
701795 if ! reuse
@@ -707,7 +801,7 @@ function estimate_parameters_gpu(pbc::ParamBitCircuit, data, pseudocount; weight
707801
708802 if reuse
709803 # Also return the allocated vars v, f, counts for future reuse
710- to_cpu ( params) , v, f, (node_counts, edge_counts)
804+ params, v, f, (node_counts, edge_counts)
711805 else
712806 to_cpu (params) # a.k.a. log_probs
713807 end
0 commit comments