Skip to content

Commit cf25c79

Browse files
Incompatible-with-solver errors no longer masking other underlying ValueErrors
1 parent 6b20ef5 commit cf25c79

2 files changed

Lines changed: 36 additions & 7 deletions

File tree

diffrax/_integrate.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def _is_none(x: Any) -> bool:
114114
return x is None
115115

116116

117+
class TermAndSolverIncompatible(ValueError):
118+
pass
119+
120+
117121
def _assert_term_compatible(
118122
t: FloatScalarLike,
119123
y: PyTree[ArrayLike],
@@ -137,7 +141,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
137141
):
138142
_assert_term_compatible(t, yi, args, term, arg, term_contr_kwarg)
139143
else:
140-
raise ValueError(
144+
raise TermAndSolverIncompatible(
141145
f"Term {term} is not a MultiTerm but is expected to be."
142146
)
143147
else:
@@ -147,7 +151,9 @@ def _check(term_cls, term, term_contr_kwargs, yi):
147151
if origin_cls is None:
148152
origin_cls = term_cls
149153
if not isinstance(term, origin_cls):
150-
raise ValueError(f"Term {term} is not an instance of {origin_cls}.")
154+
raise TermAndSolverIncompatible(
155+
f"Term {term} is not an instance of {origin_cls}."
156+
)
151157

152158
# Now check the generic parametrization of `term_cls`; can be one of:
153159
# -----------------------------------------
@@ -167,7 +173,9 @@ def _check(term_cls, term, term_contr_kwargs, yi):
167173
better_isinstance, vf_type, vf_type_expected
168174
)
169175
if not vf_type_compatible:
170-
raise ValueError(f"Vector field term {term} is incompatible.")
176+
raise TermAndSolverIncompatible(
177+
f"Vector field term {term} is incompatible."
178+
)
171179

172180
contr = ft.partial(term.contr, **term_contr_kwargs)
173181
# Work around https://github.com/google/jax/issues/21825
@@ -176,7 +184,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
176184
better_isinstance, control_type, control_type_expected
177185
)
178186
if not control_type_compatible:
179-
raise ValueError(
187+
raise TermAndSolverIncompatible(
180188
"Control term is incompatible: the returned control (e.g. "
181189
f"Brownian motion for an SDE) was {control_type}, but this "
182190
f"solver expected {control_type_expected}."
@@ -185,14 +193,23 @@ def _check(term_cls, term, term_contr_kwargs, yi):
185193
assert False, "Malformed term structure"
186194
# If we've got to this point then the term is compatible
187195

196+
try: # check for JAX pytree mismatches first
197+
jtu.tree_map(lambda *a: None, term_structure, terms, contr_kwargs, y)
198+
except ValueError as e:
199+
pretty_term = wl.pformat(terms)
200+
pretty_expected = wl.pformat(term_structure)
201+
raise TermAndSolverIncompatible(
202+
f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:"
203+
f"\n{pretty_expected}\nNote that terms are checked recursively: if you "
204+
"scroll up you may find a root-cause error that is more specific."
205+
) from e
188206
try:
189207
with jax.numpy_dtype_promotion("standard"):
190208
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
191-
except ValueError as e:
192-
# ValueError may also arise from mismatched tree structures
209+
except TermAndSolverIncompatible as e:
193210
pretty_term = wl.pformat(terms)
194211
pretty_expected = wl.pformat(term_structure)
195-
raise ValueError(
212+
raise TermAndSolverIncompatible(
196213
f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:"
197214
f"\n{pretty_expected}\nNote that terms are checked recursively: if you "
198215
"scroll up you may find a root-cause error that is more specific."

test/test_integrate.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,15 @@ def grad_fn(params: jnp.ndarray) -> jnp.ndarray:
864864

865865
assert not jnp.isnan(grad).any(), "Gradient should not be NaN."
866866
assert not jnp.isinf(grad).any(), "Gradient should not be infinite."
867+
868+
869+
def test_nice_errors():
870+
def vf1(t, y, args):
871+
raise ValueError("Oh no!")
872+
873+
def vf2(t, y, args):
874+
raise TypeError("Oh no!")
875+
876+
for vf, etype in ((vf1, ValueError), (vf2, TypeError)):
877+
with pytest.raises(etype, match="Oh no!"):
878+
diffrax.diffeqsolve(diffrax.ODETerm(vf), diffrax.Euler(), 0, 1, 0.1, 0)

0 commit comments

Comments
 (0)