jgot solves two-endpoint dynamic optimal transport on sparse reversible graphs
using JAX, following the time-discrete formulation of
Erbar et al. (2020).
PyPI package:
Core library:
pip install jgotExamples (plotting dependencies included):
pip install "jgot[examples]"Development environment:
uv sync --group devimport jax.numpy as jnp
from jgot import (
GraphSpec,
LogMeanOps,
OTConfig,
OTProblem,
TimeDiscretization,
solve_ot,
)
graph = GraphSpec.from_undirected_weights(
num_nodes=2,
edge_u=[0],
edge_v=[1],
weight=[1.0],
)
mass_a = jnp.array([1.0, 0.0])
mass_b = jnp.array([0.0, 1.0])
rho_a = mass_a / graph.pi
rho_b = mass_b / graph.pi
problem = OTProblem(
graph=graph,
time=TimeDiscretization(num_steps=64),
rho_a=rho_a,
rho_b=rho_b,
mean_ops=LogMeanOps(),
)
sol = solve_ot(problem, OTConfig())
print(float(sol.distance), sol.converged, sol.iterations_used)Important:
- Densities are represented with respect to
pi. - Endpoints must satisfy
sum(pi * rho) == 1.
Detailed docs live under docs.
Recommended starting points:
- Getting Started
- Graph Model
- API Reference
- Examples Guide
- Debugging and Diagnostics
- Numerical Limitations
Runnable scripts:
examples/two_node_benchmark/run.pyexamples/cycle_neighbor_transport/run.pyexamples/line_chain_transport/run.pyexamples/directed_reversible_transport/run.pyexamples/large_grid_transport/run.py
See examples/README.md for commands and outputs.