Skip to content

Add alternative to batching #200

@MaxBlesch

Description

@MaxBlesch

The package so far utilizes a custom made batch algorithm. The steps of the algorithm are displayed in the terminal when the model is set up.

Illustrative example. Imagine a 5 period model. The number of state_choices in each period are given by:

Period 0: 2
Period 1: 4
Period 2: 6
Period 3: 7
Period 4: 8

Periods 3 and 4 are solved manually. The other three are solved with the backward induction. For the purpose of fast jax code we want equally sized bins to compute at every step. The batch algorithm now finds these bins, such that the child states of each state are solved in a previous computed bin.

The algorithm would start with a binsize of 6, i.e. two steps. This fails as in the second step Period 1 and 0 would be solved at the same time. This is not possible. So the algorithm reduces the size to 5 and tries again. And so on. In this setup the algorithm pretty surely converges at 2.

It is not a priori clear that this is faster than just doing 6, then 4 and then 2. In the past we ran into RAM problems doing this for larger models, because the computational graph has to be completely rolled out and optimized for each of these operations, thats why we have batching. However, we might want to give the user the option to do that?

An alternative could be to allow the user to just set the binsize to the maximum state_choice in a period, here 6 and fill the rest with dummy states. In combination with "min_period_batch_segments" where one tells the batch algorithm to find different equally sized bins for parts of the period, this might work even faster!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions