Skip to content

Commit 193c3d7

Browse files
committed
Fix numba FunctionGraph cache key
It's necessary to encode the edge information, not only the nodes and their ordering
1 parent c1b2011 commit 193c3d7

File tree

3 files changed

+122
-31
lines changed

3 files changed

+122
-31
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
1010

1111
from pytensor import config
12-
from pytensor.graph.basic import Apply, Constant
12+
from pytensor.graph.basic import Apply, Constant, Variable
1313
from pytensor.graph.fg import FunctionGraph
1414
from pytensor.graph.type import Type
1515
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
@@ -498,36 +498,46 @@ def numba_funcify_FunctionGraph(
498498
):
499499
# Collect cache keys of every Op/Constant in the FunctionGraph
500500
# so we can create a global cache key for the whole FunctionGraph
501+
fgraph_can_be_cached = [True]
501502
cache_keys = []
502503
toposort = fgraph.toposort()
503-
clients = fgraph.clients
504-
toposort_indices = {node: i for i, node in enumerate(toposort)}
505-
# Add dummy output clients which are not included of the toposort
506-
toposort_indices |= {
507-
clients[out][0][0]: i
508-
for i, out in enumerate(fgraph.outputs, start=len(toposort))
504+
toposort_coords: dict[Variable, tuple[int, int]] = {
505+
inp: (0, i) for i, inp in enumerate(fgraph.inputs)
506+
}
507+
toposort_coords |= {
508+
out: (i, j)
509+
for i, node in enumerate(toposort, start=1)
510+
for j, out in enumerate(node.outputs)
509511
}
510512

511-
def op_conversion_and_key_collection(*args, **kwargs):
513+
def op_conversion_and_key_collection(op, *args, node, **kwargs):
512514
# Convert an Op to a funcified function and store the cache_key
513515

514516
# We also Cache each Op so Numba can do less work next time it sees it
515-
func, key = numba_funcify_ensure_cache(*args, **kwargs)
516-
cache_keys.append(key)
517+
func, key = numba_funcify_ensure_cache(op, node=node, *args, **kwargs)
518+
if key is None:
519+
fgraph_can_be_cached[0] = False
520+
else:
521+
# Add graph coordinate information (input edges and node location)
522+
cache_keys.append(
523+
(
524+
tuple(toposort_coords[inp] for inp in node.inputs),
525+
key,
526+
)
527+
)
517528
return func
518529

519530
def type_conversion_and_key_collection(value, variable, **kwargs):
520531
# Convert a constant type to a numba compatible one and compute a cache key for it
521532

522-
# We need to know where in the graph the constants are used
523-
# Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same
524533
# FIXME: It doesn't make sense to call type_conversion on non-constants,
525-
# but that's what fgraph_to_python currently does. We appease it, but don't consider for caching
534+
# but that's what fgraph_to_python currently does.
535+
# We appease it, but don't consider for caching
526536
if isinstance(variable, Constant):
527-
client_indices = tuple(
528-
(toposort_indices[node], inp_idx) for node, inp_idx in clients[variable]
529-
)
530-
cache_keys.append((client_indices, cache_key_for_constant(value)))
537+
# Store unique key in toposort_coords. It will be included by whichever nodes make use of the constant
538+
constant_cache_key = cache_key_for_constant(value)
539+
assert constant_cache_key is not None
540+
toposort_coords[variable] = (-1, constant_cache_key)
531541
return numba_typify(value, variable=variable, **kwargs)
532542

533543
py_func = fgraph_to_python(
@@ -537,12 +547,15 @@ def type_conversion_and_key_collection(value, variable, **kwargs):
537547
fgraph_name=fgraph_name,
538548
**kwargs,
539549
)
540-
if any(key is None for key in cache_keys):
550+
if not fgraph_can_be_cached[0]:
541551
# If a single element couldn't be cached, we can't cache the whole FunctionGraph either
542552
fgraph_key = None
543553
else:
554+
# Add graph coordinate information for fgraph outputs
555+
fgraph_output_ancestors = tuple(toposort_coords[out] for out in fgraph.outputs)
556+
544557
# Compose individual cache_keys into a global key for the FunctionGraph
545558
fgraph_key = sha256(
546-
f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {len(fgraph.outputs)})".encode()
559+
f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {fgraph_output_ancestors})".encode()
547560
).hexdigest()
548561
return numba_njit(py_func), fgraph_key

