Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ authors = ["Tor Erlend Fjelde <tor.erlend95@gmail.com> and contributors"]
version = "0.3.1"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
Expand Down
24 changes: 10 additions & 14 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@ using TuringCallbacks
using Documenter

makedocs(;
modules=[TuringCallbacks],
authors="Tor",
repo="https://github.com/TuringLang/TuringCallbacks.jl/blob/{commit}{path}#L{line}",
sitename="TuringCallbacks.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://turinglang.github.io/TuringCallbacks.jl",
assets=String[],
modules = [TuringCallbacks],
authors = "Tor",
repo = "https://github.com/TuringLang/TuringCallbacks.jl/blob/{commit}{path}#L{line}",
sitename = "TuringCallbacks.jl",
format = Documenter.HTML(;
prettyurls = get(ENV, "CI", "false") == "true",
canonical = "https://turinglang.github.io/TuringCallbacks.jl",
assets = String[],
),
pages=[
"Home" => "index.md",
],
pages = ["Home" => "index.md"],
)

deploydocs(;
repo="github.com/TuringLang/TuringCallbacks.jl",
)
deploydocs(; repo = "github.com/TuringLang/TuringCallbacks.jl")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all unrelated to the proposed changes, no?

8 changes: 6 additions & 2 deletions ext/TuringCallbacksTuringExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ end
const TuringTransition = Union{Turing.Inference.Transition,Turing.Inference.HMCTransition}

function TuringCallbacks.params_and_values(transition::TuringTransition; kwargs...)
return Iterators.map(zip(Turing.Inference._params_to_array([transition])...)) do (ksym, val)
return Iterators.map(
zip(Turing.Inference._params_to_array([transition])...),
) do (ksym, val)
return string(ksym), val
end
end

function TuringCallbacks.extras(transition::TuringTransition; kwargs...)
return Iterators.map(zip(Turing.Inference.get_transition_extras([transition])...)) do (ksym, val)
return Iterators.map(
zip(Turing.Inference.get_transition_extras([transition])...),
) do (ksym, val)
return string(ksym), val
end
end
Expand Down
12 changes: 9 additions & 3 deletions src/TuringCallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ using Reexport
using LinearAlgebra
using Logging
using DocStringExtensions
using DynamicPPL: Model, Sampler, AbstractVarInfo, invlink!!
using CSV: write
using Random: AbstractRNG

@reexport using OnlineStats # used to compute different statistics on-the-fly

Expand All @@ -17,16 +20,19 @@ using DataStructures: DefaultDict
using Requires
end

export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback
export DefaultDict, WindowStat, Thin, Skip, TensorBoardCallback, MultiCallback, SaveCSV

include("stats.jl")
include("tensorboardlogger.jl")
include("callbacks/tensorboard.jl")
include("callbacks/multicallback.jl")
include("callbacks/save.jl")
include("callbacks/tensorboard.jl")

@static if !isdefined(Base, :get_extension)
function __init__()
@require Turing="fce5fe82-541a-59a6-adf8-730c64b5f9a0" include("../ext/TuringCallbacksTuringExt.jl")
@require Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" include(
"../ext/TuringCallbacksTuringExt.jl",
)
end
end

Expand Down
3 changes: 2 additions & 1 deletion src/callbacks/multicallback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ MultiCallback(callbacks...) = MultiCallback(callbacks)
Add a callback to the list of callbacks, mutating if possible.
"""
push!!(c::MultiCallback{<:Tuple}, callback) = MultiCallback((c.callbacks..., callback))
push!!(c::MultiCallback{<:AbstractArray}, callback) = (push!(c.callbacks, callback); return c)
push!!(c::MultiCallback{<:AbstractArray}, callback) =
(push!(c.callbacks, callback); return c)
37 changes: 37 additions & 0 deletions src/callbacks/save.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
###############################
### Saves samples on the go ###
###############################

"""
SaveCSV

A callback saves samples to .csv file during sampling
"""
function SaveCSV(
rng::AbstractRNG,
model::Model,
sampler::Sampler,
transition,
state,
iteration::Int64;
kwargs...,
)
SaveCSV(rng, model, sampler, transition, state.vi, iteration; kwargs...)
end

function SaveCSV(
rng::AbstractRNG,
model::Model,
sampler::Sampler,
transition,
vi::AbstractVarInfo,
iteration::Int64;
kwargs...,
)
vii = deepcopy(vi)
invlink!!(vii, model)
θ = vii[sampler]
# it would be good to have the param names as in the chain
chain_name = get(kwargs, :chain_name, "chain")
write(string(chain_name, ".csv"), Dict("params" => [θ]); append = true, delim = ";")
end
39 changes: 29 additions & 10 deletions src/callbacks/tensorboard.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function TensorBoardCallback(args...; comment = "", directory = nothing, kwargs.
end

# Set up the logger
lg = TBLogger(log_dir, min_level=Logging.Info; step_increment=0)
lg = TBLogger(log_dir, min_level = Logging.Info; step_increment = 0)

return TensorBoardCallback(lg, args...; kwargs...)
end
Expand All @@ -87,27 +87,35 @@ function TensorBoardCallback(
filter = nothing,
param_prefix::String = "",
extras_prefix::String = "extras/",
kwargs...
kwargs...,
)
# Lookups: create default ones if not given
stats_lookup = if stats isa OnlineStat
# Warn the user if they've provided a non-empty `OnlineStat`
OnlineStats.nobs(stats) > 0 && @warn("using statistic with observations as a base: $(stats)")
OnlineStats.nobs(stats) > 0 &&
@warn("using statistic with observations as a base: $(stats)")
let o = stats
DefaultDict{String, typeof(o)}(() -> deepcopy(o))
DefaultDict{String,typeof(o)}(() -> deepcopy(o))
end
elseif !isnothing(stats)
# If it's not an `OnlineStat` nor `nothing`, assume user knows what they're doing
stats
else
# This is default
let o = OnlineStats.Series(Mean(), Variance(), KHist(num_bins))
DefaultDict{String, typeof(o)}(() -> deepcopy(o))
DefaultDict{String,typeof(o)}(() -> deepcopy(o))
end
end

return TensorBoardCallback(
lg, stats_lookup, filter, include, exclude, include_extras, param_prefix, extras_prefix
lg,
stats_lookup,
filter,
include,
exclude,
include_extras,
param_prefix,
extras_prefix,
)
end

Expand All @@ -133,7 +141,8 @@ function filter_param_and_value(cb::TensorBoardCallback, param, value)
# Otherwise we return `true` by default.
return true
end
filter_param_and_value(cb::TensorBoardCallback, param_and_value::Tuple) = filter_param_and_value(cb, param_and_value...)
filter_param_and_value(cb::TensorBoardCallback, param_and_value::Tuple) =
filter_param_and_value(cb, param_and_value...)

"""
default_param_names_for_values(x)
Expand All @@ -150,7 +159,8 @@ default_param_names_for_values(x) = ("θ[$i]" for i = 1:length(x))
Return an iterator over parameter names and values from a `transition`.
"""
params_and_values(transition, state; kwargs...) = params_and_values(transition; kwargs...)
params_and_values(model, sampler, transition, state; kwargs...) = params_and_values(transition, state; kwargs...)
params_and_values(model, sampler, transition, state; kwargs...) =
params_and_values(transition, state; kwargs...)

"""
extras(transition[, state]; kwargs...)
Expand All @@ -167,14 +177,23 @@ extras(model, sampler, transition, state; kwargs...) = extras(transition, state;
increment_step!(lg::TensorBoardLogger.TBLogger, Δ_Step) =
TensorBoardLogger.increment_step!(lg, Δ_Step)

function (cb::TensorBoardCallback)(rng, model, sampler, transition, state, iteration; kwargs...)
function (cb::TensorBoardCallback)(
rng,
model,
sampler,
transition,
state,
iteration;
kwargs...,
)
stats = cb.stats
lg = cb.logger
filterf = Base.Fix1(filter_param_and_value, cb)

# TODO: Should we use the explicit interface for TensorBoardLogger?
with_logger(lg) do
for (k, val) in Iterators.filter(filterf, params_and_values(transition, state; kwargs...))
for (k, val) in
Iterators.filter(filterf, params_and_values(transition, state; kwargs...))
stat = stats[k]

# Log the raw value
Expand Down
29 changes: 12 additions & 17 deletions src/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ $(TYPEDEF)

Skips the first `b` observations before passing them on to `stat`.
"""
mutable struct Skip{T, O<:OnlineStat{T}} <: OnlineStat{T}
mutable struct Skip{T,O<:OnlineStat{T}} <: OnlineStat{T}
b::Int
current_index::Int
stat::O
Expand All @@ -29,10 +29,8 @@ function OnlineStats._fit!(o::Skip, x::Real)
return o
end

Base.show(io::IO, o::Skip) = print(
io,
"Skip ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`"
)
Base.show(io::IO, o::Skip) =
print(io, "Skip ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`")

"""
$(TYPEDEF)
Expand All @@ -43,7 +41,7 @@ $(TYPEDEF)

Thins `stat` with an interval `b`, i.e. only passes every b-th observation to `stat`.
"""
mutable struct Thin{T, O<:OnlineStat{T}} <: OnlineStat{T}
mutable struct Thin{T,O<:OnlineStat{T}} <: OnlineStat{T}
b::Int
current_index::Int
stat::O
Expand All @@ -62,10 +60,8 @@ function OnlineStats._fit!(o::Thin, x::Real)
return o
end

Base.show(io::IO, o::Thin) = print(
io,
"Thin ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`"
)
Base.show(io::IO, o::Thin) =
print(io, "Thin ($(o.b)): current_index=$(o.current_index) | stat=$(o.stat)`")

"""
$(TYPEDEF)
Expand All @@ -80,27 +76,26 @@ $(TYPEDEF)
`stat`, which is *only* fitted on the batched data contained in the `MovingWindow`.

"""
struct WindowStat{T, O} <: OnlineStat{T}
struct WindowStat{T,O} <: OnlineStat{T}
window::MovingWindow{T}
stat::O
end

WindowStat(b::Int, T::Type, o) = WindowStat{T, typeof(o)}(MovingWindow(b, T), o)
WindowStat(b::Int, o::OnlineStat{T}) where {T} = WindowStat{T, typeof(o)}(
MovingWindow(b, T), o
)
WindowStat(b::Int, T::Type, o) = WindowStat{T,typeof(o)}(MovingWindow(b, T), o)
WindowStat(b::Int, o::OnlineStat{T}) where {T} =
WindowStat{T,typeof(o)}(MovingWindow(b, T), o)

# Proxy methods to the window
OnlineStats.nobs(o::WindowStat) = OnlineStats.nobs(o.window)
OnlineStats._fit!(o::WindowStat, x) = OnlineStats._fit!(o.window, x)

function OnlineStats.value(o::WindowStat{<:Any, <:OnlineStat})
function OnlineStats.value(o::WindowStat{<:Any,<:OnlineStat})
stat_new = deepcopy(o.stat)
fit!(stat_new, OnlineStats.value(o.window))
return stat_new
end

function OnlineStats.value(o::WindowStat{<:Any, <:Function})
function OnlineStats.value(o::WindowStat{<:Any,<:Function})
stat_new = o.stat()
fit!(stat_new, OnlineStats.value(o.window))
return stat_new
Expand Down
19 changes: 10 additions & 9 deletions src/tensorboardlogger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ end

function TBL.preprocess(name, stat::AutoCov, data)
autocor = OnlineStats.autocor(stat)
for b = 1:(stat.lag.b - 1)
for b = 1:(stat.lag.b-1)
# `autocor[i]` corresponds to the lag of size `i - 1` and `autocor[1] = 1.0`
bname = tb_name(stat, b)
TBL.preprocess(tb_name(name, bname), autocor[b + 1], data)
TBL.preprocess(tb_name(name, bname), autocor[b+1], data)
end
end

Expand All @@ -60,22 +60,23 @@ function TBL.preprocess(name, hist::KHist, data)
# Creates a NORMALIZED histogram
edges = OnlineStats.edges(hist)
cnts = OnlineStats.counts(hist)
TBL.preprocess(
name, (edges, cnts ./ sum(cnts)), data
)
TBL.preprocess(name, (edges, cnts ./ sum(cnts)), data)
end
end

# Unlike the `preprocess` overload, this allows us to specify if we want to normalize
function TBL.log_histogram(
logger::AbstractLogger, name::AbstractString, hist::OnlineStats.HistogramStat;
step=nothing, normalize=false
logger::AbstractLogger,
name::AbstractString,
hist::OnlineStats.HistogramStat;
step = nothing,
normalize = false,
)
edges = edges(hist)
cnts = Float64.(OnlineStats.counts(hist))
if normalize
return TBL.log_histogram(logger, name, (edges, cnts ./ sum(cnts)); step=step)
return TBL.log_histogram(logger, name, (edges, cnts ./ sum(cnts)); step = step)
else
return TBL.log_histogram(logger, name, (edges, cnts); step=step)
return TBL.log_histogram(logger, name, (edges, cnts); step = step)
end
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Expand Down
Loading