Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
/docs/Manifest.toml
/docs/build/
**/*~
Maniest.toml
test/Manifest.toml
Manifest.toml
test/Manifest.toml
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "SliceSampling"
uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf"
version = "0.7.10"
version = "0.7.11"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand All @@ -16,12 +17,13 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
SliceSamplingTuringExt = ["Turing"]

[compat]
AbstractMCMC = "4, 5"
AbstractMCMC = "5.9"
Accessors = "0.1"
Distributions = "0.25"
LinearAlgebra = "1"
LogDensityProblems = "2"
Random = "1"
Turing = "0.41"
Turing = "0.42"
julia = "1.10"

[extras]
Expand Down
4 changes: 2 additions & 2 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
AbstractMCMC = "5"
AbstractMCMC = "5.9"
Distributions = "0.25"
Documenter = "1"
FillArrays = "1"
Expand All @@ -28,5 +28,5 @@ Random = "1"
SliceSampling = "0.7.1"
StableRNGs = "1"
Statistics = "1"
Turing = "0.41, 0.42"
Turing = "0.42"
julia = "1.10"
29 changes: 0 additions & 29 deletions ext/SliceSamplingTuringExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,6 @@ using Random
using SliceSampling
using Turing

# Required for using the slice samplers as `externalsampler`s in Turing
# begin
function Turing.Inference.getparams(
::Turing.DynamicPPL.Model, sample::SliceSampling.Transition
)
return sample.params
end
# end
Comment on lines -9 to -16
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to the external sampler interface are described in https://github.com/TuringLang/Turing.jl/releases/tag/v0.42.0 -- the general aim is that you should not need to overload Turing internal functions (getparams was actually not exported afaik) and shifting this to AbstractMCMC means that it's easier for other packages to make use of this info.

Turing.Inference.getparams is gone now, it's replaced with AbstractMCMC.getparams (but called on the state instead of the transition).


# Required for using the slice samplers as `Gibbs` samplers in Turing
# begin
Turing.Inference.isgibbscomponent(::SliceSampling.RandPermGibbs) = true
Turing.Inference.isgibbscomponent(::SliceSampling.HitAndRun) = true
Turing.Inference.isgibbscomponent(::SliceSampling.Slice) = true
Turing.Inference.isgibbscomponent(::SliceSampling.SliceSteppingOut) = true
Turing.Inference.isgibbscomponent(::SliceSampling.SliceDoublingOut) = true
Comment on lines -20 to -24
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isgibbscomponent now defaults to true so it no longer needs to be overridden


const SliceSamplingStates = Union{
SliceSampling.UnivariateSliceState,
SliceSampling.GibbsState,
SliceSampling.HitAndRunState,
SliceSampling.LatentSliceState,
SliceSampling.GibbsPolarSliceState,
}
function Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSamplingStates)
return sample.transition.params
end
# end

function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
n_max_attempts = 1000

Expand Down
17 changes: 17 additions & 0 deletions src/SliceSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
module SliceSampling

using AbstractMCMC
using Accessors: Accessors
using Distributions
using LinearAlgebra
using LogDensityProblems
Expand Down Expand Up @@ -37,6 +38,22 @@ struct Transition{P,L<:Real,I<:NamedTuple}
info::I
end

"""
abstract type AbstractStateWithTransition

Base type for MCMC states that contain a `Transition` stored in the `transition` field.
"""
abstract type AbstractStateWithTransition end
AbstractMCMC.getparams(state::AbstractStateWithTransition) = state.transition.params
AbstractMCMC.getstats(state::AbstractStateWithTransition) = state.transition.info
function AbstractMCMC.setparams!!(
model::AbstractMCMC.LogDensityModel, state::AbstractStateWithTransition, params
)
new_lp = LogDensityProblems.logdensity(model.logdensity, params)
new_transition = Transition(params, new_lp, NamedTuple())
return Accessors.@set state.transition = new_transition
end
Comment on lines +46 to +55
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the definitions of these functions are the same for all states in this package, I thought it would be cleaner to just define the behaviour on an abstract type. It does necessitate an extra dep on Accessors, but that's fairly lightweight.


"""
initial_sample(rng, model)

Expand Down
3 changes: 2 additions & 1 deletion src/multivariate/gibbspolar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ function GibbsPolarSlice(w::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS)
return GibbsPolarSlice(w, max_proposals)
end

struct GibbsPolarSliceState{T<:Transition,R<:Real,D<:AbstractVector}
struct GibbsPolarSliceState{T<:Transition,R<:Real,D<:AbstractVector} <:
AbstractStateWithTransition
"Current [`Transition`](@ref)."
transition::T

Expand Down
9 changes: 1 addition & 8 deletions src/multivariate/hitandrun.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,11 @@ struct HitAndRun{S<:AbstractUnivariateSliceSampling} <: AbstractMultivariateSlic
unislice::S
end

struct HitAndRunState{T<:Transition}
struct HitAndRunState{T<:Transition} <: AbstractStateWithTransition
"Current [`Transition`](@ref)."
transition::T
end

function AbstractMCMC.setparams!!(
model::AbstractMCMC.LogDensityModel, state::HitAndRunState, params
)
lp = LogDensityProblems.logdensity(model.logdensity, params)
return HitAndRunState(Transition(params, lp, NamedTuple()))
end

struct HitAndRunTarget{Model,Vec<:AbstractVector}
model :: Model
direction :: Vec
Expand Down
2 changes: 1 addition & 1 deletion src/multivariate/latent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function LatentSlice(beta::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS)
return LatentSlice(beta, max_proposals)
end

struct LatentSliceState{T<:Transition,S<:AbstractVector}
struct LatentSliceState{T<:Transition,S<:AbstractVector} <: AbstractStateWithTransition
"Current [`Transition`](@ref)."
transition::T

Expand Down
7 changes: 1 addition & 6 deletions src/multivariate/randpermgibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,11 @@ struct RandPermGibbs{
unislice::S
end

struct GibbsState{T<:Transition}
struct GibbsState{T<:Transition} <: AbstractStateWithTransition
"Current [`Transition`](@ref)."
transition::T
end

function AbstractMCMC.setparams!!(model::AbstractMCMC.LogDensityModel, ::GibbsState, params)
lp = LogDensityProblems.logdensity(model.logdensity, params)
return GibbsState(Transition(params, lp, NamedTuple()))
end

struct GibbsTarget{Model,Idx<:Integer,Vec<:AbstractVector}
model :: Model
idx :: Idx
Expand Down
9 changes: 1 addition & 8 deletions src/univariate/univariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,11 @@ function slice_sampling_univariate(
return exceeded_max_prop(max_prop)
end

struct UnivariateSliceState{T<:Transition}
struct UnivariateSliceState{T<:Transition} <: AbstractStateWithTransition
"Current [`Transition`](@ref)."
transition::T
end

function AbstractMCMC.setparams!!(
model::AbstractMCMC.LogDensityModel, state::UnivariateSliceState, params
)
lp = LogDensityProblems.logdensity(model.logdensity, params)
return UnivariateSliceState(Transition(params, lp, NamedTuple()))
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
Expand Down
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
AbstractMCMC = "4, 5"
AbstractMCMC = "5.9"
Accessors = "0.1"
Distributions = "0.25"
LogDensityProblems = "2"
MCMCTesting = "0.3"
Random = "1"
StableRNGs = "1"
Test = "1"
Turing = "0.41, 0.42"
Turing = "0.42"
julia = "1.10"
Loading