Currently, MPAX uses isinstance to distinguish between dense and sparse matrices. This could be improved by using jax.experimental.sparse.sparsify, which provides a more general and composable way to handle both types.
This depends on how mature the support is, but it’s worth exploring.
Related JAX issue: jax-ml/jax#28749