66
77template <typename CoefficientPrecision, typename CoordinatePrecision = htool::underlying_type<CoefficientPrecision>>
88class VirtualLocalToLocalOperatorPython : public htool ::VirtualLocalToLocalOperator<CoefficientPrecision> {
9- const Cluster<CoordinatePrecision> &m_target_cluster ;
10- const Cluster<CoordinatePrecision> &m_source_cluster ;
9+ LocalRenumbering m_local_target_renumbering ;
10+ LocalRenumbering m_local_source_renumbering ;
1111
1212 public:
13- VirtualLocalToLocalOperatorPython (const Cluster<CoordinatePrecision> &target_cluster, const Cluster<CoordinatePrecision> &source_cluster ) : m_target_cluster(target_cluster ), m_source_cluster(source_cluster ) {}
13+ VirtualLocalToLocalOperatorPython (LocalRenumbering local_target_renumbering, LocalRenumbering local_source_renumbering ) : m_local_target_renumbering(local_target_renumbering ), m_local_source_renumbering(local_source_renumbering ) {}
1414
1515 void add_vector_product (char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out) const override {
16- py::array_t <CoefficientPrecision> input (std::array<long int , 1 >{trans == ' N' ? m_source_cluster .get_size () : m_target_cluster .get_size ()}, in, py::capsule (in));
17- py::array_t <CoefficientPrecision> output (std::array<long int , 1 >{trans == ' N' ? m_target_cluster .get_size () : m_source_cluster .get_size ()}, out, py::capsule (out));
16+ py::array_t <CoefficientPrecision> input (std::array<long int , 1 >{trans == ' N' ? m_local_source_renumbering .get_size () : m_local_target_renumbering .get_size ()}, in, py::capsule (in));
17+ py::array_t <CoefficientPrecision> output (std::array<long int , 1 >{trans == ' N' ? m_local_target_renumbering .get_size () : m_local_source_renumbering .get_size ()}, out, py::capsule (out));
1818
1919 local_add_vector_product (trans, alpha, input, beta, output);
2020 }
2121 void add_matrix_product_row_major (char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out, int mu) const override {
22- py::array_t <CoefficientPrecision, py::array::c_style> input (std::array<long int , 2 >{trans == ' N' ? m_source_cluster .get_size () : m_target_cluster .get_size (), mu}, in, py::capsule (in));
23- py::array_t <CoefficientPrecision, py::array::c_style> output (std::array<long int , 2 >{trans == ' N' ? m_target_cluster .get_size () : m_source_cluster .get_size (), mu}, out, py::capsule (out));
22+ py::array_t <CoefficientPrecision, py::array::c_style> input (std::array<long int , 2 >{trans == ' N' ? m_local_source_renumbering .get_size () : m_local_target_renumbering .get_size (), mu}, in, py::capsule (in));
23+ py::array_t <CoefficientPrecision, py::array::c_style> output (std::array<long int , 2 >{trans == ' N' ? m_local_target_renumbering .get_size () : m_local_source_renumbering .get_size (), mu}, out, py::capsule (out));
2424
2525 local_add_matrix_product_row_major (trans, alpha, input, beta, output);
2626 }
2727
2828 void add_sub_matrix_product_to_local (const CoefficientPrecision *const in, CoefficientPrecision *const out, int mu, int offset, int size) const override {
29- int source_offset = m_source_cluster.get_offset ();
30- int source_size = m_source_cluster.get_size ();
31- bool is_output_null = ((offset + size) < source_offset) || (source_offset + source_size < offset);
32- if (!is_output_null) {
33- int temp_offset = std::max (offset, source_offset);
34- const CoefficientPrecision *const temp_in = (offset < source_offset) ? in + source_offset - offset : in;
35- int temp_size = (size + offset <= source_size + source_offset) ? size - std::max (source_offset - offset, 0 ) : size - std::max (source_offset - offset, 0 ) - (size + offset - source_offset - source_size);
36-
37- if (temp_offset == source_offset && temp_size == source_size)
38- add_matrix_product_row_major (' N' , 1 , temp_in, 1 , out, mu);
39- else {
40- std::vector<CoefficientPrecision> extension_by_zero (source_size * mu);
29+ int source_offset = m_local_source_renumbering.get_offset ();
30+ int source_size = m_local_source_renumbering.get_size ();
31+
32+ int source_end = source_size+source_offset;
33+ int end = size+offset;
34+
35+ int temp_offset = std::max (offset,source_offset);
36+ int temp_end = std::min (source_end,end);
37+
38+ bool is_output_null = temp_end-temp_offset<=0 ? true :false ;
39+ if (offset == source_offset && temp_end == source_end){
40+ add_matrix_product_row_major (' N' , 1 , in, 1 , out, mu);
41+ }
42+ else {
43+ std::vector<CoefficientPrecision> extension_by_zero (source_size * mu);
44+ if (!is_output_null){
45+ const CoefficientPrecision *const temp_in = in + temp_offset-offset;
46+ int temp_size = temp_end-temp_offset;
4147 std::copy_n (temp_in, temp_size * mu, extension_by_zero.data ());
42- add_matrix_product_row_major (' N' , 1 , extension_by_zero.data (), 1 , out, mu);
4348 }
49+ add_matrix_product_row_major (' N' , 1 , extension_by_zero.data (), 1 , out, mu);
4450 }
51+
4552 }
4653
4754 virtual void local_add_vector_product (char trans, CoefficientPrecision alpha, const py::array_t <CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t <CoefficientPrecision> &out) const = 0; // LCOV_EXCL_LINE
@@ -59,7 +66,7 @@ class PyVirtualLocalToLocalOperator : public VirtualLocalToLocalOperatorPython<C
5966 PYBIND11_OVERRIDE_PURE (
6067 void , /* Return type */
6168 VirtualLocalToLocalOperatorPython<CoefficientPrecision>, /* Parent class */
62- add_vector_product , /* Name of function in C++ (must match Python name) */
69+ local_add_vector_product , /* Name of function in C++ (must match Python name) */
6370 trans,
6471 alpha,
6572 in,
@@ -71,7 +78,7 @@ class PyVirtualLocalToLocalOperator : public VirtualLocalToLocalOperatorPython<C
7178 PYBIND11_OVERRIDE_PURE (
7279 void , /* Return type */
7380 VirtualLocalToLocalOperatorPython<CoefficientPrecision>, /* Parent class */
74- add_matrix_product_row_major , /* Name of function in C++ (must match Python name) */
81+ local_add_matrix_product_row_major , /* Name of function in C++ (must match Python name) */
7582 trans,
7683 alpha,
7784 in,
@@ -87,10 +94,10 @@ void declare_virtual_local_to_local_operator(py::module &m, const std::string &c
8794 py::class_<BaseClass>(m, (base_class_name).c_str ());
8895
8996 using Class = VirtualLocalToLocalOperatorPython<CoefficientPrecision>;
90- py::class_<Class, PyVirtualLocalToLocalOperator<CoefficientPrecision>, BaseClass > py_class (m, className.c_str ());
91- py_class.def (py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> & >());
92- py_class.def (" local_add_vector_product" , &Class::add_vector_product , py::arg (" trans" ), py::arg (" alpha" ), py::arg (" in" ).noconvert (true ), py::arg (" beta" ), py::arg (" out" ).noconvert (true ));
93- py_class.def (" local_add_matrix_product_row_major" , &Class::add_matrix_product_row_major );
97+ py::class_<Class, BaseClass, PyVirtualLocalToLocalOperator<CoefficientPrecision>> py_class (m, className.c_str ());
98+ py_class.def (py::init<LocalRenumbering, LocalRenumbering >());
99+ py_class.def (" local_add_vector_product" , &Class::local_add_vector_product , py::arg (" trans" ), py::arg (" alpha" ), py::arg (" in" ).noconvert (true ), py::arg (" beta" ), py::arg (" out" ).noconvert (true ));
100+ py_class.def (" local_add_matrix_product_row_major" , &Class::local_add_matrix_product_row_major );
94101}
95102
96103#endif
0 commit comments