44#include < htool/distributed_operator/implementations/local_operators/local_operator.hpp>
55#include < pybind11/pybind11.h>
66
7- template <typename CoefficientPrecision, typename CoordinatePrecision = CoefficientPrecision >
8- class LocalOperatorPython : public htool ::LocalOperator<CoefficientPrecision, CoordinatePrecision > {
7+ template <typename CoefficientPrecision>
8+ class LocalOperatorPython : public htool ::LocalOperator<CoefficientPrecision> {
99 public:
10- using htool::LocalOperator<CoefficientPrecision, CoordinatePrecision >::LocalOperator;
10+ using htool::LocalOperator<CoefficientPrecision>::LocalOperator;
1111
12- LocalOperatorPython (const Cluster<CoordinatePrecision> &cluster_tree_target, const Cluster<CoordinatePrecision> &cluster_tree_source , bool target_use_permutation_to_mvprod = false , bool source_use_permutation_to_mvprod = false ) : LocalOperator<CoefficientPrecision, CoordinatePrecision>(cluster_tree_target, cluster_tree_source , target_use_permutation_to_mvprod, source_use_permutation_to_mvprod) {}
12+ LocalOperatorPython (LocalRenumbering target_local_renumbering, LocalRenumbering source_local_renumbering , bool target_use_permutation_to_mvprod = false , bool source_use_permutation_to_mvprod = false ) : LocalOperator<CoefficientPrecision>(target_local_renumbering, source_local_renumbering , target_use_permutation_to_mvprod, source_use_permutation_to_mvprod) {}
1313
1414 void local_add_vector_product (char trans, CoefficientPrecision alpha, const CoefficientPrecision *in, CoefficientPrecision beta, CoefficientPrecision *out) const override {
1515
16- py::array_t <CoefficientPrecision> input (std::array<long int , 1 >{this ->m_source_cluster .get_size ()}, in, py::capsule (in));
17- py::array_t <CoefficientPrecision> output (std::array<long int , 1 >{this ->m_target_cluster .get_size ()}, out, py::capsule (out));
16+ py::array_t <CoefficientPrecision> input (std::array<long int , 1 >{this ->m_local_source_renumbering .get_size ()}, in, py::capsule (in));
17+ py::array_t <CoefficientPrecision> output (std::array<long int , 1 >{this ->m_local_target_renumbering .get_size ()}, out, py::capsule (out));
1818
1919 add_vector_product (trans, alpha, input, beta, output);
2020 }
2121
2222 void local_add_matrix_product_row_major (char trans, CoefficientPrecision alpha, const CoefficientPrecision *in, CoefficientPrecision beta, CoefficientPrecision *out, int mu) const override {
2323
24- py::array_t <CoefficientPrecision, py::array::c_style> input (std::array<long int , 2 >{this ->m_source_cluster .get_size (), mu}, in, py::capsule (in));
25- py::array_t <CoefficientPrecision, py::array::c_style> output (std::array<long int , 2 >{this ->m_target_cluster .get_size (), mu}, out, py::capsule (out));
24+ py::array_t <CoefficientPrecision, py::array::c_style> input (std::array<long int , 2 >{this ->m_local_source_renumbering .get_size (), mu}, in, py::capsule (in));
25+ py::array_t <CoefficientPrecision, py::array::c_style> output (std::array<long int , 2 >{this ->m_local_target_renumbering .get_size (), mu}, out, py::capsule (out));
2626
2727 add_matrix_product_row_major (trans, alpha, input, beta, output);
2828 }
@@ -34,10 +34,10 @@ class LocalOperatorPython : public htool::LocalOperator<CoefficientPrecision, Co
3434 virtual void add_matrix_product_row_major (char trans, CoefficientPrecision alpha, const py::array_t <CoefficientPrecision, py::array::c_style> &in, CoefficientPrecision beta, py::array_t <CoefficientPrecision, py::array::c_style> &out) const = 0; // LCOV_EXCL_LINE
3535};
3636
37- template <typename CoefficientPrecision, typename CoordinatePrecision >
38- class PyLocalOperator : public LocalOperatorPython <CoefficientPrecision, CoordinatePrecision > {
37+ template <typename CoefficientPrecision>
38+ class PyLocalOperator : public LocalOperatorPython <CoefficientPrecision> {
3939 public:
40- using LocalOperatorPython<CoefficientPrecision, CoordinatePrecision >::LocalOperatorPython;
40+ using LocalOperatorPython<CoefficientPrecision>::LocalOperatorPython;
4141
4242 /* Trampoline (need one for each virtual function) */
4343 virtual void add_vector_product (char trans, CoefficientPrecision alpha, const py::array_t <CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t <CoefficientPrecision> &out) const override {
@@ -66,17 +66,17 @@ class PyLocalOperator : public LocalOperatorPython<CoefficientPrecision, Coordin
6666 }
6767};
6868
69- template <typename CoefficientPrecision, typename CoordinatePrecision >
69+ template <typename CoefficientPrecision>
7070void declare_local_operator (py::module &m, const std::string &class_name) {
7171 using VirtualClass = htool::VirtualLocalOperator<CoefficientPrecision>;
7272 py::class_<VirtualClass>(m, (" Virtual" + class_name).c_str ());
7373
74- using BaseClass = LocalOperator<CoefficientPrecision, CoordinatePrecision >;
74+ using BaseClass = LocalOperator<CoefficientPrecision>;
7575 py::class_<BaseClass, VirtualClass> py_base_class (m, (" Base" + class_name).c_str ());
7676
77- using Class = LocalOperatorPython<CoefficientPrecision, CoordinatePrecision >;
78- py::class_<Class, PyLocalOperator<CoefficientPrecision, CoordinatePrecision >, BaseClass> py_class (m, class_name.c_str ());
79- py_class.def (py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> & , bool , bool >());
77+ using Class = LocalOperatorPython<CoefficientPrecision>;
78+ py::class_<Class, PyLocalOperator<CoefficientPrecision>, BaseClass> py_class (m, class_name.c_str ());
79+ py_class.def (py::init<LocalRenumbering, LocalRenumbering , bool , bool >());
8080 py_class.def (" add_vector_product" , &Class::add_vector_product, py::arg (" trans" ), py::arg (" alpha" ), py::arg (" in" ).noconvert (true ), py::arg (" beta" ), py::arg (" out" ).noconvert (true ));
8181 py_class.def (" add_matrix_product_row_major" , &Class::add_matrix_product_row_major);
8282}
0 commit comments