99from numba .cpython .unsafe .tuple import tuple_setitem # noqa: F401
1010
1111from pytensor import config
12- from pytensor .graph .basic import Apply , Constant
12+ from pytensor .graph .basic import Apply , Constant , Variable
1313from pytensor .graph .fg import FunctionGraph
1414from pytensor .graph .type import Type
1515from pytensor .link .numba .cache import compile_numba_function_src , hash_from_pickle_dump
@@ -498,36 +498,46 @@ def numba_funcify_FunctionGraph(
498498):
499499 # Collect cache keys of every Op/Constant in the FunctionGraph
500500 # so we can create a global cache key for the whole FunctionGraph
501+ fgraph_can_be_cached = [True ]
501502 cache_keys = []
502503 toposort = fgraph .toposort ()
503- clients = fgraph .clients
504- toposort_indices = {node : i for i , node in enumerate (toposort )}
505- # Add dummy output clients which are not included of the toposort
506- toposort_indices |= {
507- clients [out ][0 ][0 ]: i
508- for i , out in enumerate (fgraph .outputs , start = len (toposort ))
504+ toposort_coords : dict [Variable , tuple [int , int ]] = {
505+ inp : (0 , i ) for i , inp in enumerate (fgraph .inputs )
506+ }
507+ toposort_coords |= {
508+ out : (i , j )
509+ for i , node in enumerate (toposort , start = 1 )
510+ for j , out in enumerate (node .outputs )
509511 }
510512
511- def op_conversion_and_key_collection (* args , ** kwargs ):
513+ def op_conversion_and_key_collection (op , * args , node , ** kwargs ):
512514 # Convert an Op to a funcified function and store the cache_key
513515
514516 # We also Cache each Op so Numba can do less work next time it sees it
515- func , key = numba_funcify_ensure_cache (* args , ** kwargs )
516- cache_keys .append (key )
517+ func , key = numba_funcify_ensure_cache (op , node = node , * args , ** kwargs )
518+ if key is None :
519+ fgraph_can_be_cached [0 ] = False
520+ else :
521+ # Add graph coordinate information (input edges and node location)
522+ cache_keys .append (
523+ (
524+ tuple (toposort_coords [inp ] for inp in node .inputs ),
525+ key ,
526+ )
527+ )
517528 return func
518529
519530 def type_conversion_and_key_collection (value , variable , ** kwargs ):
520531 # Convert a constant type to a numba compatible one and compute a cache key for it
521532
522- # We need to know where in the graph the constants are used
523- # Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same
524533 # FIXME: It doesn't make sense to call type_conversion on non-constants,
525- # but that's what fgraph_to_python currently does. We appease it, but don't consider for caching
534+ # but that's what fgraph_to_python currently does.
535+ # We appease it, but don't consider for caching
526536 if isinstance (variable , Constant ):
527- client_indices = tuple (
528- ( toposort_indices [ node ], inp_idx ) for node , inp_idx in clients [ variable ]
529- )
530- cache_keys . append (( client_indices , cache_key_for_constant ( value )) )
537+ # Store unique key in toposort_coords. It will be included by whichever nodes make use of the constant
538+ constant_cache_key = cache_key_for_constant ( value )
539+ assert constant_cache_key is not None
540+ toposort_coords [ variable ] = ( - 1 , constant_cache_key )
531541 return numba_typify (value , variable = variable , ** kwargs )
532542
533543 py_func = fgraph_to_python (
@@ -537,12 +547,15 @@ def type_conversion_and_key_collection(value, variable, **kwargs):
537547 fgraph_name = fgraph_name ,
538548 ** kwargs ,
539549 )
540- if any ( key is None for key in cache_keys ) :
550+ if not fgraph_can_be_cached [ 0 ] :
541551 # If a single element couldn't be cached, we can't cache the whole FunctionGraph either
542552 fgraph_key = None
543553 else :
554+ # Add graph coordinate information for fgraph outputs
555+ fgraph_output_ancestors = tuple (toposort_coords [out ] for out in fgraph .outputs )
556+
544557 # Compose individual cache_keys into a global key for the FunctionGraph
545558 fgraph_key = sha256 (
546- f"({ type (fgraph )} , { tuple (cache_keys )} , { len (fgraph .inputs )} , { len ( fgraph . outputs ) } )" .encode ()
559+ f"({ type (fgraph )} , { tuple (cache_keys )} , { len (fgraph .inputs )} , { fgraph_output_ancestors } )" .encode ()
547560 ).hexdigest ()
548561 return numba_njit (py_func ), fgraph_key
0 commit comments