Skip to content

improve readability #85

@genji970

Description

@genji970

In flow_matching/tests/solver/test_ode_solver.py,

in 23th line,

class ConstantVelocityModel(ModelWrapper):
    def __init__(self):
        super().__init__(None)
        self.a = torch.nn.Parameter(torch.tensor(1.0))

    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
        return x * 0.0 + self.a

for enhancing readability, i suggest

return x * 0.0 + self.a -> return torch.ones_like(x) * self.a

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions