From f28558a4c345bf0b6b36d94c35534f38fbf8f335 Mon Sep 17 00:00:00 2001 From: timonpalm Date: Thu, 31 Jul 2025 15:36:31 +0200 Subject: [PATCH 1/2] t on right device --- flow_matching/solver/ode_solver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flow_matching/solver/ode_solver.py b/flow_matching/solver/ode_solver.py index 89975064..1c2bbd8b 100644 --- a/flow_matching/solver/ode_solver.py +++ b/flow_matching/solver/ode_solver.py @@ -181,7 +181,7 @@ def dynamics_func(t, states): sol, log_det = odeint( dynamics_func, y_init, - time_grid, + time_grid.to(x_1.device), method=method, options=ode_opts, atol=atol, From 4404acb9e640c098534cd1d227494c67e1e7ebb2 Mon Sep 17 00:00:00 2001 From: timonpalm Date: Fri, 8 Aug 2025 16:36:32 +0200 Subject: [PATCH 2/2] compute likelihood in sampling --- flow_matching/solver/ode_solver.py | 54 +++++++++++++++++--- tests/solver/test_ode_solver.py | 81 ++++++++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 6 deletions(-) diff --git a/flow_matching/solver/ode_solver.py b/flow_matching/solver/ode_solver.py index c924c053..9e6df8cd 100644 --- a/flow_matching/solver/ode_solver.py +++ b/flow_matching/solver/ode_solver.py @@ -37,6 +37,8 @@ def sample( time_grid: Tensor = torch.tensor([0.0, 1.0]), return_intermediates: bool = False, enable_grad: bool = False, + log_p0: Optional[Callable[[Tensor], Tensor]] = None, + exact_divergence: bool = False, **model_extras, ) -> Union[Tensor, Sequence[Tensor]]: r"""Solve the ODE with the velocity field. @@ -73,6 +75,8 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]). return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False. enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False. + log_p0 (Optional[Callable[[Tensor], Tensor]]): If provided, the function computes the log likelihood of the source distribution at :math:`t=0`. The velocity model must be differentiable with respect to x. + exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator. **model_extras: Additional input for the model. Returns: @@ -81,16 +85,49 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: time_grid = time_grid.to(x_init.device) + # Fix the random projection for the Hutchinson divergence estimator + if not exact_divergence: + z = (torch.randn_like(x_init).to(x_init.device) < 0) * 2.0 - 1.0 + def ode_func(t, x): return self.velocity_model(x=x, t=t, **model_extras) + def dynamics_func(t, states): + xt = states[0] + with torch.set_grad_enabled(True): + xt.requires_grad_() + ut = ode_func(t, xt) + + # Compute exact divergence + if exact_divergence: + div = 0 + for i in range(ut.flatten(1).shape[1]): + div += gradient(ut[:, i], xt, create_graph=True)[:, i].detach() + else: + # Compute Hutchinson divergence estimator E[z^T D_x(ut) z] + ut_dot_z = torch.einsum( + "ij,ij->i", ut.flatten(start_dim=1), z.flatten(start_dim=1) + ) + grad_ut_dot_z = gradient(ut_dot_z, xt) + div = torch.einsum( + "ij,ij->i", + grad_ut_dot_z.flatten(start_dim=1), + z.flatten(start_dim=1), + ) + + return ut.detach(), div.detach() + ode_opts = {"step_size": step_size} if step_size is not None else {} with torch.set_grad_enabled(enable_grad): # Approximate ODE solution with numerical ODE solver sol = odeint( - ode_func, - x_init, + ode_func if log_p0 is None else dynamics_func, + ( + x_init + if log_p0 is None + else (x_init, torch.zeros(x_init.shape[0], device=x_init.device)) + ), time_grid, method=method, options=ode_opts, @@ -98,10 +135,15 @@ def ode_func(t, x): rtol=rtol, ) - if return_intermediates: - return sol - else: - return sol[-1] + if log_p0 is not None: + sol, log_det = sol + log_likelihood = log_p0(x_init) - log_det[-1] + return ( + (sol, log_likelihood) + if return_intermediates + else (sol[-1], log_likelihood) + ) + return sol if return_intermediates else sol[-1] def compute_likelihood( self, diff --git a/tests/solver/test_ode_solver.py b/tests/solver/test_ode_solver.py index 85259fdc..a770dc64 100644 --- a/tests/solver/test_ode_solver.py +++ b/tests/solver/test_ode_solver.py @@ -185,6 +185,87 @@ def dummy_log_p(x: Tensor) -> Tensor: torch.allclose(x_1.grad, torch.tensor([1.0, 1.0]), atol=1e-2), ) + def test_sample_with_likelihoods(self): + x_0 = torch.tensor([[1.0, 0.0]]) + step_size = 0.001 + + # Define a dummy log probability function + def dummy_log_p(x: Tensor) -> Tensor: + return -0.5 * torch.sum(x**2, dim=1) + + result, log_likelihood = self.dummy_solver.sample( + x_init=x_0, + step_size=step_size, + log_p0=dummy_log_p, + exact_divergence=True, + ) + self.assertIsInstance(log_likelihood, Tensor) + self.assertEqual(x_0.shape[0], log_likelihood.shape[0]) + + self.assertTrue( + torch.allclose(torch.tensor([2.0, 1.0]), result, atol=1e-2), + ) + + def test_forward_backward_likelihoods_exact(self): + x_0 = torch.tensor([[1.0, 0.0]]) + step_size = 0.001 + + # Define a dummy log probability function + def dummy_log_p(x: Tensor) -> Tensor: + return -0.5 * torch.sum(x**2, dim=1) + + x1, forward_log_likelihood = self.dummy_solver.sample( + x_init=x_0, + step_size=step_size, + log_p0=dummy_log_p, + exact_divergence=True, + ) + + self.assertIsInstance(forward_log_likelihood, Tensor) + self.assertEqual(x_0.shape[0], forward_log_likelihood.shape[0]) + + # Check if the post-hoc likelihoods match the original log likelihoods + _, backward_log_likelihood = self.dummy_solver.compute_likelihood( + x_1=x1, + log_p0=dummy_log_p, + step_size=step_size, + exact_divergence=True, + ) + + self.assertTrue( + torch.allclose(forward_log_likelihood, backward_log_likelihood, atol=1e-2), + ) + + def test_forward_backward_likelihoods(self): + x_0 = torch.tensor([[1.0, 0.0]]) + step_size = 0.001 + + # Define a dummy log probability function + def dummy_log_p(x: Tensor) -> Tensor: + return -0.5 * torch.sum(x**2, dim=1) + + x1, forward_log_likelihood = self.dummy_solver.sample( + x_init=x_0, + step_size=step_size, + log_p0=dummy_log_p, + exact_divergence=False, + ) + + self.assertIsInstance(forward_log_likelihood, Tensor) + self.assertEqual(x_0.shape[0], forward_log_likelihood.shape[0]) + + # Check if the post-hoc likelihoods match the original log likelihoods + _, backward_log_likelihood = self.dummy_solver.compute_likelihood( + x_1=x1, + log_p0=dummy_log_p, + step_size=step_size, + exact_divergence=False, + ) + + self.assertTrue( + torch.allclose(forward_log_likelihood, backward_log_likelihood, atol=1e-2), + ) + if __name__ == "__main__": unittest.main()