-
Notifications
You must be signed in to change notification settings - Fork 3
Update for newer AbstractMCMC/Turing interface #49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,5 +4,5 @@ | |
| /docs/Manifest.toml | ||
| /docs/build/ | ||
| **/*~ | ||
| Maniest.toml | ||
| test/Manifest.toml | ||
| Manifest.toml | ||
| test/Manifest.toml | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,35 +6,6 @@ using Random | |
| using SliceSampling | ||
| using Turing | ||
|
|
||
| # Required for using the slice samplers as `externalsampler`s in Turing | ||
| # begin | ||
| function Turing.Inference.getparams( | ||
| ::Turing.DynamicPPL.Model, sample::SliceSampling.Transition | ||
| ) | ||
| return sample.params | ||
| end | ||
| # end | ||
|
|
||
| # Required for using the slice samplers as `Gibbs` samplers in Turing | ||
| # begin | ||
| Turing.Inference.isgibbscomponent(::SliceSampling.RandPermGibbs) = true | ||
| Turing.Inference.isgibbscomponent(::SliceSampling.HitAndRun) = true | ||
| Turing.Inference.isgibbscomponent(::SliceSampling.Slice) = true | ||
| Turing.Inference.isgibbscomponent(::SliceSampling.SliceSteppingOut) = true | ||
| Turing.Inference.isgibbscomponent(::SliceSampling.SliceDoublingOut) = true | ||
|
Comment on lines
-20
to
-24
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| const SliceSamplingStates = Union{ | ||
| SliceSampling.UnivariateSliceState, | ||
| SliceSampling.GibbsState, | ||
| SliceSampling.HitAndRunState, | ||
| SliceSampling.LatentSliceState, | ||
| SliceSampling.GibbsPolarSliceState, | ||
| } | ||
| function Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSamplingStates) | ||
| return sample.transition.params | ||
| end | ||
| # end | ||
|
|
||
| function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction) | ||
| n_max_attempts = 1000 | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| module SliceSampling | ||
|
|
||
| using AbstractMCMC | ||
| using Accessors: Accessors | ||
| using Distributions | ||
| using LinearAlgebra | ||
| using LogDensityProblems | ||
|
|
@@ -37,6 +38,22 @@ struct Transition{P,L<:Real,I<:NamedTuple} | |
| info::I | ||
| end | ||
|
|
||
| """ | ||
| abstract type AbstractStateWithTransition | ||
|
|
||
| Base type for MCMC states that contain a `Transition` stored in the `transition` field. | ||
| """ | ||
| abstract type AbstractStateWithTransition end | ||
| AbstractMCMC.getparams(state::AbstractStateWithTransition) = state.transition.params | ||
| AbstractMCMC.getstats(state::AbstractStateWithTransition) = state.transition.info | ||
| function AbstractMCMC.setparams!!( | ||
| model::AbstractMCMC.LogDensityModel, state::AbstractStateWithTransition, params | ||
| ) | ||
| new_lp = LogDensityProblems.logdensity(model.logdensity, params) | ||
| new_transition = Transition(params, new_lp, NamedTuple()) | ||
| return Accessors.@set state.transition = new_transition | ||
| end | ||
|
Comment on lines
+46
to
+55
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the definitions of these functions are the same for all states in this package, I thought it would be cleaner to just define the behaviour on an abstract type. It does necessitate an extra dep on Accessors, but that's fairly lightweight. |
||
|
|
||
| """ | ||
| initial_sample(rng, model) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes to the external sampler interface are described in https://github.com/TuringLang/Turing.jl/releases/tag/v0.42.0 -- the general aim is that you should not need to overload Turing internal functions (getparams was actually not exported afaik) and shifting this to AbstractMCMC means that it's easier for other packages to make use of this info.
Turing.Inference.getparamsis gone now, it's replaced withAbstractMCMC.getparams(but called on the state instead of the transition).