Skip to content

Commit 17934e1

Browse files
committed
Update for newer AbstractMCMC/Turing interface
1 parent 2dd6716 commit 17934e1

File tree

10 files changed

+31
-61
lines changed

10 files changed

+31
-61
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "SliceSampling"
22
uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf"
3-
version = "0.7.10"
3+
version = "0.7.11"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -16,12 +16,12 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1616
SliceSamplingTuringExt = ["Turing"]
1717

1818
[compat]
19-
AbstractMCMC = "4, 5"
19+
AbstractMCMC = "5.9"
2020
Distributions = "0.25"
2121
LinearAlgebra = "1"
2222
LogDensityProblems = "2"
2323
Random = "1"
24-
Turing = "0.41"
24+
Turing = "0.42"
2525
julia = "1.10"
2626

2727
[extras]

docs/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1515
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1616

1717
[compat]
18-
AbstractMCMC = "5"
18+
AbstractMCMC = "5.9"
1919
Distributions = "0.25"
2020
Documenter = "1"
2121
FillArrays = "1"
@@ -28,5 +28,5 @@ Random = "1"
2828
SliceSampling = "0.7.1"
2929
StableRNGs = "1"
3030
Statistics = "1"
31-
Turing = "0.41, 0.42"
31+
Turing = "0.42"
3232
julia = "1.10"

ext/SliceSamplingTuringExt.jl

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,6 @@ using Random
66
using SliceSampling
77
using Turing
88

9-
# Required for using the slice samplers as `externalsampler`s in Turing
10-
# begin
11-
function Turing.Inference.getparams(
12-
::Turing.DynamicPPL.Model, sample::SliceSampling.Transition
13-
)
14-
return sample.params
15-
end
16-
# end
17-
18-
# Required for using the slice samplers as `Gibbs` samplers in Turing
19-
# begin
20-
Turing.Inference.isgibbscomponent(::SliceSampling.RandPermGibbs) = true
21-
Turing.Inference.isgibbscomponent(::SliceSampling.HitAndRun) = true
22-
Turing.Inference.isgibbscomponent(::SliceSampling.Slice) = true
23-
Turing.Inference.isgibbscomponent(::SliceSampling.SliceSteppingOut) = true
24-
Turing.Inference.isgibbscomponent(::SliceSampling.SliceDoublingOut) = true
25-
26-
const SliceSamplingStates = Union{
27-
SliceSampling.UnivariateSliceState,
28-
SliceSampling.GibbsState,
29-
SliceSampling.HitAndRunState,
30-
SliceSampling.LatentSliceState,
31-
SliceSampling.GibbsPolarSliceState,
32-
}
33-
function Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSamplingStates)
34-
return sample.transition.params
35-
end
36-
# end
37-
389
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
3910
n_max_attempts = 1000
4011

@@ -50,7 +21,7 @@ function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDe
5021
@warn "Failed to find valid initial parameters after $(init_attempt_count) attempts; consider providing explicit initial parameters using the `initial_params` keyword"
5122
end
5223

53-
# NOTE: This will sample in the unconstrained space if ℓ.varinfo is linked
24+
# NOTE: This will sample in the unconstrained space if ℓ.varinfo is linked
5425
vi_spl = last(Turing.DynamicPPL.init!!(rng, model, vi, Turing.InitFromUniform()))
5526
ℓp =.getlogdensity(vi_spl)
5627
θ = vi_spl[:]

src/SliceSampling.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
module SliceSampling
33

44
using AbstractMCMC
5+
using Accessors: Accessors
56
using Distributions
67
using LinearAlgebra
78
using LogDensityProblems
@@ -37,6 +38,23 @@ struct Transition{P,L<:Real,I<:NamedTuple}
3738
info::I
3839
end
3940

41+
"""
42+
abstract type AbstractStateWithTransition
43+
44+
Base type for MCMC states that contain a `Transition` stored in the `transition` field.
45+
"""
46+
abstract type AbstractStateWithTransition end
47+
get_transition(state::AbstractStateWithTransition) = state.transition
48+
AbstractMCMC.getparams(state::AbstractStateWithTransition) = get_transition(state).params
49+
AbstractMCMC.getstats(state::AbstractStateWithTransition) = get_transition(state).info
50+
function AbstractMCMC.setparams(
51+
model::AbstractMCMC.LogDensityModel, state::AbstractStateWithTransition, params
52+
)
53+
new_lp = LogDensityProblems.logdensity(model.logdensity, params)
54+
new_transition = Transition(params, new_lp, NamedTuple())
55+
return Accessors.@set state.transition = new_transition
56+
end
57+
4058
"""
4159
initial_sample(rng, model)
4260

src/multivariate/gibbspolar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function GibbsPolarSlice(w::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS)
2929
return GibbsPolarSlice(w, max_proposals)
3030
end
3131

32-
struct GibbsPolarSliceState{T<:Transition,R<:Real,D<:AbstractVector}
32+
struct GibbsPolarSliceState{T<:Transition,R<:Real,D<:AbstractVector} <: AbstractStateWithTransition
3333
"Current [`Transition`](@ref)."
3434
transition::T
3535

src/multivariate/hitandrun.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,11 @@ struct HitAndRun{S<:AbstractUnivariateSliceSampling} <: AbstractMultivariateSlic
1212
unislice::S
1313
end
1414

15-
struct HitAndRunState{T<:Transition}
15+
struct HitAndRunState{T<:Transition} <: AbstractStateWithTransition
1616
"Current [`Transition`](@ref)."
1717
transition::T
1818
end
1919

20-
function AbstractMCMC.setparams!!(
21-
model::AbstractMCMC.LogDensityModel, state::HitAndRunState, params
22-
)
23-
lp = LogDensityProblems.logdensity(model.logdensity, params)
24-
return HitAndRunState(Transition(params, lp, NamedTuple()))
25-
end
26-
2720
struct HitAndRunTarget{Model,Vec<:AbstractVector}
2821
model :: Model
2922
direction :: Vec

src/multivariate/latent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ function LatentSlice(beta::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS)
2020
return LatentSlice(beta, max_proposals)
2121
end
2222

23-
struct LatentSliceState{T<:Transition,S<:AbstractVector}
23+
struct LatentSliceState{T<:Transition,S<:AbstractVector} <: AbstractStateWithTransition
2424
"Current [`Transition`](@ref)."
2525
transition::T
2626

src/multivariate/randpermgibbs.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,11 @@ struct RandPermGibbs{
2020
unislice::S
2121
end
2222

23-
struct GibbsState{T<:Transition}
23+
struct GibbsState{T<:Transition} <: AbstractStateWithTransition
2424
"Current [`Transition`](@ref)."
2525
transition::T
2626
end
2727

28-
function AbstractMCMC.setparams!!(model::AbstractMCMC.LogDensityModel, ::GibbsState, params)
29-
lp = LogDensityProblems.logdensity(model.logdensity, params)
30-
return GibbsState(Transition(params, lp, NamedTuple()))
31-
end
32-
3328
struct GibbsTarget{Model,Idx<:Integer,Vec<:AbstractVector}
3429
model :: Model
3530
idx :: Idx

src/univariate/univariate.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,11 @@ function slice_sampling_univariate(
2424
return exceeded_max_prop(max_prop)
2525
end
2626

27-
struct UnivariateSliceState{T<:Transition}
27+
struct UnivariateSliceState{T<:Transition} <: AbstractStateWithTransition
2828
"Current [`Transition`](@ref)."
2929
transition::T
3030
end
3131

32-
function AbstractMCMC.setparams!!(
33-
model::AbstractMCMC.LogDensityModel, state::UnivariateSliceState, params
34-
)
35-
lp = LogDensityProblems.logdensity(model.logdensity, params)
36-
return UnivariateSliceState(Transition(params, lp, NamedTuple()))
37-
end
38-
3932
function AbstractMCMC.step(
4033
rng::Random.AbstractRNG,
4134
model::AbstractMCMC.LogDensityModel,

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1010
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1111

1212
[compat]
13-
AbstractMCMC = "4, 5"
13+
AbstractMCMC = "5.9"
1414
Accessors = "0.1"
1515
Distributions = "0.25"
1616
LogDensityProblems = "2"
1717
MCMCTesting = "0.3"
1818
Random = "1"
1919
StableRNGs = "1"
2020
Test = "1"
21-
Turing = "0.41, 0.42"
21+
Turing = "0.42"
2222
julia = "1.10"

0 commit comments

Comments
 (0)