We are currently stuck on cuda11.8. The main reason is that to futureproof, we want this repo to be compatible with stable torch and jax at the same time. However, torch's cuda version is pinned to 12.1 while jax requires at least 12.2, which makes installing from pip challenging (it's possible to get the installations to play nicely installing from source, but this makes using this package cumbersome).
We are currently stuck on
cuda11.8. The main reason is that to futureproof, we want this repo to be compatible with stabletorchandjaxat the same time. However,torch'scudaversion is pinned to 12.1 whilejaxrequires at least12.2, which makes installing frompipchallenging (it's possible to get the installations to play nicely installing from source, but this makes using this package cumbersome).