diff --git a/Project.toml b/Project.toml index 437f5216..09a91c2f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.6" +version = "0.6.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/algorithms/subsampledobjective.jl b/src/algorithms/subsampledobjective.jl index b71e7ab4..a12ba66e 100644 --- a/src/algorithms/subsampledobjective.jl +++ b/src/algorithms/subsampledobjective.jl @@ -25,12 +25,22 @@ function init( adtype::ADTypes.AbstractADType, q_init, prob, - params, - restructure, + ::Any, + ::Any, ) (; objective, subsampling) = subobj sub_st = init(rng, subsampling) - obj_st = AdvancedVI.init(rng, objective, adtype, q_init, prob, params, restructure) + + # This is necessary to ensure that `init` sees the type "conditioned" on a minibatch + # when calling `DifferentiationInterface.prepare_*` inside it. + batch, _, _ = step(rng, subsampling, sub_st, true) + prob_sub = subsample(prob, batch) + q_init_sub = subsample(q_init, batch) + params_sub, re_sub = Optimisers.destructure(q_init_sub) + + obj_st = AdvancedVI.init( + rng, objective, adtype, q_init_sub, prob_sub, params_sub, re_sub + ) return SubsampledObjectiveState(prob, sub_st, obj_st) end @@ -66,7 +76,7 @@ function estimate_gradient!( (; prob, sub_st, obj_st) = state q = restructure(params) - batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) + batch, sub_st′, sub_inf = step(rng, subsampling, sub_st, true) prob_sub = subsample(prob, batch) q_sub = subsample(q, batch) params_sub, re_sub = Optimisers.destructure(q_sub) diff --git a/src/reshuffling.jl b/src/reshuffling.jl index a1ffdece..e0e50cfb 100644 --- a/src/reshuffling.jl +++ b/src/reshuffling.jl @@ -20,7 +20,9 @@ struct ReshufflingBatchSubsamplingState{It} iterator::It end -Base.length(sub::ReshufflingBatchSubsampling) = ceil(Int, length(sub.dataset)/sub.batchsize) +function Base.length(sub::ReshufflingBatchSubsampling) + return ceil(Int, length(sub.dataset) / sub.batchsize) +end function reshuffle_batches(rng::Random.AbstractRNG, sub::ReshufflingBatchSubsampling) (; dataset, batchsize) = sub @@ -37,15 +39,22 @@ function step( rng::Random.AbstractRNG, sub::ReshufflingBatchSubsampling, state::ReshufflingBatchSubsamplingState, + drop_trailing_batch_if_too_small::Bool=false, ) (; epoch, iterator) = state - (sub_step, batch), batch_it′ = Iterators.peel(iterator) - epoch′, iterator′′ = if isempty(batch_it′) - epoch + 1, reshuffle_batches(rng, sub) - else - epoch, batch_it′ + (sub_step, batch), iterator = Iterators.peel(iterator) + if isempty(iterator) + iterator = reshuffle_batches(rng, sub) + if drop_trailing_batch_if_too_small && length(batch) < sub.batchsize + # Ignore the trailing batch if its size is smaller than `batchsize`. + # This should only be used when estimating gradients during optimization. + # This is necessary to ensure that all batches have the same size. + # Otherwise, `DifferentiationInterface.prepare_*` behaves incorrectly. + (sub_step, batch), iterator = Iterators.peel(iterator) + end + epoch = epoch + 1 end info = (epoch=epoch, step=sub_step) - state′ = ReshufflingBatchSubsamplingState(epoch′, iterator′′) - return batch, state′, info + state = ReshufflingBatchSubsamplingState(epoch, iterator) + return batch, state, info end