Skip to content

Commit bc000c6

Browse files
james-martensKfacJaxDev
authored andcommitted
Adding NaN/Inf guard on call to matrix inverses/solves since LU decomp on GPU can cause an infinite loop when the matrix has these values.
PiperOrigin-RevId: 701556748
1 parent face046 commit bc000c6

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

kfac_jax/_src/utils/math.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,13 @@ def psd_inv(matrix: Array) -> Array:
374374
identity = jnp.eye(matrix.shape[0], dtype=matrix.dtype)
375375
return linalg.solve(matrix, identity, assume_a="pos")
376376
else:
377-
return linalg.inv(matrix)
377+
# Cuda's LU solver will go into an infinite loop if the matrix has NaNs or
378+
# possibly Infs, so we need to check for that before calling it.
379+
return lax.cond(
380+
jnp.logical_or(jnp.any(jnp.isnan(matrix)), jnp.any(jnp.isinf(matrix))),
381+
lambda: jnp.full(matrix.shape, jnp.nan, dtype=matrix.dtype),
382+
lambda: linalg.inv(matrix),
383+
)
378384

379385

380386
def psd_solve(matrix: Array, vector: Array) -> Array:
@@ -385,9 +391,14 @@ def psd_solve(matrix: Array, vector: Array) -> Array:
385391

386392
if get_use_cholesky_inversion():
387393
return linalg.solve(matrix, vector, assume_a="pos")
388-
389394
else:
390-
return linalg.solve(matrix, vector)
395+
# Cuda's LU solver will go into an infinite loop if the matrix has NaNs or
396+
# possibly Infs, so we need to check for that before calling it.
397+
return lax.cond(
398+
jnp.logical_or(jnp.any(jnp.isnan(matrix)), jnp.any(jnp.isinf(matrix))),
399+
lambda: jnp.full(vector.shape, jnp.nan, dtype=vector.dtype),
400+
lambda: linalg.solve(matrix, vector),
401+
)
391402

392403

393404
def psd_solve_without_last_idx(a: Array, b: Array) -> Array:

0 commit comments

Comments
 (0)