Skip to content

Question on Riemannian probability path with geodesic interpolation #75

@johannespitz

Description

@johannespitz

Based on my understanding, assuming a linear time schedule, the optimization target dx_t = "u_t(x|x1)" should simply be manifold.logmap(x_0, x_1).
However, in this repository you implement

        def cond_u(x_0, x_1, t):
            path = geodesic(self.manifold, x_0, x_1)
            x_t, dx_t = jvp(
                lambda t: path(self.scheduler(t).alpha_t),
                (t,),
                (torch.ones_like(t).to(t),),
            )
            return x_t, dx_t

        x_t, dx_t = vmap(cond_u)(x_0, x_1, t)

I checked, and the results are different and your (the correct) target works better. Interestingly manifold.logmap(x_0, x_1) also works for my basic test setup but not quite as good.

Could someone help me understand why it is more complicated to compute dx_t and also what exactly the jvp does compute?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions