Skip to content

Commit aebc3dd

Browse files
Fix pyright across multiple JAX versions
1 parent 4e87b5d commit aebc3dd

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

diffrax/_solver/foster_langevin_srk.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def _recompute_coeffs(
190190
r"""When h changes, the SRK coefficients (which depend on h) are recomputed
191191
using this function."""
192192
# Inner will record the tree structure of the coefficients
193+
inner: Any
193194
inner = sentinel = object()
194195

195196
def recompute_coeffs_leaf(c: UnderdampedLangevinLeaf, _tay_coeffs: _Coeffs):
@@ -239,7 +240,7 @@ def _choose(tay_leaf, direct_leaf):
239240
tree_with_coeffs = jtu.tree_map(recompute_coeffs_leaf, gamma, tay_coeffs)
240241
outer = jtu.tree_structure(gamma)
241242
assert inner is not sentinel, "inner tree structure not set"
242-
coeffs_with_tree = jtu.tree_transpose(outer, inner, tree_with_coeffs) # pyright: ignore
243+
coeffs_with_tree = jtu.tree_transpose(outer, inner, tree_with_coeffs)
243244
return coeffs_with_tree
244245

245246
def init(

0 commit comments

Comments
 (0)