Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Nov 8, 2025

This does lead to some improvement in performance, but not as much as I had hoped:

using Turing, Random, LinearAlgebra
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function eight_schools(y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower=0)
    theta ~ MvNormal(fill(mu, length(y)), tau^2 * I)
    for i in eachindex(y)
        y[i] ~ Normal(theta[i], sigma[i])
    end
    return (mu=mu, tau=tau)
end
model = eight_schools(y, sigma)

chn = sample(Xoshiro(468), model, NUTS(), 10000; progress=false)

@time returned(model, chn);
# main   :  0.760821 seconds (3.99 M allocations: 177.151 MiB, 2.56% gc time)
# This PR:  0.690771 seconds (3.15 M allocations: 90.624 MiB, 1.75% gc time)

It appears to me that most of this time is spent faffing with MCMCChains. Every time you try to get the value of @varname(mu) you have to go through the varname_to_symbol dict, etc. Even more importantly, there's an issue with theta, because that's vector-valued and when you access it you have to reconstruct the vector (with getvalue(dict, vn, dist). So I believe we are hitting a natural plateau that is caused by the data structure. Probably the performance gains with FlexiChains would be larger.

Still, I suppose it's worth putting this in because it's basically free performance, so why not? It's also cleaner code.

Also, I haven't tested this, but I'm about 95% sure that this makes predict, returned, etc. completely thread-safe if TSVI is used. That's because the combination of TSVI{<:OAVI} is completely threadsafe to everything.


A complete aside

Wouldn't it be fun if we could inspect the model, realise that the return value only involves mu and tau, realise that tau is a bits type and thus tau^2 * I cannot mutate tau, and thus optimise away all the things to do with theta and y?

@github-actions
Copy link
Contributor

github-actions bot commented Nov 8, 2025

Benchmark Report

  • this PR's head: edf4e74dedda18eb772a17fa82dfeff8b2aa46c2
  • base branch: 766f6635903c401a79d3c2427dc60225f0053dad

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬─────────────────────────────────┬────────────────────────────┬─────────────────────────────────┐
│                       │       │             │                   │        │        t(eval) / t(ref)         │     t(grad) / t(eval)      │        t(grad) / t(ref)         │
│                       │       │             │                   │        │ ──────────┬───────────┬──────── │ ───────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │      base │   this PR │ speedup │   base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │             typed │   true │    351.32 │    484.41 │    0.73 │  15.90 │    8.75 │    1.82 │   5586.23 │   4237.70 │    1.32 │
│                   LDA │    12 │ reversediff │             typed │   true │   2850.98 │   3217.55 │    0.89 │   4.46 │    4.30 │    1.04 │  12707.99 │  13828.49 │    0.92 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │ 107739.10 │ 100809.07 │    1.07 │   3.91 │    5.39 │    0.73 │ 421206.16 │ 543357.60 │    0.78 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │   8317.30 │   8875.76 │    0.94 │   4.17 │    4.40 │    0.95 │  34702.41 │  39050.19 │    0.89 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │  30311.00 │  44967.67 │    0.67 │  10.01 │   11.28 │    0.89 │ 303403.64 │ 507402.74 │    0.60 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │   3276.02 │   4746.36 │    0.69 │   9.36 │    6.96 │    1.34 │  30653.57 │  33037.90 │    0.93 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │      3.45 │      3.82 │    0.90 │   2.72 │    2.83 │    0.96 │      9.39 │     10.79 │    0.87 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │   1187.90 │   1389.71 │    0.85 │ 120.40 │   73.56 │    1.64 │ 143020.95 │ 102228.07 │    1.40 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │       err │       err │     err │    err │     err │     err │       err │       err │     err │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │       err │       err │     err │    err │     err │     err │       err │       err │     err │
│           Smorgasbord │   201 │      enzyme │             typed │   true │       err │       err │     err │    err │     err │     err │       err │       err │     err │
│           Smorgasbord │   201 │    mooncake │             typed │   true │   1635.16 │   2331.06 │    0.70 │   5.63 │    3.73 │    1.51 │   9201.87 │   8688.16 │    1.06 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ reversediff │             typed │   true │   1976.29 │   1922.20 │    1.03 │  70.42 │   85.65 │    0.82 │ 139169.52 │ 164632.58 │    0.85 │
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │   1632.93 │   1935.44 │    0.84 │  55.69 │   64.15 │    0.87 │  90937.32 │ 124167.75 │    0.73 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │   1625.71 │   1925.05 │    0.84 │  59.76 │   65.30 │    0.92 │  97154.66 │ 125701.17 │    0.77 │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼───────────┼───────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │   1620.52 │   1934.44 │    0.84 │  58.22 │   67.10 │    0.87 │  94339.11 │ 129809.45 │    0.73 │
│              Submodel │     1 │    mooncake │             typed │   true │      8.00 │      8.92 │    0.90 │   4.53 │    4.34 │    1.04 │     36.27 │     38.70 │    0.94 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴───────────┴───────────┴─────────┴────────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@codecov
Copy link

codecov bot commented Nov 8, 2025

Codecov Report

❌ Patch coverage is 89.28571% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 80.60%. Comparing base (766f663) to head (edf4e74).
⚠️ Report is 1 commits behind head on breaking.

Files with missing lines Patch % Lines
src/logdensityfunction.jl 57.14% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           breaking    #1130      +/-   ##
============================================
- Coverage     81.07%   80.60%   -0.48%     
============================================
  Files            41       41              
  Lines          3894     3861      -33     
============================================
- Hits           3157     3112      -45     
- Misses          737      749      +12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm penelopeysm mentioned this pull request Nov 8, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Nov 8, 2025

DynamicPPL.jl documentation for PR #1130 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1130/

Base automatically changed from py/fastinit to py/fastldf November 10, 2025 19:28
@penelopeysm penelopeysm force-pushed the py/fasteverythingelse branch from 1d79c7f to 89a5aec Compare November 25, 2025 00:50
@penelopeysm penelopeysm changed the base branch from py/fastldf to py/not-experimental November 25, 2025 00:51
@penelopeysm penelopeysm force-pushed the py/fasteverythingelse branch 3 times, most recently from f211389 to 5e335ff Compare November 25, 2025 00:53
Comment on lines +158 to +159
"""
reevaluate_with_chain(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fundamentally, all the functions in this extension really just use this under the hood.

FWIW the FlexiChains extension has a very similar structure and I believe these can be unified pretty much immediately after this PR. To be specific, DynamicPPL.reevaluate_with_chain should be implemented by each chain type in the most performant manner (FlexiChains doesn't use InitFromParams), but once that's done, the definitions of returned, logjoint, ..., pointwise_logdensities, ... can be shared.

predict can't yet be shared unfortunately, because the include_all keyword argument forces custom MCMCChains / FlexiChains code. That would require an extension of the AbstractChains API to support a subset-like operation.

@penelopeysm penelopeysm force-pushed the py/not-experimental branch 4 times, most recently from cf33cff to d1c002f Compare November 25, 2025 02:48
Base automatically changed from py/not-experimental to breaking November 25, 2025 11:41
@penelopeysm penelopeysm force-pushed the py/fasteverythingelse branch from 5e335ff to 62cd3af Compare November 25, 2025 11:45
@penelopeysm penelopeysm marked this pull request as ready for review November 25, 2025 12:01
@penelopeysm penelopeysm changed the title Implement predict, returned, logjoint, ... with fast eval Implement predict, returned, logjoint, ... with OnlyAccsVarInfo Nov 25, 2025
@penelopeysm penelopeysm requested a review from sunxd3 November 25, 2025 12:03
Copy link
Member

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

code is good, maye something to be added to changelog

params_with_stats = AbstractMCMC.to_samples(DynamicPPL.ParamsWithStats, chain)
vi = DynamicPPL.OnlyAccsVarInfo(DynamicPPL.AccumulatorTuple(accs))
return map(params_with_stats) do ps
DynamicPPL.init!!(rng, model, vi, DynamicPPL.InitFromParams(ps.params, fallback))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fallback being nothing is a change in default behavior, right? Might worth noting in changelog if true

Comment on lines -1 to -26
"""
supports_varname_indexing(chain::AbstractChains)
Return `true` if `chain` supports indexing using `VarName` in place of the
variable name index.
"""
supports_varname_indexing(::AbstractChains) = false

"""
getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx)
Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`.
Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
"""
function getindex_varname end

"""
varnames(chains::AbstractChains)
Return an iterator over the varnames present in `chains`.
Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
"""
function varnames end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no objection on getting rid of these, not sure if we should add an entry in changelog though. the likelihood that any of these are used somewhere is low

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you're completely correct, and I just totally neglected the changelog in this PR. My bad.

@sunxd3
Copy link
Member

sunxd3 commented Nov 25, 2025

on

A complete aside
Wouldn't it be fun if we could inspect the model, realise that the return value only involves mu and tau, realise that tau is a bits type and thus tau^2 * I cannot mutate tau, and thus optimise away all the things to do with theta and y?

maybe we can somehow check Julia's effect analysis pipeline? (https://aviatesk.github.io/posts/effects-analysis/index.html, https://docs.julialang.org/en/v1/base/base/#Base.@assume_effects, https://github.com/JuliaLang/julia/blob/master/Compiler/src/effects.jl)

@penelopeysm
Copy link
Member Author

Hmmm, I thought that the only way to make it work would be to handle the IR (i.e. like Mooncake), but I don't know anything about the effect system!

@penelopeysm penelopeysm merged commit 8547e25 into breaking Nov 25, 2025
19 of 21 checks passed
@penelopeysm penelopeysm deleted the py/fasteverythingelse branch November 25, 2025 15:39
github-merge-queue bot pushed a commit that referenced this pull request Dec 2, 2025
* v0.39

* Update DPPL compats for benchmarks and docs

* remove merge conflict markers

* Remove `NodeTrait` (#1133)

* Remove NodeTrait

* Changelog

* Fix exports

* docs

* fix a bug

* Fix doctests

* Fix test

* tweak changelog

* FastLDF / InitContext unified (#1132)

* Fast Log Density Function

* Make it work with AD

* Optimise performance for identity VarNames

* Mark `get_range_and_linked` as having zero derivative

* Update comment

* make AD testing / benchmarking use FastLDF

* Fix tests

* Optimise away `make_evaluate_args_and_kwargs`

* const func annotation

* Disable benchmarks on non-typed-Metadata-VarInfo

* Fix `_evaluate!!` correctly to handle submodels

* Actually fix submodel evaluate

* Document thoroughly and organise code

* Support more VarInfos, make it thread-safe (?)

* fix bug in parsing ranges from metadata/VNV

* Fix get_param_eltype for TSVI

* Disable Enzyme benchmark

* Don't override _evaluate!!, that breaks ForwardDiff (sometimes)

* Move FastLDF to experimental for now

* Fix imports, add tests, etc

* More test fixes

* Fix imports / tests

* Remove AbstractFastEvalContext

* Changelog and patch bump

* Add correctness tests, fix imports

* Concretise parameter vector in tests

* Add zero-allocation tests

* Add Chairmarks as test dep

* Disable allocations tests on multi-threaded

* Fast InitContext (#1125)

* Make InitContext work with OnlyAccsVarInfo

* Do not convert NamedTuple to Dict

* remove logging

* Enable InitFromPrior and InitFromUniform too

* Fix `infer_nested_eltype` invocation

* Refactor FastLDF to use InitContext

* note init breaking change

* fix logjac sign

* workaround Mooncake segfault

* fix changelog too

* Fix get_param_eltype for context stacks

* Add a test for threaded observe

* Export init

* Remove dead code

* fix transforms for pathological distributions

* Tidy up loads of things

* fix typed_identity spelling

* fix definition order

* Improve docstrings

* Remove stray comment

* export get_param_eltype (unfortunatley)

* Add more comment

* Update comment

* Remove inlines, fix OAVI docstring

* Improve docstrings

* Simplify InitFromParams constructor

* Replace map(identity, x[:]) with [i for i in x[:]]

* Simplify implementation for InitContext/OAVI

* Add another model to allocation tests

Co-authored-by: Markus Hauru <markus@mhauru.org>

* Revert removal of dist argument (oops)

* Format

* Update some outdated bits of FastLDF docstring

* remove underscores

---------

Co-authored-by: Markus Hauru <markus@mhauru.org>

* implement `LogDensityProblems.dimension`

* forgot about capabilities...

* use interpolation in run_ad

* Improvements to benchmark outputs (#1146)

* print output

* fix

* reenable

* add more lines to guide the eye

* reorder table

* print tgrad / trel as well

* forgot this type

* Allow generation of `ParamsWithStats` from `FastLDF` plus parameters, and also `bundle_samples` (#1129)

* Implement `ParamsWithStats` for `FastLDF`

* Add comments

* Implement `bundle_samples` for ParamsWithStats -> MCMCChains

* Remove redundant comment

* don't need Statistics?

* Make FastLDF the default (#1139)

* Make FastLDF the default

* Add miscellaneous LogDensityProblems tests

* Use `init!!` instead of `fast_evaluate!!`

* Rename files, rebalance tests

* Implement `predict`, `returned`, `logjoint`, ... with `OnlyAccsVarInfo` (#1130)

* Use OnlyAccsVarInfo for many re-evaluation functions

* drop `fast_` prefix

* Add a changelog

* Improve FastLDF type stability when all parameters are linked or unlinked (#1141)

* Improve type stability when all parameters are linked or unlinked

* fix a merge conflict

* fix enzyme gc crash (locally at least)

* Fixes from review

* Make threadsafe evaluation opt-in (#1151)

* Make threadsafe evaluation opt-in

* Reduce number of type parameters in methods

* Make `warned_warn_about_threads_threads_threads_threads` shorter

* Improve `setthreadsafe` docstring

* warn on bare `@threads` as well

* fix merge

* Fix performance issues

* Use maxthreadid() in TSVI

* Move convert_eltype code to threadsafe eval function

* Point to new Turing docs page

* Add a test for setthreadsafe

* Tidy up check_model

* Apply suggestions from code review

Fix outdated docstrings

Co-authored-by: Markus Hauru <markus@mhauru.org>

* Improve warning message

* Export `requires_threadsafe`

* Add an actual docstring for `requires_threadsafe`

---------

Co-authored-by: Markus Hauru <markus@mhauru.org>

* Standardise `:lp` -> `:logjoint` (#1161)

* Standardise `:lp` -> `:logjoint`

* changelog

* fix a test

---------

Co-authored-by: Markus Hauru <mhauru@turing.ac.uk>
Co-authored-by: Markus Hauru <markus@mhauru.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants