Skip to content

LvDAO/JaxGraphOT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JGOT (JAX Graph Optimal Transport)

PyPI

jgot solves two-endpoint dynamic optimal transport on sparse reversible graphs using JAX, following the time-discrete formulation of Erbar et al. (2020).

Install

PyPI package:

Core library:

pip install jgot

Examples (plotting dependencies included):

pip install "jgot[examples]"

Development environment:

uv sync --group dev

Minimal Example

import jax.numpy as jnp
from jgot import (
    GraphSpec,
    LogMeanOps,
    OTConfig,
    OTProblem,
    TimeDiscretization,
    solve_ot,
)

graph = GraphSpec.from_undirected_weights(
    num_nodes=2,
    edge_u=[0],
    edge_v=[1],
    weight=[1.0],
)

mass_a = jnp.array([1.0, 0.0])
mass_b = jnp.array([0.0, 1.0])
rho_a = mass_a / graph.pi
rho_b = mass_b / graph.pi

problem = OTProblem(
    graph=graph,
    time=TimeDiscretization(num_steps=64),
    rho_a=rho_a,
    rho_b=rho_b,
    mean_ops=LogMeanOps(),
)

sol = solve_ot(problem, OTConfig())
print(float(sol.distance), sol.converged, sol.iterations_used)

Important:

  • Densities are represented with respect to pi.
  • Endpoints must satisfy sum(pi * rho) == 1.

Documentation

Detailed docs live under docs.

Recommended starting points:

Examples

Runnable scripts:

  • examples/two_node_benchmark/run.py
  • examples/cycle_neighbor_transport/run.py
  • examples/line_chain_transport/run.py
  • examples/directed_reversible_transport/run.py
  • examples/large_grid_transport/run.py

See examples/README.md for commands and outputs.

About

JAX Graph Optimal Transport Package

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages