Skip to content

Invalid t shape for data with ndim > 1 in GeodesicProbPath.sample() #73

@LukasSchweizer

Description

@LukasSchweizer

Hi and thanks a lot for your amazing work on this library.

Describe the bug
After upgrading from version 1.0.9 to 1.0.10, a RuntimeError occurs when calling GeodesicProbPath.sample() with high-dimensional input (i.e., inputs with shape > 2D including batch dimension).

RuntimeError: einsum(): the number of subscripts in the equation
(1) does not match the number of dimensions (2) for operand 0 and no ellipsis was given

To Reproduce

from flow_matching.path import GeodesicProbPath, PathSample
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.utils.manifolds import Euclidean
import torch

batch_size = 128
data_dim = (16, 4)

x_0 = torch.randn((batch_size, *data_dim))      # (128, 16, 4)
x_1 = torch.randn((batch_size, *data_dim))      # (128, 16, 4)
t = torch.linspace(0, 1, batch_size)  # (128)

manifold = Euclidean()
scheduler = CondOTScheduler()
path = GeodesicProbPath(scheduler, manifold)

sample: PathSample = path.sample(x_0=x_0, x_1=x_1, t=t)
# RuntimeError: einsum(): the number of subscripts in the equation
#   (1) does not match the number of dimensions (2) for operand 0 and no ellipsis was given

Expected behavior
The function should support data with arbitrary trailing dimensions (e.g., (batch_size, D1, D2)), not just 2D inputs.

Thank you in advance
Lukas

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions