@@ -61,9 +61,8 @@ def register_linker(name, linker):
6161# If a string is passed as the optimizer argument in the constructor
6262# for Mode, it will be used as the key to retrieve the real optimizer
6363# in this dictionary
64- exclude = []
65- if not config .cxx :
66- exclude = ["cxx_only" ]
64+
65+ exclude = ["cxx_only" , "BlasOpt" ]
6766OPT_NONE = RewriteDatabaseQuery (include = [], exclude = exclude )
6867# Even if multiple merge optimizer call will be there, this shouldn't
6968# impact performance.
@@ -340,6 +339,11 @@ def __setstate__(self, state):
340339 optimizer = predefined_optimizers [optimizer ]
341340 if isinstance (optimizer , RewriteDatabaseQuery ):
342341 self .provided_optimizer = optimizer
342+
343+ # Force numba-required rewrites if using NumbaLinker
344+ if isinstance (linker , NumbaLinker ):
345+ optimizer = optimizer .including ("numba" )
346+
343347 self ._optimizer = optimizer
344348 self .call_time = 0
345349 self .fn_time = 0
@@ -437,16 +441,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
437441# string as the key
438442# Use VM_linker to allow lazy evaluation by default.
439443FAST_COMPILE = Mode (
440- VMLinker (use_cloop = False , c_thunks = False ),
441- RewriteDatabaseQuery (include = ["fast_compile" , "py_only" ]),
444+ NumbaLinker (),
445+ # TODO: Fast_compile should just use python code, CHANGE ME!
446+ RewriteDatabaseQuery (
447+ include = ["fast_compile" , "numba" ],
448+ exclude = ["cxx_only" , "BlasOpt" , "local_careduce_fusion" ],
449+ ),
450+ )
451+ FAST_RUN = Mode (
452+ NumbaLinker (),
453+ RewriteDatabaseQuery (
454+ include = ["fast_run" , "numba" ],
455+ exclude = ["cxx_only" , "BlasOpt" , "local_careduce_fusion" ],
456+ ),
442457)
443- if config .cxx :
444- FAST_RUN = Mode ("cvm" , "fast_run" )
445- else :
446- FAST_RUN = Mode (
447- "vm" ,
448- RewriteDatabaseQuery (include = ["fast_run" , "py_only" ]),
449- )
450458
451459JAX = Mode (
452460 JAXLinker (),
@@ -512,7 +520,7 @@ def get_mode(orig_string):
512520 # NanGuardMode use its own linker.
513521 ret = NanGuardMode (True , True , True , optimizer = config .optimizer )
514522 else :
515- # TODO: Can't we look up the name and invoke it rather than using eval here ?
523+ # TODO: Get rid of this? Or refactor ?
516524 ret = eval (string + "(linker=config.linker, optimizer=config.optimizer)" )
517525 elif string in predefined_modes :
518526 ret = predefined_modes [string ]
@@ -541,6 +549,7 @@ def register_mode(name, mode):
541549 Add a `Mode` which can be referred to by `name` in `function`.
542550
543551 """
552+ # TODO: Remove me
544553 if name in predefined_modes :
545554 raise ValueError (f"Mode name already taken: { name } " )
546555 predefined_modes [name ] = mode
0 commit comments