pytensor/link/utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -735,14 +735,6 @@ def fgraph_to_python(
735735

736736
body_assigns = []
737737
for node in order:
738-
compiled_func = op_conversion_fn(
739-
node.op, node=node, storage_map=storage_map, **kwargs
740-
)
741-
742-
# Create a local alias with a unique name
743-
local_compiled_func_name = unique_name(compiled_func)
744-
global_env[local_compiled_func_name] = compiled_func
745-
746738
node_input_names = []
747739
for inp in node.inputs:
748740
local_input_name = unique_name(inp)
@@ -772,6 +764,13 @@ def fgraph_to_python(
772764

773765
node_output_names = [unique_name(v) for v in node.outputs]
774766

767+
compiled_func = op_conversion_fn(
768+
node.op, node=node, storage_map=storage_map, **kwargs
769+
)
770+
# Create a local alias with a unique name
771+
local_compiled_func_name = unique_name(compiled_func)
772+
global_env[local_compiled_func_name] = compiled_func
773+
775774
assign_str = f"{', '.join(node_output_names)} = {local_compiled_func_name}({', '.join(node_input_names)})"
776775
assign_comment_str = f"{indent(str(node), '# ')}"
777776
assign_block_str = f"{assign_comment_str}\n{assign_str}"

tests/link/numba/test_basic.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,31 @@
77
import pytest
88
import scipy
99

10-
from pytensor.compile import SymbolicInput
11-
from pytensor.tensor.utils import hash_from_ndarray
10+
from pytensor.tensor import scalar_from_tensor
1211

1312

1413
numba = pytest.importorskip("numba")
1514

1615
import pytensor.scalar as ps
1716
import pytensor.tensor as pt
1817
from pytensor import config, shared
18+
from pytensor.compile import SymbolicInput
1919
from pytensor.compile.function import function
2020
from pytensor.compile.mode import Mode
2121
from pytensor.graph.basic import Apply, Variable
22+
from pytensor.graph.fg import FunctionGraph
2223
from pytensor.graph.op import Op
2324
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
2425
from pytensor.graph.type import Type
2526
from pytensor.link.numba.dispatch import basic as numba_basic
26-
from pytensor.link.numba.dispatch.basic import cache_key_for_constant
27+
from pytensor.link.numba.dispatch.basic import (
28+
cache_key_for_constant,
29+
numba_funcify_and_cache_key,
30+
)
2731
from pytensor.link.numba.linker import NumbaLinker
28-
from pytensor.scalar.basic import ScalarOp, as_scalar
32+
from pytensor.scalar.basic import Composite, ScalarOp, as_scalar
2933
from pytensor.tensor.elemwise import Elemwise
34+
from pytensor.tensor.utils import hash_from_ndarray
3035

3136

3237
if TYPE_CHECKING:
@@ -652,3 +657,77 @@ def impl(x):
652657
outs[2].owner.op, outs[2].owner
653658
)
654659
assert numba.njit(lambda x: fn2_def_cached(x))(test_x) == 2
660+
661+
662+
class TestFgraphCacheKey:
663+
@staticmethod
664+
def generate_and_validate_key(fg):
665+
_, key = numba_funcify_and_cache_key(fg)
666+
assert key is not None
667+
_, key_again = numba_funcify_and_cache_key(fg)
668+
assert key == key_again # Check its stable
669+
return key
670+
671+
def test_node_order(self):
672+
x = pt.scalar("x")
673+
log_x = pt.log(x)
674+
graphs = [
675+
pt.exp(x) / log_x,
676+
log_x / pt.exp(x),
677+
pt.exp(log_x) / x,
678+
x / pt.exp(log_x),
679+
pt.exp(log_x) / log_x,
680+
log_x / pt.exp(log_x),
681+
]
682+
683+
keys = []
684+
for graph in graphs:
685+
fg = FunctionGraph([x], [graph], clone=False)
686+
keys.append(self.generate_and_validate_key(fg))
687+
# Check keys are unique
688+
assert len(set(keys)) == len(graphs)
689+
690+
# Extra unused input should alter the key, because it changes the function signature
691+
y = pt.scalar("y")
692+
for inputs in [[x, y], [y, x]]:
693+
fg = FunctionGraph(inputs, [graphs[0]], clone=False)
694+
keys.append(self.generate_and_validate_key(fg))
695+
assert len(set(keys)) == len(graphs) + 2
696+
697+
# Adding an input as an output should also change the key
698+
for outputs in [
699+
[graphs[0], x],
700+
[x, graphs[0]],
701+
[x, x, graphs[0]],
702+
[x, graphs[0], x],
703+
[graphs[0], x, x],
704+
]:
705+
fg = FunctionGraph([x], outputs, clone=False)
706+
keys.append(self.generate_and_validate_key(fg))
707+
assert len(set(keys)) == len(graphs) + 2 + 5
708+
709+
def test_multi_output(self):
710+
x = pt.scalar("x")
711+
712+
xs = scalar_from_tensor(x)
713+
out0, out1 = Elemwise(Composite([xs], [xs * 2, xs - 2]))(x)
714+
715+
test_outs = [
716+
[out0],
717+
[out1],
718+
[out0, out1],
719+
[out1, out0],
720+
]
721+
keys = []
722+
for test_out in test_outs:
723+
fg = FunctionGraph([x], test_out, clone=False)
724+
keys.append(self.generate_and_validate_key(fg))
725+
assert len(set(keys)) == len(test_outs)
726+
727+
def test_constant_output(self):
728+
fg_pi = FunctionGraph([], [pt.constant(np.pi)])
729+
fg_e = FunctionGraph([], [pt.constant(np.e)])
730+
731+
assert self.generate_and_validate_key(fg_pi) != self.generate_and_validate_key(
732+
fg_e
733+
)

0 commit comments

Comments
 (0)