|
4 | 4 | #include <htool/distributed_operator/interfaces/virtual_local_operator.hpp> |
5 | 5 | #include <pybind11/pybind11.h> |
6 | 6 |
|
7 | | -template <typename CoefficientPrecision, typename CoordinatePrecision = CoefficientPrecision> |
| 7 | +template <typename CoefficientPrecision, typename CoordinatePrecision = htool::underlying_type<CoefficientPrecision>> |
8 | 8 | class VirtualLocalOperatorPython : public htool::VirtualLocalOperator<CoefficientPrecision> { |
9 | | - int m_target_offset; |
10 | | - int m_target_size; |
11 | | - int m_source_offset; |
12 | | - int m_source_size; |
13 | 9 |
|
14 | 10 | public: |
15 | | - VirtualLocalOperatorPython(int target_offset, int target_size, int source_offset, int source_size) : m_target_offset(target_offset), m_target_size(target_size), m_source_offset(source_offset), m_source_size(source_size) {} |
| 11 | + int m_target_offset; |
| 12 | + int m_target_size; |
| 13 | + int m_global_source_offset; |
| 14 | + int m_global_source_size; |
| 15 | + int m_local_source_offset; |
| 16 | + int m_local_source_size; |
| 17 | + VirtualLocalOperatorPython(const Cluster<CoordinatePrecision> &target_cluster, const Cluster<CoordinatePrecision> &global_source_cluster, int local_source_offset, int local_source_size) : m_target_offset(target_cluster.get_offset()), m_target_size(target_cluster.get_size()), m_global_source_offset(global_source_cluster.get_offset()), m_global_source_size(global_source_cluster.get_size()), m_local_source_offset(local_source_offset), m_local_source_size(local_source_size) {} |
16 | 18 |
|
17 | 19 | int get_target_offset() const override { return m_target_offset; } |
18 | | - int get_source_offset() const override { return m_source_offset; } |
| 20 | + int get_source_offset() const override { return m_local_source_offset; } |
19 | 21 | int get_target_size() const override { return m_target_size; } |
20 | | - int get_source_size() const override { return m_source_size; } |
| 22 | + int get_source_size() const override { return m_local_source_size; } |
21 | 23 |
|
22 | 24 | void add_vector_product(char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out) const override { |
23 | | - py::array_t<CoefficientPrecision> input(std::array<long int, 1>{m_source_size}, in, py::capsule(in)); |
| 25 | + py::array_t<CoefficientPrecision> input(std::array<long int, 1>{m_global_source_size}, in, py::capsule(in)); |
24 | 26 | py::array_t<CoefficientPrecision> output(std::array<long int, 1>{m_target_size}, out, py::capsule(out)); |
25 | 27 |
|
26 | 28 | local_add_vector_product(trans, alpha, input, beta, output); |
27 | 29 | } |
28 | 30 | void add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out, int mu) const override { |
29 | | - py::array_t<CoefficientPrecision, py::array::c_style> input(std::array<long int, 2>{m_source_size, mu}, in, py::capsule(in)); |
| 31 | + py::array_t<CoefficientPrecision, py::array::c_style> input(std::array<long int, 2>{m_global_source_size, mu}, in, py::capsule(in)); |
30 | 32 | py::array_t<CoefficientPrecision, py::array::c_style> output(std::array<long int, 2>{m_target_size, mu}, out, py::capsule(out)); |
31 | 33 |
|
32 | 34 | local_add_matrix_product_row_major(trans, alpha, input, beta, output); |
33 | 35 | } |
34 | 36 |
|
35 | 37 | virtual void sub_matrix_product_to_local(const CoefficientPrecision *const in, CoefficientPrecision *const out, int mu, int offset, int size) const override { |
36 | | - std::vector<CoefficientPrecision> temp(m_source_size * mu, 0); |
| 38 | + std::vector<CoefficientPrecision> temp(m_global_source_size * mu, 0); |
37 | 39 | std::copy_n(in, size * mu, temp.data() + offset * mu); |
38 | 40 | add_matrix_product_row_major('N', 1, temp.data(), 0, out, mu); |
39 | 41 | }; |
@@ -75,16 +77,22 @@ class PyVirtualLocalOperator : public VirtualLocalOperatorPython<CoefficientPrec |
75 | 77 | } |
76 | 78 | }; |
77 | 79 |
|
78 | | -template <typename CoefficientPrecision> |
| 80 | +template <typename CoefficientPrecision, typename CoordinatePrecision = htool::underlying_type<CoefficientPrecision>> |
79 | 81 | void declare_virtual_local_operator(py::module &m, const std::string &className, const std::string &base_class_name) { |
80 | 82 | using BaseClass = htool::VirtualLocalOperator<CoefficientPrecision>; |
81 | 83 | py::class_<BaseClass>(m, (base_class_name).c_str()); |
82 | 84 |
|
83 | 85 | using Class = VirtualLocalOperatorPython<CoefficientPrecision>; |
84 | 86 | py::class_<Class, PyVirtualLocalOperator<CoefficientPrecision>, BaseClass> py_class(m, className.c_str()); |
85 | | - py_class.def(py::init<int, int, int, int>()); |
| 87 | + py_class.def(py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> &, int, int>()); |
86 | 88 | py_class.def("local_add_vector_product", &Class::add_vector_product, py::arg("trans"), py::arg("alpha"), py::arg("in").noconvert(true), py::arg("beta"), py::arg("out").noconvert(true)); |
87 | 89 | py_class.def("local_add_matrix_product_row_major", &Class::add_matrix_product_row_major); |
| 90 | + py_class.def_readonly("target_offset", &Class::m_target_offset); |
| 91 | + py_class.def_readonly("target_size", &Class::m_target_size); |
| 92 | + py_class.def_readonly("global_source_offset", &Class::m_global_source_offset); |
| 93 | + py_class.def_readonly("global_source_size", &Class::m_global_source_size); |
| 94 | + py_class.def_readonly("local_source_offset", &Class::m_local_source_offset); |
| 95 | + py_class.def_readonly("local_source_size", &Class::m_local_source_size); |
88 | 96 | } |
89 | 97 |
|
90 | 98 | #endif |
0 commit comments