diff --git a/Project.toml b/Project.toml index 7af36c4e58..0db429253d 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc" AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 73f190dcc6..1969f7fc95 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -23,6 +23,7 @@ using AbstractMCMC: AbstractModel, AbstractSampler using DocStringExtensions: TYPEDEF, TYPEDFIELDS using DataStructures: OrderedSet using Setfield: Setfield +using CSV import AbstractMCMC import AdvancedHMC; const AHMC = AdvancedHMC @@ -249,7 +250,7 @@ function AbstractMCMC.sample( chain_type=MCMCChains.Chains, progress=PROGRESS[], kwargs... -) +) return AbstractMCMC.mcmcsample(rng, model, sampler, ensemble, N, n_chains; chain_type=chain_type, progress=progress, kwargs...) end @@ -488,6 +489,23 @@ end # Utilities # ############## +function SaveCSV(rng::AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler, + transition::HMCTransition, + state::HMCState, + iteration::Int64; + kwargs... +) + vii = deepcopy(state.vi) + DynamicPPL.invlink!!(vii, model) + θ = vii[sampler] + # it would be good to have the param names as in the chain + chain_name = get(kwargs, :chain_name, "chain") + CSV.write(string(chain_name,".csv"), Dict("params"=>[θ]); + append=true, delim=";") +end + DynamicPPL.getspace(spl::Sampler) = getspace(spl.alg) DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg)) diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 5ba3831481..39b131a042 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -100,7 +100,7 @@ function AbstractMCMC.sample( discard_adapt=true, discard_initial=-1, kwargs... -) +) if resume_from === nothing # If `nadapts` is `-1`, then the user called a convenience # constructor like `NUTS()` or `NUTS(0.65)`,