Skip to content

Commit bccfdf0

Browse files
committed
Merge branch 'breaking' into mhauru/vnt-for-fastldf
2 parents f114e40 + 1bb97ae commit bccfdf0

File tree

7 files changed

+37
-3
lines changed

7 files changed

+37
-3
lines changed

HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
# DynamicPPL Changelog
22

3+
## 0.40
4+
5+
## 0.39.1
6+
7+
`LogDensityFunction` now allows you to call `logdensity_and_gradient(ldf, x)` with `AbstractVector`s `x` that are not plain Vectors (they will be converted internally before calculating the gradient).
8+
39
## 0.39.0
410

511
### Breaking changes

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.39.0"
3+
version = "0.40"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

benchmarks/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ DynamicPPL = {path = "../"}
2424
ADTypes = "1.14.0"
2525
Chairmarks = "1.3.1"
2626
Distributions = "0.25.117"
27-
DynamicPPL = "0.39"
27+
DynamicPPL = "0.40"
2828
Enzyme = "0.13"
2929
ForwardDiff = "1"
3030
JSON = "1.3.0"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Accessors = "0.1"
2121
Distributions = "0.25"
2222
Documenter = "1"
2323
DocumenterMermaid = "0.1, 0.2"
24-
DynamicPPL = "0.39"
24+
DynamicPPL = "0.40"
2525
FillArrays = "0.13, 1"
2626
ForwardDiff = "0.10, 1"
2727
JET = "0.9, 0.10, 0.11"

src/logdensityfunction.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ struct LogDensityFunction{
144144
F<:Function,
145145
VNT<:VarNamedTuple,
146146
ADP<:Union{Nothing,DI.GradientPrep},
147+
# type of the vector passed to logdensity functions
148+
X<:AbstractVector,
147149
}
148150
model::M
149151
adtype::AD
@@ -192,12 +194,17 @@ struct LogDensityFunction{
192194
typeof(getlogdensity),
193195
typeof(all_ranges),
194196
typeof(prep),
197+
typeof(x),
195198
}(
196199
model, adtype, getlogdensity, all_ranges, prep, dim
197200
)
198201
end
199202
end
200203

204+
function _get_input_vector_type(::LogDensityFunction{T,M,A,G,I,P,X}) where {T,M,A,G,I,P,X}
205+
return X
206+
end
207+
201208
###################################
202209
# LogDensityProblems.jl interface #
203210
###################################
@@ -245,6 +252,7 @@ end
245252
function LogDensityProblems.logdensity_and_gradient(
246253
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
247254
) where {Tlink}
255+
params = convert(_get_input_vector_type(ldf), params)
248256
return DI.value_and_gradient(
249257
LogDensityAt{Tlink}(ldf.model, ldf._getlogdensity, ldf._varname_ranges),
250258
ldf._adprep,

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Accessors = "0.1"
3939
Aqua = "0.8"
4040
BangBang = "0.4"
4141
Bijectors = "0.15.1"
42+
Chairmarks = "1"
4243
Combinatorics = "1"
4344
DifferentiationInterface = "0.6.41, 0.7"
4445
Distributions = "0.25"

test/logdensityfunction.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,25 @@ end
182182
end
183183
end
184184

185+
@testset "logdensity_and_gradient with views" begin
186+
# This test ensures that you can call `logdensity_and_gradient` with an array
187+
# type that isn't the same as the one used in the gradient preparation.
188+
@model function f()
189+
x ~ Normal()
190+
return y ~ Normal()
191+
end
192+
@testset "$adtype" for adtype in test_adtypes
193+
x = randn(2)
194+
ldf = LogDensityFunction(f(); adtype)
195+
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
196+
logp_view, grad_view = LogDensityProblems.logdensity_and_gradient(
197+
ldf, (@view x[:])
198+
)
199+
@test logp == logp_view
200+
@test grad == grad_view
201+
end
202+
end
203+
185204
# Test that various different ways of specifying array types as arguments work with all
186205
# ADTypes.
187206
@testset "Array argument types" begin

0 commit comments

Comments
 (0)