Skip to content

Remove Jacobian type instability#212

Open
gdalle wants to merge 2 commits intobifurcationkit:masterfrom
gdalle:gd/di
Open

Remove Jacobian type instability#212
gdalle wants to merge 2 commits intobifurcationkit:masterfrom
gdalle:gd/di

Conversation

@gdalle
Copy link

@gdalle gdalle commented Mar 13, 2025

Here's an MWE removing the Jacobian type instability thanks to DifferentiationInterface's preparation mechanism

@rveltz
Copy link
Member

rveltz commented Mar 13, 2025

Hi,

Thanks a lot for the PR! You were fast....

I am not sure it removes type instability, at least for in place functions.

function TMvf!(dz, z, p, t = 0)
    (;J, α, E0, τ, τD, τF, U0) = p
    E, x, u = z
    SS0 = J * u * x * E + E0
    SS1 = α * log(1 + exp(SS0 / α))
    dz[1] = (-E + SS1) / τ
    dz[2] = (1 - x) / τD - u * x * E
    dz[3] = (U0 - u) / τF +  U0 * (1 - u) * E
    dz
end

par_tm == 1.5, τ = 0.013, J = 3.07, E0 = -2.0, τD = 0.200, U0 = 0.3, τF = 1.5, τS = 0.007)
z0 = [0.238616, 0.982747, 0.367876 ]

prob = BifurcationProblem(TMvf!, z0, par_tm, (@optic _.E0); record_from_solution = (x, p; k...) -> (E = x[1], x = x[2], u = x[3]))

BK.jacobian(prob, z0, par_tm)
@code_warntype BK.jacobian(prob, z0, par_tm)

returns

...
 p::@NamedTuple::Float64, τ::Float64, J::Float64, E0::Float64, τD::Float64, U0::Float64, τF::Float64, τS::Float64}
Body::Any
1 ─ %1 = Base.getproperty(pb, :VF)::BifFunction{BifurcationKit.var"#2350#2367"{typeof(TMvf!)}, typeof(TMvf!), BifurcationKit.var"#2354#2371", Nothing, BifurcationKit.var"#2352#2369"{DifferentiationInterfaceForwardDiffExt.ForwardDiffOneArgJacobianPrep{ForwardDiff.JacobianConfig{ForwardDiff.Tag{DifferentiationInterface.FixTail{BifurcationKit.var"#2350#2367"{typeof(TMvf!)}, Tuple{@NamedTuple{α::Float64, τ::Float64, J::Float64, E0::Float64, τD::Float64, U0::Float64, τF::Float64, τS::Float64}}}, Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DifferentiationInterface.FixTail{BifurcationKit.var"#2350#2367"{typeof(TMvf!)}, Tuple{@NamedTuple{α::Float64, τ::Float64, J::Float64, E0::Float64, τD::Float64, U0::Float64, τF::Float64, τS::Float64}}}, Float64}, Float64, 3}}}}, ADTypes.AutoForwardDiff{nothing, Nothing}}, Nothing, BifurcationKit.var"#2353#2370"{BifurcationKit.var"#2352#2369"{DifferentiationInterfaceForwardDiffExt.ForwardDiffOneArgJacobianPrep{ForwardDiff.JacobianConfig{ForwardDiff.Tag{DifferentiationInterface.FixTail{BifurcationKit.var"#2350#2367"{typeof(TMvf!)}, Tuple{@NamedTuple{α::Float64, τ::Float64, J::Float64, E0::Float64, τD::Float64, U0::Float64, τF::Float64, τS::Float64}}}, Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DifferentiationInterface.FixTail{BifurcationKit.var"#2350#2367"{typeof(TMvf!)}, Tuple{@NamedTuple{α::Float64, τ::Float64, J::Float64, E0::Float64, τD::Float64, U0::Float64, τF::Float64, τS::Float64}}}, Float64}, Float64, 3}}}}, ADTypes.AutoForwardDiff{nothing, Nothing}}}, BifurcationKit.var"#2357#2375"{BifurcationKit.var"#d1Fad#2373"}, BifurcationKit.var"#2359#2377", BifurcationKit.var"#2361#2379", BifurcationKit.var"#2363#2381", Bool, Float64, BifurcationKit.Jet{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}}
│   %2 = BifurcationKit.jacobian(%1, x, p)::Any
└──      return %2
...

@rveltz
Copy link
Member

rveltz commented Mar 13, 2025

and idem for out-of-place

TMvf(x,p) = TMvf!(similar(x),x,p)
prob = BifurcationProblem(TMvf, z0, par_tm, (@optic _.E0); record_from_solution = (x, p; k...) -> (E = x[1], x = x[2], u = x[3]))

@code_warntype BK.jacobian(prob, z0, par_tm)

@gdalle
Copy link
Author

gdalle commented Mar 13, 2025

