Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TuringCallbacks"
uuid = "ea0860ee-d0ef-45ef-82e6-cc37d6be2f9c"
authors = ["Tor Erlend Fjelde <tor.erlend95@gmail.com> and contributors"]
version = "0.3.1"
version = "0.4.0"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -26,8 +26,8 @@ DocStringExtensions = "0.8, 0.9"
OnlineStats = "1.5"
Reexport = "0.2, 1.0"
Requires = "1"
TensorBoardLogger = "0.1"
Turing = "0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22"
TensorBoardLogger = "0.1.22"
Turing = "0.29"
julia = "1"

[extras]
Expand Down
65 changes: 54 additions & 11 deletions ext/TuringCallbacksTuringExt.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,69 @@
module TuringCallbacksTuringExt

if isdefined(Base, :get_extension)
using Turing: Turing
using Turing: Turing, DynamicPPL
using TuringCallbacks: TuringCallbacks
else
# Requires compatible.
using ..Turing: Turing
using ..Turing: Turing, DynamicPPL
using ..TuringCallbacks: TuringCallbacks
end

const TuringTransition = Union{Turing.Inference.Transition,Turing.Inference.HMCTransition}
const TuringTransition = Union{
Turing.Inference.Transition,
Turing.Inference.SMCTransition,
Turing.Inference.PGTransition
}

function TuringCallbacks.params_and_values(transition::TuringTransition; kwargs...)
return Iterators.map(zip(Turing.Inference._params_to_array([transition])...)) do (ksym, val)
return string(ksym), val
end
function TuringCallbacks.params_and_values(
model::DynamicPPL.Model,
transition::TuringTransition;
kwargs...
)
vns, vals = Turing.Inference._params_to_array(model, [transition])
return zip(Iterators.map(string, vns), vals)
end

function TuringCallbacks.extras(transition::TuringTransition; kwargs...)
return Iterators.map(zip(Turing.Inference.get_transition_extras([transition])...)) do (ksym, val)
return string(ksym), val
end
function TuringCallbacks.extras(
model::DynamicPPL.Model, transition::TuringTransition;
kwargs...
)
names, vals = Turing.Inference.get_transition_extras([transition])
return zip(string.(names), vec(vals))
end

default_hyperparams(sampler::DynamicPPL.Sampler) = default_hyperparams(sampler.alg)
default_hyperparams(alg::Turing.Inference.InferenceAlgorithm) = (
string(f) => getfield(alg, f) for f in fieldnames(typeof(alg))
)

const AlgsWithDefaultHyperparams = Union{
Turing.Inference.HMC,
Turing.Inference.HMCDA,
Turing.Inference.NUTS,
Turing.Inference.SGHMC,

}

function TuringCallbacks.hyperparams(
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:AlgsWithDefaultHyperparams};
kwargs...
)
return default_hyperparams(sampler)
end

function TuringCallbacks.hyperparam_metrics(
model,
sampler::Turing.Sampler{<:Turing.Inference.NUTS}
)
return [
"extras/acceptance_rate/stat/Mean",
"extras/max_hamiltonian_energy_error/stat/Mean",
"extras/lp/stat/Mean",
"extras/n_steps/stat/Mean",
"extras/tree_depth/stat/Mean"
]
end

end
1 change: 1 addition & 0 deletions src/TuringCallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ end

export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback

include("utils.jl")
include("stats.jl")
include("tensorboardlogger.jl")
include("callbacks/tensorboard.jl")
Expand Down
37 changes: 37 additions & 0 deletions src/callbacks/save.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
###############################
### Saves samples on the go ###
###############################

"""
SaveCSV

A callback saves samples to .csv file during sampling
"""
function SaveCSV(
rng::AbstractRNG,
model::Model,
sampler::Sampler,
transition,
state,
iteration::Int64;
kwargs...,
)
SaveCSV(rng, model, sampler, transition, state.vi, iteration; kwargs...)
end

function SaveCSV(
rng::AbstractRNG,
model::Model,
sampler::Sampler,
transition,
vi::AbstractVarInfo,
iteration::Int64;
kwargs...,
)
vii = deepcopy(vi)
invlink!!(vii, model)
θ = vii[sampler]
# it would be good to have the param names as in the chain
chain_name = get(kwargs, :chain_name, "chain")
write(string(chain_name, ".csv"), Dict("params" => [θ]); append = true, delim = ";")
end
Loading