-
Notifications
You must be signed in to change notification settings - Fork 61
Description
We are trying to solve the following problem in grain. Our data for our JAX models requires postprocessing before it can be fed to the models and this postprocessing produces variable sized data. Because of this we batch and pad multiple input items together to achieve batches of the same size to avoid recompilation. The postprocessed data is much larger than the original input items.
Our problem is now that the Batched transformation for DataLoader in grain only allows to specify a fixed number of input items for a fixed batch size, whereas we would rather ingest a variable number of input items until we can reach a given output batch size with minimal padding. Batched doesn't seem to support this. Is this something that could be done with the DataSet API or do you see another way of mapping the problem onto the DataLoader?