Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions flow_matching/solver/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def sample(
time_grid: Tensor = torch.tensor([0.0, 1.0]),
return_intermediates: bool = False,
enable_grad: bool = False,
log_p0: Optional[Callable[[Tensor], Tensor]] = None,
exact_divergence: bool = False,
**model_extras,
) -> Union[Tensor, Sequence[Tensor]]:
r"""Solve the ODE with the velocity field.
Expand Down Expand Up @@ -73,6 +75,8 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
log_p0 (Optional[Callable[[Tensor], Tensor]]): If provided, the function computes the log likelihood of the source distribution at :math:`t=0`. The velocity model must be differentiable with respect to x.
exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator.
**model_extras: Additional input for the model.

Returns:
Expand All @@ -81,27 +85,65 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:

time_grid = time_grid.to(x_init.device)

# Fix the random projection for the Hutchinson divergence estimator
if not exact_divergence:
z = (torch.randn_like(x_init).to(x_init.device) < 0) * 2.0 - 1.0

def ode_func(t, x):
return self.velocity_model(x=x, t=t, **model_extras)

def dynamics_func(t, states):
xt = states[0]
with torch.set_grad_enabled(True):
xt.requires_grad_()
ut = ode_func(t, xt)

# Compute exact divergence
if exact_divergence:
div = 0
for i in range(ut.flatten(1).shape[1]):
div += gradient(ut[:, i], xt, create_graph=True)[:, i].detach()
else:
# Compute Hutchinson divergence estimator E[z^T D_x(ut) z]
ut_dot_z = torch.einsum(
"ij,ij->i", ut.flatten(start_dim=1), z.flatten(start_dim=1)
)
grad_ut_dot_z = gradient(ut_dot_z, xt)
div = torch.einsum(
"ij,ij->i",
grad_ut_dot_z.flatten(start_dim=1),
z.flatten(start_dim=1),
)

return ut.detach(), div.detach()

ode_opts = {"step_size": step_size} if step_size is not None else {}

with torch.set_grad_enabled(enable_grad):
# Approximate ODE solution with numerical ODE solver
sol = odeint(
ode_func,
x_init,
ode_func if log_p0 is None else dynamics_func,
(
x_init
if log_p0 is None
else (x_init, torch.zeros(x_init.shape[0], device=x_init.device))
),
time_grid,
method=method,
options=ode_opts,
atol=atol,
rtol=rtol,
)

if return_intermediates:
return sol
else:
return sol[-1]
if log_p0 is not None:
sol, log_det = sol
log_likelihood = log_p0(x_init) - log_det[-1]
return (
(sol, log_likelihood)
if return_intermediates
else (sol[-1], log_likelihood)
)
return sol if return_intermediates else sol[-1]

def compute_likelihood(
self,
Expand Down Expand Up @@ -181,7 +223,7 @@ def dynamics_func(t, states):
sol, log_det = odeint(
dynamics_func,
y_init,
time_grid,
time_grid.to(x_1.device),
method=method,
options=ode_opts,
atol=atol,
Expand Down
81 changes: 81 additions & 0 deletions tests/solver/test_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,87 @@ def dummy_log_p(x: Tensor) -> Tensor:
torch.allclose(x_1.grad, torch.tensor([1.0, 1.0]), atol=1e-2),
)

def test_sample_with_likelihoods(self):
x_0 = torch.tensor([[1.0, 0.0]])
step_size = 0.001

# Define a dummy log probability function
def dummy_log_p(x: Tensor) -> Tensor:
return -0.5 * torch.sum(x**2, dim=1)

result, log_likelihood = self.dummy_solver.sample(
x_init=x_0,
step_size=step_size,
log_p0=dummy_log_p,
exact_divergence=True,
)
self.assertIsInstance(log_likelihood, Tensor)
self.assertEqual(x_0.shape[0], log_likelihood.shape[0])

self.assertTrue(
torch.allclose(torch.tensor([2.0, 1.0]), result, atol=1e-2),
)

def test_forward_backward_likelihoods_exact(self):
x_0 = torch.tensor([[1.0, 0.0]])
step_size = 0.001

# Define a dummy log probability function
def dummy_log_p(x: Tensor) -> Tensor:
return -0.5 * torch.sum(x**2, dim=1)

x1, forward_log_likelihood = self.dummy_solver.sample(
x_init=x_0,
step_size=step_size,
log_p0=dummy_log_p,
exact_divergence=True,
)

self.assertIsInstance(forward_log_likelihood, Tensor)
self.assertEqual(x_0.shape[0], forward_log_likelihood.shape[0])

# Check if the post-hoc likelihoods match the original log likelihoods
_, backward_log_likelihood = self.dummy_solver.compute_likelihood(
x_1=x1,
log_p0=dummy_log_p,
step_size=step_size,
exact_divergence=True,
)

self.assertTrue(
torch.allclose(forward_log_likelihood, backward_log_likelihood, atol=1e-2),
)

def test_forward_backward_likelihoods(self):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for the PR. Can you also add a test that compares the output of the forward compute_likelihood and the backward compute_likelihood to make sure they match?

Is this sufficient? It computes the forward likelihood, while generating samples. Then using those samples to compute the backward likelihood and compare them.

The odesolver.compute_likelihood does not allow a forward computation, as it also would't make sense without x1 samples?!

Maybe I am misunderstanding sth. here ...

x_0 = torch.tensor([[1.0, 0.0]])
step_size = 0.001

# Define a dummy log probability function
def dummy_log_p(x: Tensor) -> Tensor:
return -0.5 * torch.sum(x**2, dim=1)

x1, forward_log_likelihood = self.dummy_solver.sample(
x_init=x_0,
step_size=step_size,
log_p0=dummy_log_p,
exact_divergence=False,
)

self.assertIsInstance(forward_log_likelihood, Tensor)
self.assertEqual(x_0.shape[0], forward_log_likelihood.shape[0])

# Check if the post-hoc likelihoods match the original log likelihoods
_, backward_log_likelihood = self.dummy_solver.compute_likelihood(
x_1=x1,
log_p0=dummy_log_p,
step_size=step_size,
exact_divergence=False,
)

self.assertTrue(
torch.allclose(forward_log_likelihood, backward_log_likelihood, atol=1e-2),
)


if __name__ == "__main__":
unittest.main()
Loading