I added the in-place version, and I also fixed the remaining type instabilities even though they were not my fault, because I'm a very nice person.
It was due to the infamous "variable captured by closure" inefficiency, see performance tips or Discourse for details.
I solved it by restructuring your definition of F and Finp so that there is only one assignment to these variables and not two.

@rveltz
Copy link
Member

rveltz commented Mar 13, 2025

because I'm a very nice person.

You are indeed :D

Thanks a lot for this. Ill merge when the tests pass but this is my work now ;)

@gdalle
Copy link
Author

gdalle commented Mar 13, 2025

And you can also use DI for first derivatives, second derivatives and JVPs/pushforwards. But not for higher-order derivatives, or directional second-order stuff with different directions.

@rveltz
Copy link
Member

rveltz commented Mar 13, 2025

Ah that's why this fails in my tests:

ForwardDiff.derivative(z -> apply(jacobian(prob, x0, set(parbif, lens, z)), ζ), p)

@rveltz
Copy link
Member

rveltz commented Mar 13, 2025

And do you also mean that I should keeps this in ForwardDiff d2F = (x, p, dx1, dx2) -> ForwardDiff.derivative(t -> d1Fad(x .+ t .* dx2, p, dx1), zero(eltype(dx1))) ?

@gdalle
Copy link
Author

gdalle commented Mar 13, 2025

Ah that's why this fails in my tests:

I don't understand what the code with apply does

And do you also mean that I should keeps this in ForwardDiff d2F = (x, p, dx1, dx2) -> ForwardDiff.derivative(t -> d1Fad(x .+ t .* dx2, p, dx1), zero(eltype(dx1))) ?

You can always replace every ForwardDiff.derivative with a DI.derivative, and you could even replace one of these derivatives with a DI.pushforward. What I meant is that there is no operator built into DI which does exactly what you want in one shot, without nesting.

@rveltz
Copy link
Member

rveltz commented Mar 13, 2025

apply applies a map to a vector hence apply(f,x) = f(x) but for matrices apply(J,x) = J*x. It can probably be eluded now by using jvp

@rveltz
Copy link
Member

rveltz commented Mar 13, 2025

Huh, that's a massive change to the library.
It's gona take some time to incorporate DI

@gdalle
Copy link
Author

gdalle commented Mar 13, 2025

you don't have to put it everywhere at once though

@gdalle
Copy link
Author

gdalle commented Mar 13, 2025

For instance, this PR alone seems like a net improvement?

@rveltz
Copy link
Member

rveltz commented Mar 14, 2025

Yes but the test do not pass.
I have been able to correct most errors but I block on the following:

At some point (codim2), I need to AD something jacobian(...) \ rhs
See https://github.com/bifurcationkit/BifurcationKit.jl/blob/master/test/testJacobianFoldDeflation.jl#L58

@gdalle
Copy link
Author

gdalle commented Mar 14, 2025

That test file is an MWE? If so, I'll take a look

@rveltz
Copy link
Member

rveltz commented Mar 14, 2025

const BK = BifurcationKit
function Lor(u, p, t = 0)
    (;α,β,γ,δ,G,F,T) = p
    X,Y,Z,U = u
    [
        -Y^2 - Z^2 - α*X + α*F - γ*U^2,
        X*Y - β*X*Z - Y + G,
        β*X*Y + X*Z - Z,
        -δ*U + γ*U*X + T
    ]
end

parlor == 1//4, β = 1, G = .25, δ = 1.04, γ = 0.987, F = 1.7620532879639, T = .0001265)

opts_br = ContinuationPar(p_min = -1.5, p_max = 3.0, ds = 0.001, dsmax = 0.025,
    # options to detect codim 1 bifurcations using bisection
    detect_bifurcation = 3,
    # Optional: bisection options for locating bifurcations
    n_inversion = 6, max_bisection_steps = 25,
    # number of eigenvalues
    nev = 4, max_steps = 252)

@reset opts_br.newton_options.max_iterations = 25

z0 =  [2.9787004394953343, -0.03868302503393752,  0.058232737694740085, -0.02105288273117459]

prob = BK.BifurcationProblem(Lor, z0, parlor, (@optic _.F))

br = @time continuation(re_make(prob, params = setproperties(parlor;T=0.04,F=3.)),
     PALC(tangent = Bordered()),
    opts_br;
    normC = norminf,
    bothside = true)

sn_codim2_test = continuation((@set br.alg.tangent = Secant()), 5, (@optic _.T), ContinuationPar(opts_br, p_max = 3.2, p_min = -0.1, detect_bifurcation = 1, dsmin=1e-5, ds = -0.001, dsmax = 0.005, n_inversion = 8, save_sol_every_step = 1, max_steps = 60) ;
    normC = norminf,
    detect_codim2_bifurcation = 0,
    update_minaug_every_step = 1,
    start_with_eigen = true,
    bdlinsolver = MatrixBLS(),
    )

@rveltz
Copy link
Member

rveltz commented Mar 14, 2025

I have tried using

jacobian(foldpb::FoldMAProblem{Tprob, AutoDiff, Tu0, Tp, Tl, Tplot, Trecord}, x, p) where {Tprob, Tu0, Tp, Tl <: Union{AllOpticTypes, Nothing}, Tplot, Trecord} = ForwardDiff.jacobian(z -> foldpb.prob(z, p), x)

# becomes =>
jacobian(foldpb::FoldMAProblem{Tprob, AutoDiff, Tu0, Tp, Tl, Tplot, Trecord}, x, p) where {Tprob, Tu0, Tp, Tl <: Union{AllOpticTypes, Nothing}, Tplot, Trecord} = DI.jacobian(foldpb.prob, DI.AutoForwardDiff(), x, DI.Constant(p))

@gdalle
Copy link
Author

gdalle commented Mar 14, 2025

I understand what happens here. Preparation, whether in DI or in ForwardDiff directly, is only valid when you reuse its result with inputs that have the same type. In this case, you prepare_jacobian with x::Vector{Float64} but then try to call jacobian with x::Vector{Dual{...}}. I don't have a great solution built into DI at the moment (it's a difficult problem), so you'd probably have to revamp your interface to accomodate this.

@rveltz
Copy link
Member

rveltz commented Mar 14, 2025

Exactly. I notice that if I remove the prep pretty much everything goes well except a minor point I will discuss.
So either I ditch prep and we lose the type stability benefit but gain the niceties from DI or we keep prep but I have no idea how to tackle what you mention

@rveltz
Copy link
Member

rveltz commented Mar 14, 2025

Do you understand why this does not work?

par = (a=1., b=0.)
prob = BifurcationProblem((x,p)->[x[1]^2+p[1],p.a + sum(x)], rand(2), par, (@optic _.a))
DI.derivative(z -> BK.dF(prob, prob.u0, set(prob.params, prob.lens, z), _dx), prob.VF.ad_backend, 0.)

where in Problems.jl I defined

 dF = (x, p, dx) -> DI.pushforward(F, ad_backend, x, (dx,), DI.Constant(p))[1]

The reported error is:

ERROR: Cannot determine ordering of Dual tags ForwardDiff.Tag{var"#529#530", Float64} and ForwardDiff.Tag{var"#535#536", Float64}
Stacktrace:
  [1] partials
    @ ~/.julia/packages/ForwardDiff/UBbGT/src/dual.jl:115 [inlined]
  [2] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
  [3] _broadcast_getindex
    @ ./broadcast.jl:692 [inlined]
  [4] getindex

@gdalle
Copy link
Author

gdalle commented Mar 14, 2025

where in Problems.jl I defined

Where is the definition of dF in Problems.jl?

@rveltz
Copy link
Member

rveltz commented Mar 14, 2025

In your commit:

F = if inplace || _isinplace(_F)
                    (x, p) -> _F(similar(x), x, p)
                else
                    _F
                end

@gdalle
Copy link
Author

gdalle commented Mar 14, 2025

That is F, not dF, right?

@rveltz
Copy link
Member

rveltz commented Mar 14, 2025

@gdalle
Copy link
Author

gdalle commented Mar 14, 2025

Honestly it would be easier if you pointed me to a branch where I can run your MWEs, instead of me figuring out what I should change in this branch to imitate you

@rveltz
Copy link
Member

rveltz commented Mar 14, 2025

I just added a branch called DI

@rveltz
Copy link
Member

rveltz commented Mar 14, 2025

OK so what is surprsing to me is that the tests fpr testNF_maps.jl do not pass. I dont get the error.

@rveltz
Copy link
Member

rveltz commented Mar 16, 2025

Hi,

I have pushed my whole stack, this should correspond to the state of my computer. The error remains in test/testNF_maps.jl.

Also, if you have a suggestion about

I don't have a great solution built into DI at the moment (it's a difficult problem), so you'd probably have to revamp your interface to accomodate this.

I am interested.

I could for example add options in the constructor like use_ad_preparation = false

@gdalle
Copy link
Author

gdalle commented Mar 17, 2025

I don't know where the tag ordering issue comes from. If your function contains parallelism, it can be due to JuliaDiff/ForwardDiff.jl#320 but otherwise I'm rather clueless.

I could for example add options in the constructor like use_ad_preparation = false

That's definitely an option. And if you do want to prepare, you need to be able to anticipate which types will be provided to your function, including during differentiation. I have started to add some utilities to DI for this, like here, but it's very ForwardDiff-specific and not yet part of the public API.

@rveltz
Copy link
Member

rveltz commented Mar 17, 2025

I don't know where the tag ordering issue comes from. If your function contains parallelism, it can be due to JuliaDiff/ForwardDiff.jl#320 but otherwise I'm rather clueless.

Should I open an issue? The MWE is quite simple and do not rely on BK

@gdalle
Copy link
Author

gdalle commented Mar 17, 2025

If you have a simple MWE that doesn't rely on BK, can you post it here first? There's a change it might be a bug in DI

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants