Skip to content

Commit 6756d7e

Browse files
committed
Refactor test and change expected counts of Alloc that were due to BlasOpt
1 parent 1b1dc59 commit 6756d7e

File tree

1 file changed

+29
-27
lines changed

1 file changed

+29
-27
lines changed

tests/tensor/test_basic.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -738,41 +738,43 @@ def check_allocs_in_fgraph(fgraph, n):
738738
def setup_method(self):
739739
self.rng = np.random.default_rng(seed=utt.fetch_seed())
740740

741-
def test_alloc_constant_folding(self):
741+
@pytest.mark.parametrize(
742+
"subtensor_fn, expected_grad_n_alloc",
743+
[
744+
# IncSubtensor1
745+
(lambda x: x[:60], 1),
746+
# AdvancedIncSubtensor1
747+
(lambda x: x[np.arange(60)], 1),
748+
# AdvancedIncSubtensor
749+
(lambda x: x[np.arange(50), np.arange(50)], 1),
750+
],
751+
)
752+
def test_alloc_constant_folding(self, subtensor_fn, expected_grad_n_alloc):
742753
test_params = np.asarray(self.rng.standard_normal(50 * 60), self.dtype)
743754

744755
some_vector = vector("some_vector", dtype=self.dtype)
745756
some_matrix = some_vector.reshape((60, 50))
746757
variables = self.shared(np.ones((50,), dtype=self.dtype))
747-
idx = constant(np.arange(50))
748-
749-
for alloc_, (subtensor, n_alloc) in zip(
750-
self.allocs,
751-
[
752-
# IncSubtensor1
753-
(some_matrix[:60], 2),
754-
# AdvancedIncSubtensor1
755-
(some_matrix[arange(60)], 2),
756-
# AdvancedIncSubtensor
757-
(some_matrix[idx, idx], 1),
758-
],
759-
strict=True,
760-
):
761-
derp = pt_sum(dense_dot(subtensor, variables))
762758

763-
fobj = pytensor.function([some_vector], derp, mode=self.mode)
764-
grad_derp = pytensor.grad(derp, some_vector)
765-
fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode)
759+
subtensor = subtensor_fn(some_matrix)
766760

767-
topo_obj = fobj.maker.fgraph.toposort()
768-
assert sum(isinstance(node.op, type(alloc_)) for node in topo_obj) == 0
761+
derp = pt_sum(dense_dot(subtensor, variables))
762+
fobj = pytensor.function([some_vector], derp, mode=self.mode)
763+
assert (
764+
sum(isinstance(node.op, Alloc) for node in fobj.maker.fgraph.apply_nodes)
765+
== 0
766+
)
767+
# TODO: Assert something about the value if we bothered to call it?
768+
fobj(test_params)
769769

770-
topo_grad = fgrad.maker.fgraph.toposort()
771-
assert (
772-
sum(isinstance(node.op, type(alloc_)) for node in topo_grad) == n_alloc
773-
), (alloc_, subtensor, n_alloc, topo_grad)
774-
fobj(test_params)
775-
fgrad(test_params)
770+
grad_derp = pytensor.grad(derp, some_vector)
771+
fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode)
772+
assert (
773+
sum(isinstance(node.op, Alloc) for node in fgrad.maker.fgraph.apply_nodes)
774+
== expected_grad_n_alloc
775+
)
776+
# TODO: Assert something about the value if we bothered to call it?
777+
fgrad(test_params)
776778

777779
def test_alloc_output(self):
778780
val = constant(self.rng.standard_normal((1, 1)), dtype=self.dtype)

0 commit comments

Comments
 (0)