This repository accompanies the reversible solver method introduced here.
The reversible method is implemented in diffrax. This is a work in progress - see the fork here. To install and checkout to the arxiv branch, run
git clone https://github.com/sammccallum/diffrax.git
pip install -e diffrax
cd diffrax/
git checkout arxivThe reversible solvers can be used by passing adjoint=diffrax.ReversibleAdjoint() to diffrax.diffeqsolve:
import jax.numpy as jnp
import diffrax
vf = lambda t, y, args: y
y0 = jnp.array([1.0])
term = diffrax.ODETerm(vf)
solver = diffrax.Tsit5()
sol = diffrax.diffeqsolve(
term,
solver,
t0=0,
t1=5,
dt0=0.01,
y0=y0,
adjoint=diffrax.ReversibleAdjoint(),
)The base solver diffrax.Tsit5() will be automatically wrapped into a reversible version and gradient calculation will follow the reversible backpropagation algorithm.
The experiments presented in the paper can be found in the experiments directory. The experiments require an installation of the reversible and diffrax libraries. To install, run
git clone https://github.com/sammccallum/reversible-solvers.git
pip install -e reversible
git clone https://github.com/sammccallum/diffrax.git
pip install -e diffrax
cd diffrax
git checkout arxivNote that the arxiv branch in diffrax contains the archived code used to run the experiments.