@@ -63,9 +63,8 @@ def register_linker(name, linker):
6363# If a string is passed as the optimizer argument in the constructor
6464# for Mode, it will be used as the key to retrieve the real optimizer
6565# in this dictionary
66- exclude = []
67- if not config .cxx :
68- exclude = ["cxx_only" ]
66+
67+ exclude = ["cxx_only" , "BlasOpt" ]
6968OPT_NONE = RewriteDatabaseQuery (include = [], exclude = exclude )
7069# Even if multiple merge optimizer call will be there, this shouldn't
7170# impact performance.
@@ -342,6 +341,11 @@ def __setstate__(self, state):
342341 optimizer = predefined_optimizers [optimizer ]
343342 if isinstance (optimizer , RewriteDatabaseQuery ):
344343 self .provided_optimizer = optimizer
344+
345+ # Force numba-required rewrites if using NumbaLinker
346+ if isinstance (linker , NumbaLinker ):
347+ optimizer = optimizer .including ("numba" )
348+
345349 self ._optimizer = optimizer
346350 self .call_time = 0
347351 self .fn_time = 0
@@ -439,16 +443,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
439443# string as the key
440444# Use VM_linker to allow lazy evaluation by default.
441445FAST_COMPILE = Mode (
442- VMLinker (use_cloop = False , c_thunks = False ),
443- RewriteDatabaseQuery (include = ["fast_compile" , "py_only" ]),
446+ NumbaLinker (),
447+ # TODO: Fast_compile should just use python code, CHANGE ME!
448+ RewriteDatabaseQuery (
449+ include = ["fast_compile" , "numba" ],
450+ exclude = ["cxx_only" , "BlasOpt" , "local_careduce_fusion" ],
451+ ),
452+ )
453+ FAST_RUN = Mode (
454+ NumbaLinker (),
455+ RewriteDatabaseQuery (
456+ include = ["fast_run" , "numba" ],
457+ exclude = ["cxx_only" , "BlasOpt" , "local_careduce_fusion" ],
458+ ),
444459)
445- if config .cxx :
446- FAST_RUN = Mode ("cvm" , "fast_run" )
447- else :
448- FAST_RUN = Mode (
449- "vm" ,
450- RewriteDatabaseQuery (include = ["fast_run" , "py_only" ]),
451- )
452460
453461JAX = Mode (
454462 JAXLinker (),
@@ -528,7 +536,7 @@ def get_mode(orig_string):
528536 # NanGuardMode use its own linker.
529537 ret = NanGuardMode (True , True , True , optimizer = config .optimizer )
530538 else :
531- # TODO: Can't we look up the name and invoke it rather than using eval here ?
539+ # TODO: Get rid of this? Or refactor ?
532540 ret = eval (string + "(linker=config.linker, optimizer=config.optimizer)" )
533541 elif string in predefined_modes :
534542 ret = predefined_modes [string ]
@@ -557,6 +565,7 @@ def register_mode(name, mode):
557565 Add a `Mode` which can be referred to by `name` in `function`.
558566
559567 """
568+ # TODO: Remove me
560569 if name in predefined_modes :
561570 raise ValueError (f"Mode name already taken: { name } " )
562571 predefined_modes [name ] = mode
0 commit comments