Skip to content

[New Feature] torchdeq implementation for *Reversible Deep Equilibrium Models* #9

@sablin39

Description

@sablin39

I encountered this paper and it claims to achive $O(1)$ memory complexity for both forward and backward pass by not requiring to save computation graph of backward process, which saves lots of VRAM for training. Since the official codebase is in JAX, I wrote a simple torchdeq implementation for it and tested with image classification task.

Paper: https://arxiv.org/abs/2509.12917
Official JAX impelementation: https://github.com/sammccallum/reversible-deq ; https://github.com/sammccallum/revdeq

My implementation adds these contents to existing codebase: main...sablin39:revdeq

I add an additional RevDEQ inhereting DEQBase class in core.py and a new grad function for forward and backward pass of it in revdeq_grad.py.

To replicate the experiments in paper, I add a revdeq folder in deq-zoo, which contains implicit resnet implementations to test on fixed-scale CIFAR10/CIFAR100 image classification. It can also be configured to use DEQSliced and DEQIndexing by changing the "core" key in config. In my experiments I set f_tol=b_tol=tol and f_max_iter=b_max_iter=steps and gets similar accuracy. DEQSliced and DEQIndexing occupy ~11GB while RevDEQ occupies only ~4.6GB in batch_size=1024 and precision=bf16-mixed.

I haven't finished the language modeling experiment due to insufficient GPU :( .

Open Questions

  • RevDEQ uses a different forward and backward pass than DEQSliced and DEQIndexing, and I choose to implement both forward and backward pass together in a single grad fuction. I'm not sure if I shall decouple it and follows your pattern in DEQSliced and DEQIndexing.

  • I reused the parameter keys in DEQConfig for compatibility, and there's an extra relaxation parameter beta (which I once wanted to reuse tau key to obtain). I'm not sure if this is appropriate.

  • Obviously the docstring needs further revision. I've added simple descriptions for required param of RevDEQ, but I am not sure what else to include in it for a better showcase in readthedocs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions