diff --git a/tests/test_graph_matcher.py b/tests/test_graph_matcher.py index e59c014..783b34d 100644 --- a/tests/test_graph_matcher.py +++ b/tests/test_graph_matcher.py @@ -35,7 +35,8 @@ def check_equation_match(self, eqn1, vars_to_vars, vars_to_eqn): """Checks that equation is matched in the other graph.""" eqn1_out_vars = [v for v in eqn1.outvars - if not isinstance(v, jax.core.DropVar)] + if not isinstance(v, jax.core.DropVar) and + v in vars_to_vars] eqn2_out_vars = [vars_to_vars[v] for v in eqn1_out_vars] eqns = [vars_to_eqn[v] for v in eqn2_out_vars] self.assertTrue(all(e == eqns[0] for e in eqns[1:]))