11#ifndef HTOOL_LOCAL_OPERATOR_CPP
22#define HTOOL_LOCAL_OPERATOR_CPP
33
4- #include < htool/distributed_operator/implementations/local_operators/local_operator .hpp>
4+ #include < htool/distributed_operator/implementations/global_to_local_operators/restricted_operator .hpp>
55#include < pybind11/pybind11.h>
66
77template <typename CoefficientPrecision>
8- class LocalOperatorPython : public htool ::LocalOperator <CoefficientPrecision> {
8+ class RestrictedGlobalToLocalOperatorPython : public htool ::RestrictedGlobalToLocalOperator <CoefficientPrecision> {
99 public:
10- using htool::LocalOperator <CoefficientPrecision>::LocalOperator ;
10+ using htool::RestrictedGlobalToLocalOperator <CoefficientPrecision>::RestrictedGlobalToLocalOperator ;
1111
12- LocalOperatorPython (LocalRenumbering target_local_renumbering, LocalRenumbering source_local_renumbering, bool target_use_permutation_to_mvprod = false , bool source_use_permutation_to_mvprod = false ) : LocalOperator <CoefficientPrecision>(target_local_renumbering, source_local_renumbering, target_use_permutation_to_mvprod, source_use_permutation_to_mvprod) {}
12+ RestrictedGlobalToLocalOperatorPython (LocalRenumbering target_local_renumbering, LocalRenumbering source_local_renumbering, bool target_use_permutation_to_mvprod = false , bool source_use_permutation_to_mvprod = false ) : RestrictedGlobalToLocalOperator <CoefficientPrecision>(target_local_renumbering, source_local_renumbering, target_use_permutation_to_mvprod, source_use_permutation_to_mvprod) {}
1313
1414 void local_add_vector_product (char trans, CoefficientPrecision alpha, const CoefficientPrecision *in, CoefficientPrecision beta, CoefficientPrecision *out) const override {
1515
@@ -35,16 +35,16 @@ class LocalOperatorPython : public htool::LocalOperator<CoefficientPrecision> {
3535};
3636
3737template <typename CoefficientPrecision>
38- class PyLocalOperator : public LocalOperatorPython <CoefficientPrecision> {
38+ class PyRestrictedGlobalToLocalOperator : public RestrictedGlobalToLocalOperatorPython <CoefficientPrecision> {
3939 public:
40- using LocalOperatorPython <CoefficientPrecision>::LocalOperatorPython ;
40+ using RestrictedGlobalToLocalOperatorPython <CoefficientPrecision>::RestrictedGlobalToLocalOperatorPython ;
4141
4242 /* Trampoline (need one for each virtual function) */
4343 virtual void add_vector_product (char trans, CoefficientPrecision alpha, const py::array_t <CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t <CoefficientPrecision> &out) const override {
4444 PYBIND11_OVERRIDE_PURE (
45- void , /* Return type */
46- LocalOperatorPython <CoefficientPrecision>, /* Parent class */
47- add_vector_product, /* Name of function in C++ (must match Python name) */
45+ void , /* Return type */
46+ RestrictedGlobalToLocalOperatorPython <CoefficientPrecision>, /* Parent class */
47+ add_vector_product, /* Name of function in C++ (must match Python name) */
4848 trans,
4949 alpha,
5050 in,
@@ -54,9 +54,9 @@ class PyLocalOperator : public LocalOperatorPython<CoefficientPrecision> {
5454 }
5555 virtual void add_matrix_product_row_major (char trans, CoefficientPrecision alpha, const py::array_t <CoefficientPrecision, py::array::c_style> &in, CoefficientPrecision beta, py::array_t <CoefficientPrecision, py::array::c_style> &out) const override {
5656 PYBIND11_OVERRIDE_PURE (
57- void , /* Return type */
58- LocalOperatorPython <CoefficientPrecision>, /* Parent class */
59- add_matrix_product_row_major, /* Name of function in C++ (must match Python name) */
57+ void , /* Return type */
58+ RestrictedGlobalToLocalOperatorPython <CoefficientPrecision>, /* Parent class */
59+ add_matrix_product_row_major, /* Name of function in C++ (must match Python name) */
6060 trans,
6161 alpha,
6262 in,
@@ -67,15 +67,15 @@ class PyLocalOperator : public LocalOperatorPython<CoefficientPrecision> {
6767};
6868
6969template <typename CoefficientPrecision>
70- void declare_local_operator (py::module &m, const std::string &class_name) {
71- using VirtualClass = htool::VirtualLocalOperator <CoefficientPrecision>;
72- py::class_<VirtualClass>(m, ( " Virtual " + class_name). c_str () );
70+ void declare_global_to_local_operator (py::module &m, const std::string &class_name) {
71+ using VirtualClass = htool::VirtualGlobalToLocalOperator <CoefficientPrecision>;
72+ py::class_<VirtualClass>(m, " IGlobalToLocalOperator " );
7373
74- using BaseClass = LocalOperator <CoefficientPrecision>;
75- py::class_<BaseClass, VirtualClass> py_base_class (m, ( " Base " + class_name). c_str () );
74+ using BaseClass = RestrictedGlobalToLocalOperator <CoefficientPrecision>;
75+ py::class_<BaseClass, VirtualClass> py_base_class (m, " IRestrictedGlobalToLocalOperator " );
7676
77- using Class = LocalOperatorPython <CoefficientPrecision>;
78- py::class_<Class, PyLocalOperator <CoefficientPrecision>, BaseClass> py_class (m, class_name.c_str ());
77+ using Class = RestrictedGlobalToLocalOperatorPython <CoefficientPrecision>;
78+ py::class_<Class, PyRestrictedGlobalToLocalOperator <CoefficientPrecision>, BaseClass> py_class (m, class_name.c_str ());
7979 py_class.def (py::init<LocalRenumbering, LocalRenumbering, bool , bool >());
8080 py_class.def (" add_vector_product" , &Class::add_vector_product, py::arg (" trans" ), py::arg (" alpha" ), py::arg (" in" ).noconvert (true ), py::arg (" beta" ), py::arg (" out" ).noconvert (true ));
8181 py_class.def (" add_matrix_product_row_major" , &Class::add_matrix_product_row_major);
0 commit comments