|
1 | | -from pytensor.configdefaults import config |
2 | | -from pytensor.graph.rewriting.basic import in2out |
3 | 1 | from pytensor.link.c.op import COp |
4 | 2 | from pytensor.link.c.params_type import ParamsType |
5 | 3 | from pytensor.scalar import bool as bool_t |
6 | | -from pytensor.tensor import basic as at |
7 | 4 | from pytensor.tensor.blas import ( |
8 | 5 | Gemv, |
9 | 6 | Ger, |
10 | 7 | blas_header_text, |
11 | 8 | blas_header_version, |
12 | | - blas_optdb, |
13 | | - gemv_inplace, |
14 | | - gemv_no_inplace, |
15 | | - ger, |
16 | | - ger_destructive, |
17 | 9 | ldflags, |
18 | | - node_rewriter, |
19 | | - optdb, |
20 | 10 | ) |
21 | 11 |
|
22 | 12 |
|
@@ -344,23 +334,6 @@ def c_code_cache_version(self): |
344 | 334 | cger_no_inplace = CGer(False) |
345 | 335 |
|
346 | 336 |
|
347 | | -@node_rewriter([ger, ger_destructive]) |
348 | | -def use_c_ger(fgraph, node): |
349 | | - if not config.blas__ldflags: |
350 | | - return |
351 | | - # Only float32 and float64 are supported for now. |
352 | | - if node.op == ger and node.outputs[0].dtype in ("float32", "float64"): |
353 | | - return [CGer(False)(*node.inputs)] |
354 | | - if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"): |
355 | | - return [CGer(True)(*node.inputs)] |
356 | | - |
357 | | - |
358 | | -@node_rewriter([CGer(False)]) |
359 | | -def make_c_ger_destructive(fgraph, node): |
360 | | - if isinstance(node.op, CGer) and not node.op.destructive: |
361 | | - return [cger_inplace(*node.inputs)] |
362 | | - |
363 | | - |
364 | 337 | # ##### ####### ####### |
365 | 338 | # GEMV |
366 | 339 | # ##### ####### ####### |
@@ -697,48 +670,3 @@ def check_force_gemv_init(): |
697 | 670 |
|
698 | 671 |
|
699 | 672 | check_force_gemv_init._force_init_beta = None |
700 | | - |
701 | | - |
702 | | -@node_rewriter([gemv_inplace, gemv_no_inplace]) |
703 | | -def use_c_gemv(fgraph, node): |
704 | | - if not config.blas__ldflags: |
705 | | - return |
706 | | - # Only float32 and float64 are supported for now. |
707 | | - if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"): |
708 | | - return [cgemv_no_inplace(*node.inputs)] |
709 | | - if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"): |
710 | | - return [cgemv_inplace(*node.inputs)] |
711 | | - |
712 | | - |
713 | | -@node_rewriter([CGemv(inplace=False)]) |
714 | | -def make_c_gemv_destructive(fgraph, node): |
715 | | - if isinstance(node.op, CGemv) and not node.op.inplace: |
716 | | - inputs = list(node.inputs) |
717 | | - dest = inputs[0] |
718 | | - if ( |
719 | | - dest.owner |
720 | | - and isinstance(dest.owner.op, at.AllocEmpty) |
721 | | - and len(fgraph.clients[dest]) > 1 |
722 | | - ): |
723 | | - inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs) |
724 | | - |
725 | | - return [cgemv_inplace(*inputs)] |
726 | | - |
727 | | - |
728 | | -# ##### ####### ####### |
729 | | -# Optimizers |
730 | | -# ##### ####### ####### |
731 | | - |
732 | | -blas_optdb.register( |
733 | | - "use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20 |
734 | | -) |
735 | | - |
736 | | -# this matches the InplaceBlasOpt defined in blas.py |
737 | | -optdb.register( |
738 | | - "c_blas_destructive", |
739 | | - in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"), |
740 | | - "fast_run", |
741 | | - "inplace", |
742 | | - "c_blas", |
743 | | - position=70.0, |
744 | | -) |
0 commit comments