-
Notifications
You must be signed in to change notification settings - Fork 13
Description
I encountered this paper and it claims to achive JAX, I wrote a simple torchdeq implementation for it and tested with image classification task.
Paper: https://arxiv.org/abs/2509.12917
Official JAX impelementation: https://github.com/sammccallum/reversible-deq ; https://github.com/sammccallum/revdeq
My implementation adds these contents to existing codebase: main...sablin39:revdeq
I add an additional RevDEQ inhereting DEQBase class in core.py and a new grad function for forward and backward pass of it in revdeq_grad.py.
To replicate the experiments in paper, I add a revdeq folder in deq-zoo, which contains implicit resnet implementations to test on fixed-scale CIFAR10/CIFAR100 image classification. It can also be configured to use DEQSliced and DEQIndexing by changing the "core" key in config. In my experiments I set f_tol=b_tol=tol and f_max_iter=b_max_iter=steps and gets similar accuracy. DEQSliced and DEQIndexing occupy ~11GB while RevDEQ occupies only ~4.6GB in batch_size=1024 and precision=bf16-mixed.
I haven't finished the language modeling experiment due to insufficient GPU :( .
Open Questions
-
RevDEQuses a different forward and backward pass thanDEQSlicedandDEQIndexing, and I choose to implement both forward and backward pass together in a single grad fuction. I'm not sure if I shall decouple it and follows your pattern inDEQSlicedandDEQIndexing. -
I reused the parameter keys in
DEQConfigfor compatibility, and there's an extra relaxation parameterbeta(which I once wanted to reusetaukey to obtain). I'm not sure if this is appropriate. -
Obviously the docstring needs further revision. I've added simple descriptions for required param of
RevDEQ, but I am not sure what else to include in it for a better showcase in readthedocs.