Skip to content

Commit 18fb8ec

Browse files
committed
Update for newer AbstractMCMC/Turing interface
1 parent 2dd6716 commit 18fb8ec

File tree

11 files changed

+34
-62
lines changed

11 files changed

+34
-62
lines changed

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44
/docs/Manifest.toml
55
/docs/build/
66
**/*~
7-
Maniest.toml
8-
test/Manifest.toml
7+
Manifest.toml
8+
test/Manifest.toml

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
name = "SliceSampling"
22
uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf"
3-
version = "0.7.10"
3+
version = "0.7.11"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
7+
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
78
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
@@ -16,12 +17,13 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1617
SliceSamplingTuringExt = ["Turing"]
1718

1819
[compat]
19-
AbstractMCMC = "4, 5"
20+
AbstractMCMC = "5.9"
21+
Accessors = "0.1"
2022
Distributions = "0.25"
2123
LinearAlgebra = "1"
2224
LogDensityProblems = "2"
2325
Random = "1"
24-
Turing = "0.41"
26+
Turing = "0.42"
2527
julia = "1.10"
2628

2729
[extras]

docs/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1515
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
1616

1717
[compat]
18-
AbstractMCMC = "5"
18+
AbstractMCMC = "5.9"
1919
Distributions = "0.25"
2020
Documenter = "1"
2121
FillArrays = "1"
@@ -28,5 +28,5 @@ Random = "1"
2828
SliceSampling = "0.7.1"
2929
StableRNGs = "1"
3030
Statistics = "1"
31-
Turing = "0.41, 0.42"
31+
Turing = "0.42"
3232
julia = "1.10"

ext/SliceSamplingTuringExt.jl

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,6 @@ using Random
66
using SliceSampling
77
using Turing
88

9-
# Required for using the slice samplers as `externalsampler`s in Turing
10-
# begin
11-
function Turing.Inference.getparams(
12-
::Turing.DynamicPPL.Model, sample::SliceSampling.Transition
13-
)
14-
return sample.params
15-
end
16-
# end
17-
18-
# Required for using the slice samplers as `Gibbs` samplers in Turing
19-
# begin
20-
Turing.Inference.isgibbscomponent(::SliceSampling.RandPermGibbs) = true
21-
Turing.Inference.isgibbscomponent(::SliceSampling.HitAndRun) = true
22-
Turing.Inference.isgibbscomponent(::SliceSampling.Slice) = true
23-
Turing.Inference.isgibbscomponent(::SliceSampling.SliceSteppingOut) = true
24-
Turing.Inference.isgibbscomponent(::SliceSampling.SliceDoublingOut) = true
25-
26-
const SliceSamplingStates = Union{
27-
SliceSampling.UnivariateSliceState,
28-
SliceSampling.GibbsState,
29-
SliceSampling.HitAndRunState,
30-
SliceSampling.LatentSliceState,
31-
SliceSampling.GibbsPolarSliceState,
32-
}
33-
function Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSamplingStates)
34-
return sample.transition.params
35-
end
36-
# end
37-
389
function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction)
3910
n_max_attempts = 1000
4011

src/SliceSampling.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
module SliceSampling
33

44
using AbstractMCMC
5+
using Accessors: Accessors
56
using Distributions
67
using LinearAlgebra
78
using LogDensityProblems
@@ -37,6 +38,22 @@ struct Transition{P,L<:Real,I<:NamedTuple}
3738
info::I
3839
end
3940

41+
"""
42+
abstract type AbstractStateWithTransition
43+
44+
Base type for MCMC states that contain a `Transition` stored in the `transition` field.
45+
"""
46+
abstract type AbstractStateWithTransition end
47+
AbstractMCMC.getparams(state::AbstractStateWithTransition) = state.transition.params
48+
AbstractMCMC.getstats(state::AbstractStateWithTransition) = state.transition.info
49+
function AbstractMCMC.setparams!!(
50+
model::AbstractMCMC.LogDensityModel, state::AbstractStateWithTransition, params
51+
)
52+
new_lp = LogDensityProblems.logdensity(model.logdensity, params)
53+
new_transition = Transition(params, new_lp, NamedTuple())
54+
return Accessors.@set state.transition = new_transition
55+
end
56+
4057
"""
4158
initial_sample(rng, model)
4259

src/multivariate/gibbspolar.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ function GibbsPolarSlice(w::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS)
2929
return GibbsPolarSlice(w, max_proposals)
3030
end
3131

32-
struct GibbsPolarSliceState{T<:Transition,R<:Real,D<:AbstractVector}
32+
struct GibbsPolarSliceState{T<:Transition,R<:Real,D<:AbstractVector} <:
33+
AbstractStateWithTransition
3334
"Current [`Transition`](@ref)."
3435
transition::T
3536

src/multivariate/hitandrun.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,11 @@ struct HitAndRun{S<:AbstractUnivariateSliceSampling} <: AbstractMultivariateSlic
1212
unislice::S
1313
end
1414

15-
struct HitAndRunState{T<:Transition}
15+
struct HitAndRunState{T<:Transition} <: AbstractStateWithTransition
1616
"Current [`Transition`](@ref)."
1717
transition::T
1818
end
1919

20-
function AbstractMCMC.setparams!!(
21-
model::AbstractMCMC.LogDensityModel, state::HitAndRunState, params
22-
)
23-
lp = LogDensityProblems.logdensity(model.logdensity, params)
24-
return HitAndRunState(Transition(params, lp, NamedTuple()))
25-
end
26-
2720
struct HitAndRunTarget{Model,Vec<:AbstractVector}
2821
model :: Model
2922
direction :: Vec

src/multivariate/latent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ function LatentSlice(beta::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS)
2020
return LatentSlice(beta, max_proposals)
2121
end
2222

23-
struct LatentSliceState{T<:Transition,S<:AbstractVector}
23+
struct LatentSliceState{T<:Transition,S<:AbstractVector} <: AbstractStateWithTransition
2424
"Current [`Transition`](@ref)."
2525
transition::T
2626

src/multivariate/randpermgibbs.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,11 @@ struct RandPermGibbs{
2020
unislice::S
2121
end
2222

23-
struct GibbsState{T<:Transition}
23+
struct GibbsState{T<:Transition} <: AbstractStateWithTransition
2424
"Current [`Transition`](@ref)."
2525
transition::T
2626
end
2727

28-
function AbstractMCMC.setparams!!(model::AbstractMCMC.LogDensityModel, ::GibbsState, params)
29-
lp = LogDensityProblems.logdensity(model.logdensity, params)
30-
return GibbsState(Transition(params, lp, NamedTuple()))
31-
end
32-
3328
struct GibbsTarget{Model,Idx<:Integer,Vec<:AbstractVector}
3429
model :: Model
3530
idx :: Idx

src/univariate/univariate.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,11 @@ function slice_sampling_univariate(
2424
return exceeded_max_prop(max_prop)
2525
end
2626

27-
struct UnivariateSliceState{T<:Transition}
27+
struct UnivariateSliceState{T<:Transition} <: AbstractStateWithTransition
2828
"Current [`Transition`](@ref)."
2929
transition::T
3030
end
3131

32-
function AbstractMCMC.setparams!!(
33-
model::AbstractMCMC.LogDensityModel, state::UnivariateSliceState, params
34-
)
35-
lp = LogDensityProblems.logdensity(model.logdensity, params)
36-
return UnivariateSliceState(Transition(params, lp, NamedTuple()))
37-
end
38-
3932
function AbstractMCMC.step(
4033
rng::Random.AbstractRNG,
4134
model::AbstractMCMC.LogDensityModel,

0 commit comments

Comments
 (0)