55
66import logging
77import warnings
8- from typing import Any , Literal
8+ from typing import Literal
99
1010from pytensor .compile .function .types import Supervisor
1111from pytensor .configdefaults import config
@@ -62,17 +62,20 @@ def register_linker(name, linker):
6262 predefined_linkers [name ] = linker
6363
6464
65- OPT_NONE = RewriteDatabaseQuery (include = [])
65+ exclude = []
66+ if not config .cxx :
67+ exclude = ["cxx_only" ]
68+ OPT_NONE = RewriteDatabaseQuery (include = [], exclude = exclude )
6669# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
67- OPT_MINIMUM = RewriteDatabaseQuery (include = ["minimum_compile" ])
70+ OPT_MINIMUM = RewriteDatabaseQuery (include = ["minimum_compile" ], exclude = exclude )
6871# Even if multiple merge optimizer call will be there, this shouldn't
6972# impact performance.
70- OPT_MERGE = RewriteDatabaseQuery (include = ["merge" ])
71- OPT_FAST_RUN = RewriteDatabaseQuery (include = ["fast_run" ])
73+ OPT_MERGE = RewriteDatabaseQuery (include = ["merge" ], exclude = exclude )
74+ OPT_FAST_RUN = RewriteDatabaseQuery (include = ["fast_run" ], exclude = exclude )
7275OPT_FAST_RUN_STABLE = OPT_FAST_RUN .requiring ("stable" )
7376
74- OPT_FAST_COMPILE = RewriteDatabaseQuery (include = ["fast_compile" ])
75- OPT_STABILIZE = RewriteDatabaseQuery (include = ["fast_run" ])
77+ OPT_FAST_COMPILE = RewriteDatabaseQuery (include = ["fast_compile" ], exclude = exclude )
78+ OPT_STABILIZE = RewriteDatabaseQuery (include = ["fast_run" ], exclude = exclude )
7679OPT_STABILIZE .position_cutoff = 1.5000001
7780OPT_NONE .name = "OPT_NONE"
7881OPT_MINIMUM .name = "OPT_MINIMUM"
@@ -310,8 +313,6 @@ def __init__(
310313 ):
311314 if linker is None :
312315 linker = config .linker
313- if isinstance (linker , str ) and linker == "auto" :
314- linker = "cvm" if config .cxx else "vm"
315316 if isinstance (optimizer , str ) and optimizer == "default" :
316317 optimizer = config .optimizer
317318
@@ -447,15 +448,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
447448 return new_mode
448449
449450
450- C = Mode ("c" , "fast_run" )
451- CVM = Mode ("cvm" , "fast_run" )
452- VM = (Mode ("vm" , "fast_run" ),)
453-
454451NUMBA = Mode (
455452 NumbaLinker (),
456453 RewriteDatabaseQuery (include = ["fast_run" , "numba" ]),
457454)
458455
456+ FAST_COMPILE = Mode (
457+ NumbaLinker (),
458+ RewriteDatabaseQuery (include = ["fast_compile" ]),
459+ )
460+ FAST_RUN = NUMBA
461+
462+ C = Mode ("c" , "fast_run" )
463+ CVM = Mode ("cvm" , "fast_run" )
464+
459465JAX = Mode (
460466 JAXLinker (),
461467 RewriteDatabaseQuery (include = ["fast_run" , "jax" ]),
@@ -470,19 +476,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
470476 RewriteDatabaseQuery (include = ["fast_run" ]),
471477)
472478
473- FAST_COMPILE = Mode (
474- VMLinker (use_cloop = False , c_thunks = False ),
475- RewriteDatabaseQuery (include = ["fast_compile" , "py_only" ]),
476- )
477-
478- fast_run_linkers_to_mode = {
479- "cvm" : CVM ,
480- "vm" : VM ,
481- "numba" : NUMBA ,
482- }
483479
484480predefined_modes = {
485481 "FAST_COMPILE" : FAST_COMPILE ,
482+ "FAST_RUN" : FAST_RUN ,
486483 "C" : C ,
487484 "CVM" : CVM ,
488485 "JAX" : JAX ,
@@ -491,7 +488,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
491488 "MLX" : MLX ,
492489}
493490
494- _CACHED_RUNTIME_MODES : dict [Any , Mode ] = {}
491+ _CACHED_RUNTIME_MODES : dict [str , Mode ] = {}
495492
496493
497494def get_mode (orig_string ):
@@ -509,20 +506,10 @@ def get_mode(orig_string):
509506 if upper_string in predefined_modes :
510507 return predefined_modes [upper_string ]
511508
512- if upper_string == "FAST_RUN" :
513- linker = config .linker
514- if linker == "auto" :
515- return CVM if config .cxx else VM
516- return fast_run_linkers_to_mode [linker ]
517-
518509 global _CACHED_RUNTIME_MODES
519510
520- cache_key = ("MODE" , config .linker ) if upper_string == "MODE" else upper_string
521-
522- try :
523- return _CACHED_RUNTIME_MODES [cache_key ]
524- except KeyError :
525- pass
511+ if upper_string in _CACHED_RUNTIME_MODES :
512+ return _CACHED_RUNTIME_MODES [upper_string ]
526513
527514 # Need to define the mode for the first time
528515 if upper_string == "MODE" :
@@ -548,7 +535,7 @@ def get_mode(orig_string):
548535 if config .optimizer_requiring :
549536 ret = ret .requiring (* config .optimizer_requiring .split (":" ))
550537 # Cache the mode for next time
551- _CACHED_RUNTIME_MODES [cache_key ] = ret
538+ _CACHED_RUNTIME_MODES [upper_string ] = ret
552539
553540 return ret
554541
0 commit comments