Skip to content

Variable sized batches #1154

@pmrv

Description

@pmrv

We are trying to solve the following problem in grain. Our data for our JAX models requires postprocessing before it can be fed to the models and this postprocessing produces variable sized data. Because of this we batch and pad multiple input items together to achieve batches of the same size to avoid recompilation. The postprocessed data is much larger than the original input items.
Our problem is now that the Batched transformation for DataLoader in grain only allows to specify a fixed number of input items for a fixed batch size, whereas we would rather ingest a variable number of input items until we can reach a given output batch size with minimal padding. Batched doesn't seem to support this. Is this something that could be done with the DataSet API or do you see another way of mapping the problem onto the DataLoader?

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:supportFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions