@@ -114,6 +114,10 @@ def _is_none(x: Any) -> bool:
114114 return x is None
115115
116116
117+ class TermAndSolverIncompatible (ValueError ):
118+ pass
119+
120+
117121def _assert_term_compatible (
118122 t : FloatScalarLike ,
119123 y : PyTree [ArrayLike ],
@@ -137,7 +141,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
137141 ):
138142 _assert_term_compatible (t , yi , args , term , arg , term_contr_kwarg )
139143 else :
140- raise ValueError (
144+ raise TermAndSolverIncompatible (
141145 f"Term { term } is not a MultiTerm but is expected to be."
142146 )
143147 else :
@@ -147,7 +151,9 @@ def _check(term_cls, term, term_contr_kwargs, yi):
147151 if origin_cls is None :
148152 origin_cls = term_cls
149153 if not isinstance (term , origin_cls ):
150- raise ValueError (f"Term { term } is not an instance of { origin_cls } ." )
154+ raise TermAndSolverIncompatible (
155+ f"Term { term } is not an instance of { origin_cls } ."
156+ )
151157
152158 # Now check the generic parametrization of `term_cls`; can be one of:
153159 # -----------------------------------------
@@ -167,7 +173,9 @@ def _check(term_cls, term, term_contr_kwargs, yi):
167173 better_isinstance , vf_type , vf_type_expected
168174 )
169175 if not vf_type_compatible :
170- raise ValueError (f"Vector field term { term } is incompatible." )
176+ raise TermAndSolverIncompatible (
177+ f"Vector field term { term } is incompatible."
178+ )
171179
172180 contr = ft .partial (term .contr , ** term_contr_kwargs )
173181 # Work around https://github.com/google/jax/issues/21825
@@ -176,7 +184,7 @@ def _check(term_cls, term, term_contr_kwargs, yi):
176184 better_isinstance , control_type , control_type_expected
177185 )
178186 if not control_type_compatible :
179- raise ValueError (
187+ raise TermAndSolverIncompatible (
180188 "Control term is incompatible: the returned control (e.g. "
181189 f"Brownian motion for an SDE) was { control_type } , but this "
182190 f"solver expected { control_type_expected } ."
@@ -185,14 +193,23 @@ def _check(term_cls, term, term_contr_kwargs, yi):
185193 assert False , "Malformed term structure"
186194 # If we've got to this point then the term is compatible
187195
196+ try : # check for JAX pytree mismatches first
197+ jtu .tree_map (lambda * a : None , term_structure , terms , contr_kwargs , y )
198+ except ValueError as e :
199+ pretty_term = wl .pformat (terms )
200+ pretty_expected = wl .pformat (term_structure )
201+ raise TermAndSolverIncompatible (
202+ f"Terms are not compatible with solver! Got:\n { pretty_term } \n but expected:"
203+ f"\n { pretty_expected } \n Note that terms are checked recursively: if you "
204+ "scroll up you may find a root-cause error that is more specific."
205+ ) from e
188206 try :
189207 with jax .numpy_dtype_promotion ("standard" ):
190208 jtu .tree_map (_check , term_structure , terms , contr_kwargs , y )
191- except ValueError as e :
192- # ValueError may also arise from mismatched tree structures
209+ except TermAndSolverIncompatible as e :
193210 pretty_term = wl .pformat (terms )
194211 pretty_expected = wl .pformat (term_structure )
195- raise ValueError (
212+ raise TermAndSolverIncompatible (
196213 f"Terms are not compatible with solver! Got:\n { pretty_term } \n but expected:"
197214 f"\n { pretty_expected } \n Note that terms are checked recursively: if you "
198215 "scroll up you may find a root-cause error that is more specific."
0 commit comments