@@ -374,7 +374,13 @@ def psd_inv(matrix: Array) -> Array:
374374 identity = jnp .eye (matrix .shape [0 ], dtype = matrix .dtype )
375375 return linalg .solve (matrix , identity , assume_a = "pos" )
376376 else :
377- return linalg .inv (matrix )
377+ # Cuda's LU solver will go into an infinite loop if the matrix has NaNs or
378+ # possibly Infs, so we need to check for that before calling it.
379+ return lax .cond (
380+ jnp .logical_or (jnp .any (jnp .isnan (matrix )), jnp .any (jnp .isinf (matrix ))),
381+ lambda : jnp .full (matrix .shape , jnp .nan , dtype = matrix .dtype ),
382+ lambda : linalg .inv (matrix ),
383+ )
378384
379385
380386def psd_solve (matrix : Array , vector : Array ) -> Array :
@@ -385,9 +391,14 @@ def psd_solve(matrix: Array, vector: Array) -> Array:
385391
386392 if get_use_cholesky_inversion ():
387393 return linalg .solve (matrix , vector , assume_a = "pos" )
388-
389394 else :
390- return linalg .solve (matrix , vector )
395+ # Cuda's LU solver will go into an infinite loop if the matrix has NaNs or
396+ # possibly Infs, so we need to check for that before calling it.
397+ return lax .cond (
398+ jnp .logical_or (jnp .any (jnp .isnan (matrix )), jnp .any (jnp .isinf (matrix ))),
399+ lambda : jnp .full (vector .shape , jnp .nan , dtype = vector .dtype ),
400+ lambda : linalg .solve (matrix , vector ),
401+ )
391402
392403
393404def psd_solve_without_last_idx (a : Array , b : Array ) -> Array :
0 commit comments