From 18fb8ecd307c0b46ec490cb839f9d318f6ec51ec Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 14 Dec 2025 16:16:02 +0000 Subject: [PATCH] Update for newer AbstractMCMC/Turing interface --- .gitignore | 4 ++-- Project.toml | 8 +++++--- docs/Project.toml | 4 ++-- ext/SliceSamplingTuringExt.jl | 29 ----------------------------- src/SliceSampling.jl | 17 +++++++++++++++++ src/multivariate/gibbspolar.jl | 3 ++- src/multivariate/hitandrun.jl | 9 +-------- src/multivariate/latent.jl | 2 +- src/multivariate/randpermgibbs.jl | 7 +------ src/univariate/univariate.jl | 9 +-------- test/Project.toml | 4 ++-- 11 files changed, 34 insertions(+), 62 deletions(-) diff --git a/.gitignore b/.gitignore index bd9827c..71cac9f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,5 @@ /docs/Manifest.toml /docs/build/ **/*~ -Maniest.toml -test/Manifest.toml \ No newline at end of file +Manifest.toml +test/Manifest.toml diff --git a/Project.toml b/Project.toml index f22eea5..3f63267 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "SliceSampling" uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf" -version = "0.7.10" +version = "0.7.11" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" @@ -16,12 +17,13 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" SliceSamplingTuringExt = ["Turing"] [compat] -AbstractMCMC = "4, 5" +AbstractMCMC = "5.9" +Accessors = "0.1" Distributions = "0.25" LinearAlgebra = "1" LogDensityProblems = "2" Random = "1" -Turing = "0.41" +Turing = "0.42" julia = "1.10" [extras] diff --git a/docs/Project.toml b/docs/Project.toml index 0969cfe..904d95d 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -15,7 +15,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -AbstractMCMC = "5" +AbstractMCMC = "5.9" Distributions = "0.25" Documenter = "1" FillArrays = "1" @@ -28,5 +28,5 @@ Random = "1" SliceSampling = "0.7.1" StableRNGs = "1" Statistics = "1" -Turing = "0.41, 0.42" +Turing = "0.42" julia = "1.10" diff --git a/ext/SliceSamplingTuringExt.jl b/ext/SliceSamplingTuringExt.jl index 3eba4bd..c181d3d 100644 --- a/ext/SliceSamplingTuringExt.jl +++ b/ext/SliceSamplingTuringExt.jl @@ -6,35 +6,6 @@ using Random using SliceSampling using Turing -# Required for using the slice samplers as `externalsampler`s in Turing -# begin -function Turing.Inference.getparams( - ::Turing.DynamicPPL.Model, sample::SliceSampling.Transition -) - return sample.params -end -# end - -# Required for using the slice samplers as `Gibbs` samplers in Turing -# begin -Turing.Inference.isgibbscomponent(::SliceSampling.RandPermGibbs) = true -Turing.Inference.isgibbscomponent(::SliceSampling.HitAndRun) = true -Turing.Inference.isgibbscomponent(::SliceSampling.Slice) = true -Turing.Inference.isgibbscomponent(::SliceSampling.SliceSteppingOut) = true -Turing.Inference.isgibbscomponent(::SliceSampling.SliceDoublingOut) = true - -const SliceSamplingStates = Union{ - SliceSampling.UnivariateSliceState, - SliceSampling.GibbsState, - SliceSampling.HitAndRunState, - SliceSampling.LatentSliceState, - SliceSampling.GibbsPolarSliceState, -} -function Turing.Inference.getparams(::Turing.DynamicPPL.Model, sample::SliceSamplingStates) - return sample.transition.params -end -# end - function SliceSampling.initial_sample(rng::Random.AbstractRNG, ℓ::Turing.LogDensityFunction) n_max_attempts = 1000 diff --git a/src/SliceSampling.jl b/src/SliceSampling.jl index c43875c..92568cf 100644 --- a/src/SliceSampling.jl +++ b/src/SliceSampling.jl @@ -2,6 +2,7 @@ module SliceSampling using AbstractMCMC +using Accessors: Accessors using Distributions using LinearAlgebra using LogDensityProblems @@ -37,6 +38,22 @@ struct Transition{P,L<:Real,I<:NamedTuple} info::I end +""" + abstract type AbstractStateWithTransition + +Base type for MCMC states that contain a `Transition` stored in the `transition` field. +""" +abstract type AbstractStateWithTransition end +AbstractMCMC.getparams(state::AbstractStateWithTransition) = state.transition.params +AbstractMCMC.getstats(state::AbstractStateWithTransition) = state.transition.info +function AbstractMCMC.setparams!!( + model::AbstractMCMC.LogDensityModel, state::AbstractStateWithTransition, params +) + new_lp = LogDensityProblems.logdensity(model.logdensity, params) + new_transition = Transition(params, new_lp, NamedTuple()) + return Accessors.@set state.transition = new_transition +end + """ initial_sample(rng, model) diff --git a/src/multivariate/gibbspolar.jl b/src/multivariate/gibbspolar.jl index 7872ee4..714be85 100644 --- a/src/multivariate/gibbspolar.jl +++ b/src/multivariate/gibbspolar.jl @@ -29,7 +29,8 @@ function GibbsPolarSlice(w::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS) return GibbsPolarSlice(w, max_proposals) end -struct GibbsPolarSliceState{T<:Transition,R<:Real,D<:AbstractVector} +struct GibbsPolarSliceState{T<:Transition,R<:Real,D<:AbstractVector} <: + AbstractStateWithTransition "Current [`Transition`](@ref)." transition::T diff --git a/src/multivariate/hitandrun.jl b/src/multivariate/hitandrun.jl index 8fb945c..7b38895 100644 --- a/src/multivariate/hitandrun.jl +++ b/src/multivariate/hitandrun.jl @@ -12,18 +12,11 @@ struct HitAndRun{S<:AbstractUnivariateSliceSampling} <: AbstractMultivariateSlic unislice::S end -struct HitAndRunState{T<:Transition} +struct HitAndRunState{T<:Transition} <: AbstractStateWithTransition "Current [`Transition`](@ref)." transition::T end -function AbstractMCMC.setparams!!( - model::AbstractMCMC.LogDensityModel, state::HitAndRunState, params -) - lp = LogDensityProblems.logdensity(model.logdensity, params) - return HitAndRunState(Transition(params, lp, NamedTuple())) -end - struct HitAndRunTarget{Model,Vec<:AbstractVector} model :: Model direction :: Vec diff --git a/src/multivariate/latent.jl b/src/multivariate/latent.jl index 9f0e288..012278b 100644 --- a/src/multivariate/latent.jl +++ b/src/multivariate/latent.jl @@ -20,7 +20,7 @@ function LatentSlice(beta::Real; max_proposals::Int=DEFAULT_MAX_PROPOSALS) return LatentSlice(beta, max_proposals) end -struct LatentSliceState{T<:Transition,S<:AbstractVector} +struct LatentSliceState{T<:Transition,S<:AbstractVector} <: AbstractStateWithTransition "Current [`Transition`](@ref)." transition::T diff --git a/src/multivariate/randpermgibbs.jl b/src/multivariate/randpermgibbs.jl index 05b4f77..9926d4c 100644 --- a/src/multivariate/randpermgibbs.jl +++ b/src/multivariate/randpermgibbs.jl @@ -20,16 +20,11 @@ struct RandPermGibbs{ unislice::S end -struct GibbsState{T<:Transition} +struct GibbsState{T<:Transition} <: AbstractStateWithTransition "Current [`Transition`](@ref)." transition::T end -function AbstractMCMC.setparams!!(model::AbstractMCMC.LogDensityModel, ::GibbsState, params) - lp = LogDensityProblems.logdensity(model.logdensity, params) - return GibbsState(Transition(params, lp, NamedTuple())) -end - struct GibbsTarget{Model,Idx<:Integer,Vec<:AbstractVector} model :: Model idx :: Idx diff --git a/src/univariate/univariate.jl b/src/univariate/univariate.jl index 13f136b..adb70db 100644 --- a/src/univariate/univariate.jl +++ b/src/univariate/univariate.jl @@ -24,18 +24,11 @@ function slice_sampling_univariate( return exceeded_max_prop(max_prop) end -struct UnivariateSliceState{T<:Transition} +struct UnivariateSliceState{T<:Transition} <: AbstractStateWithTransition "Current [`Transition`](@ref)." transition::T end -function AbstractMCMC.setparams!!( - model::AbstractMCMC.LogDensityModel, state::UnivariateSliceState, params -) - lp = LogDensityProblems.logdensity(model.logdensity, params) - return UnivariateSliceState(Transition(params, lp, NamedTuple())) -end - function AbstractMCMC.step( rng::Random.AbstractRNG, model::AbstractMCMC.LogDensityModel, diff --git a/test/Project.toml b/test/Project.toml index 52478e8..38c23bc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,7 +10,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -AbstractMCMC = "4, 5" +AbstractMCMC = "5.9" Accessors = "0.1" Distributions = "0.25" LogDensityProblems = "2" @@ -18,5 +18,5 @@ MCMCTesting = "0.3" Random = "1" StableRNGs = "1" Test = "1" -Turing = "0.41, 0.42" +Turing = "0.42" julia = "1.10"