Skip to content

Commit 88eb2b4

Browse files
committed
Refactor test and change expected counts of Alloc that were due to BlasOpt
1 parent 05c9035 commit 88eb2b4

File tree

1 file changed

+29
-26
lines changed

1 file changed

+29
-26
lines changed

tests/tensor/test_basic.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -759,40 +759,43 @@ def check_allocs_in_fgraph(fgraph, n):
759759
def setup_method(self):
760760
self.rng = np.random.default_rng(seed=utt.fetch_seed())
761761

762-
def test_alloc_constant_folding(self):
762+
@pytest.mark.parametrize(
763+
"subtensor_fn, expected_grad_n_alloc",
764+
[
765+
# IncSubtensor1
766+
(lambda x: x[:60], 1),
767+
# AdvancedIncSubtensor1
768+
(lambda x: x[np.arange(60)], 1),
769+
# AdvancedIncSubtensor
770+
(lambda x: x[np.arange(50), np.arange(50)], 1),
771+
],
772+
)
773+
def test_alloc_constant_folding(self, subtensor_fn, expected_grad_n_alloc):
763774
test_params = np.asarray(self.rng.standard_normal(50 * 60), self.dtype)
764775

765776
some_vector = vector("some_vector", dtype=self.dtype)
766777
some_matrix = some_vector.reshape((60, 50))
767778
variables = self.shared(np.ones((50,), dtype=self.dtype))
768-
idx = constant(np.arange(50))
769779

770-
for alloc_, (subtensor, n_alloc) in zip(
771-
self.allocs,
772-
[
773-
# IncSubtensor1
774-
(some_matrix[:60], 2),
775-
# AdvancedIncSubtensor1
776-
(some_matrix[arange(60)], 2),
777-
# AdvancedIncSubtensor
778-
(some_matrix[idx, idx], 1),
779-
],
780-
):
781-
derp = pt_sum(dense_dot(subtensor, variables))
780+
subtensor = subtensor_fn(some_matrix)
782781

783-
fobj = pytensor.function([some_vector], derp, mode=self.mode)
784-
grad_derp = pytensor.grad(derp, some_vector)
785-
fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode)
786-
787-
topo_obj = fobj.maker.fgraph.toposort()
788-
assert sum(isinstance(node.op, type(alloc_)) for node in topo_obj) == 0
782+
derp = pt_sum(dense_dot(subtensor, variables))
783+
fobj = pytensor.function([some_vector], derp, mode=self.mode)
784+
assert (
785+
sum(isinstance(node.op, Alloc) for node in fobj.maker.fgraph.apply_nodes)
786+
== 0
787+
)
788+
# TODO: Assert something about the value if we bothered to call it?
789+
fobj(test_params)
789790

790-
topo_grad = fgrad.maker.fgraph.toposort()
791-
assert (
792-
sum(isinstance(node.op, type(alloc_)) for node in topo_grad) == n_alloc
793-
), (alloc_, subtensor, n_alloc, topo_grad)
794-
fobj(test_params)
795-
fgrad(test_params)
791+
grad_derp = pytensor.grad(derp, some_vector)
792+
fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode)
793+
assert (
794+
sum(isinstance(node.op, Alloc) for node in fgrad.maker.fgraph.apply_nodes)
795+
== expected_grad_n_alloc
796+
)
797+
# TODO: Assert something about the value if we bothered to call it?
798+
fgrad(test_params)
796799

797800
def test_alloc_output(self):
798801
val = constant(self.rng.standard_normal((1, 1)), dtype=self.dtype)

0 commit comments

Comments
 (0)