|
4 | 4 | #include <htool/distributed_operator/interfaces/virtual_local_operator.hpp> |
5 | 5 | #include <pybind11/pybind11.h> |
6 | 6 |
|
7 | | -// template <typename CoefficientPrecision> |
8 | | -// class PyVirtualLocalOperator : public htool::VirtualLocalOperator<CoefficientPrecision> { |
9 | | -// public: |
10 | | -// }; |
| 7 | +template <typename CoefficientPrecision, typename CoordinatePrecision = CoefficientPrecision> |
| 8 | +class VirtualLocalOperatorPython : public htool::VirtualLocalOperator<CoefficientPrecision> { |
| 9 | + int m_target_offset; |
| 10 | + int m_target_size; |
| 11 | + int m_source_offset; |
| 12 | + int m_source_size; |
| 13 | + |
| 14 | + public: |
| 15 | + VirtualLocalOperatorPython(int target_offset, int target_size, int source_offset, int source_size) : m_target_offset(target_offset), m_target_size(target_size), m_source_offset(source_offset), m_source_size(source_size) {} |
| 16 | + |
| 17 | + int get_target_offset() const override { return m_target_offset; } |
| 18 | + int get_source_offset() const override { return m_source_offset; } |
| 19 | + int get_target_size() const override { return m_target_size; } |
| 20 | + int get_source_size() const override { return m_source_size; } |
| 21 | + |
| 22 | + void add_vector_product(char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out) const override { |
| 23 | + py::array_t<CoefficientPrecision> input(std::array<long int, 1>{m_source_size}, in, py::capsule(in)); |
| 24 | + py::array_t<CoefficientPrecision> output(std::array<long int, 1>{m_target_size}, out, py::capsule(out)); |
| 25 | + |
| 26 | + local_add_vector_product(trans, alpha, input, beta, output); |
| 27 | + } |
| 28 | + void add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out, int mu) const override { |
| 29 | + py::array_t<CoefficientPrecision, py::array::c_style> input(std::array<long int, 2>{m_source_size, mu}, in, py::capsule(in)); |
| 30 | + py::array_t<CoefficientPrecision, py::array::c_style> output(std::array<long int, 2>{m_target_size, mu}, out, py::capsule(out)); |
| 31 | + |
| 32 | + local_add_matrix_product_row_major(trans, alpha, input, beta, output); |
| 33 | + } |
| 34 | + |
| 35 | + virtual void sub_matrix_product_to_local(const CoefficientPrecision *const in, CoefficientPrecision *const out, int mu, int offset, int size) const override { |
| 36 | + std::vector<CoefficientPrecision> temp(m_source_size * mu, 0); |
| 37 | + std::copy_n(in, size * mu, temp.data() + offset * mu); |
| 38 | + add_matrix_product_row_major('N', 1, temp.data(), 0, out, mu); |
| 39 | + }; |
| 40 | + |
| 41 | + virtual void local_add_vector_product(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision> &out) const = 0; // LCOV_EXCL_LINE |
| 42 | + |
| 43 | + virtual void local_add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision, py::array::c_style> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision, py::array::c_style> &out) const = 0; // LCOV_EXCL_LINE |
| 44 | +}; |
| 45 | + |
| 46 | +template <typename CoefficientPrecision> |
| 47 | +class PyVirtualLocalOperator : public VirtualLocalOperatorPython<CoefficientPrecision> { |
| 48 | + public: |
| 49 | + using VirtualLocalOperatorPython<CoefficientPrecision>::VirtualLocalOperatorPython; |
| 50 | + |
| 51 | + /* Trampoline (need one for each virtual function) */ |
| 52 | + virtual void local_add_vector_product(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision> &out) const override { |
| 53 | + PYBIND11_OVERRIDE_PURE( |
| 54 | + void, /* Return type */ |
| 55 | + VirtualLocalOperatorPython<CoefficientPrecision>, /* Parent class */ |
| 56 | + add_vector_product, /* Name of function in C++ (must match Python name) */ |
| 57 | + trans, |
| 58 | + alpha, |
| 59 | + in, |
| 60 | + beta, |
| 61 | + out /* Argument(s) */ |
| 62 | + ); |
| 63 | + } |
| 64 | + virtual void local_add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision, py::array::c_style> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision, py::array::c_style> &out) const override { |
| 65 | + PYBIND11_OVERRIDE_PURE( |
| 66 | + void, /* Return type */ |
| 67 | + VirtualLocalOperatorPython<CoefficientPrecision>, /* Parent class */ |
| 68 | + add_matrix_product_row_major, /* Name of function in C++ (must match Python name) */ |
| 69 | + trans, |
| 70 | + alpha, |
| 71 | + in, |
| 72 | + beta, |
| 73 | + out /* Argument(s) */ |
| 74 | + ); |
| 75 | + } |
| 76 | +}; |
11 | 77 |
|
12 | 78 | template <typename CoefficientPrecision> |
13 | | -void declare_interface_local_operator(py::module &m, const std::string &class_name) { |
14 | | - using Class = htool::VirtualLocalOperator<CoefficientPrecision>; |
15 | | - py::class_<Class>(m, class_name.c_str()); |
| 79 | +void declare_virtual_local_operator(py::module &m, const std::string &className, const std::string &base_class_name) { |
| 80 | + using BaseClass = htool::VirtualLocalOperator<CoefficientPrecision>; |
| 81 | + py::class_<BaseClass>(m, (base_class_name).c_str()); |
| 82 | + |
| 83 | + using Class = VirtualLocalOperatorPython<CoefficientPrecision>; |
| 84 | + py::class_<Class, PyVirtualLocalOperator<CoefficientPrecision>, BaseClass> py_class(m, className.c_str()); |
| 85 | + py_class.def(py::init<int, int, int, int>()); |
| 86 | + py_class.def("local_add_vector_product", &Class::add_vector_product, py::arg("trans"), py::arg("alpha"), py::arg("in").noconvert(true), py::arg("beta"), py::arg("out").noconvert(true)); |
| 87 | + py_class.def("local_add_matrix_product_row_major", &Class::add_matrix_product_row_major); |
16 | 88 | } |
17 | 89 |
|
18 | 90 | #endif |
0 commit comments