Skip to content

Commit d4a3fa3

Browse files
update
1 parent 47108fd commit d4a3fa3

File tree

12 files changed

+153
-120
lines changed

12 files changed

+153
-120
lines changed

example/define_custom_local_operator.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,32 @@
22
import numpy as np
33

44

5-
class CustomLocalOperator(Htool.LocalOperator):
5+
class CustomLocalOperator(Htool.VirtualLocalOperator):
66
def __init__(
77
self,
88
generator: Htool.VirtualGenerator,
9-
target_cluster: Htool.Cluster,
10-
source_cluster: Htool.Cluster,
11-
symmetry: str = "N",
12-
UPLO: str = "N",
13-
target_use_permutation_to_mvprod: bool = False,
14-
source_use_permutation_to_mvprod: bool = False,
9+
target_offset: int,
10+
target_size: int,
11+
target_permutation,
12+
source_offset: int,
13+
source_size: int,
14+
source_permutation,
1515
) -> None:
1616
super().__init__(
17-
target_cluster,
18-
source_cluster,
19-
symmetry,
20-
UPLO,
21-
target_use_permutation_to_mvprod,
22-
source_use_permutation_to_mvprod,
17+
target_offset,
18+
target_size,
19+
source_offset,
20+
source_size,
2321
)
24-
self.data = np.zeros((target_cluster.get_size(), source_cluster.get_size()))
22+
self.data = np.zeros((target_size, source_size))
2523
generator.build_submatrix(
26-
target_cluster.get_permutation()[
27-
target_cluster.get_offset() : target_cluster.get_offset()
28-
+ target_cluster.get_size()
24+
target_permutation[
25+
target_offset : target_offset
26+
+ target_size
2927
],
30-
source_cluster.get_permutation()[
31-
source_cluster.get_offset() : source_cluster.get_offset()
32-
+ source_cluster.get_size()
28+
source_permutation[
29+
source_offset : source_offset
30+
+ source_size
3331
],
3432
self.data,
3533
)

example/use_custom_local_operator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@
3838
# Build local operator
3939
local_operator = CustomLocalOperator(
4040
generator,
41-
target_cluster.get_cluster_on_partition(mpi4py.MPI.COMM_WORLD.rank),
42-
source_cluster,
43-
"N",
44-
"N",
45-
False,
46-
False,
41+
target_cluster.get_cluster_on_partition(mpi4py.MPI.COMM_WORLD.rank).get_offset(),
42+
target_cluster.get_cluster_on_partition(mpi4py.MPI.COMM_WORLD.rank).get_size(),
43+
target_cluster.get_cluster_on_partition(mpi4py.MPI.COMM_WORLD.rank).get_permutation(),
44+
source_cluster.get_offset(),
45+
source_cluster.get_size(),
46+
source_cluster.get_permutation(),
4747
)
4848

4949
# Build distributed operator
5050
custom_local_approximation = Htool.CustomApproximationBuilder(
51-
target_cluster, source_cluster, "N", "N", mpi4py.MPI.COMM_WORLD, local_operator
51+
target_cluster, source_cluster, mpi4py.MPI.COMM_WORLD, local_operator
5252
)
5353
distributed_operator = custom_local_approximation.distributed_operator
5454

example/use_local_hmatrix_compression.py

Lines changed: 20 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -72,49 +72,28 @@
7272
hmatrix = default_local_approximation.hmatrix
7373
Htool.recompression(hmatrix)
7474

75-
76-
# Build off diagonal operators
77-
off_diagonal_nc_1 = local_source_cluster.get_offset()
78-
off_diagonal_nc_2 = (
79-
source_cluster.get_size()
80-
- local_source_cluster.get_size()
81-
- local_source_cluster.get_offset()
82-
)
83-
local_nc = local_source_cluster.get_size()
84-
85-
off_diagonal_partition = np.zeros((2, 2), dtype=int)
86-
off_diagonal_partition[0, 0] = 0
87-
off_diagonal_partition[1, 0] = off_diagonal_nc_1
88-
off_diagonal_partition[0, 1] = off_diagonal_nc_1 + local_nc
89-
off_diagonal_partition[1, 1] = off_diagonal_nc_2
90-
off_diagonal_cluster: Htool.Cluster = cluster_builder.create_cluster_tree(
91-
permuted_source_points, number_of_children, 2, partition=off_diagonal_partition
92-
)
93-
94-
off_diagonal_generator = CustomGenerator(target_points, permuted_source_points)
95-
9675
local_operator_1 = None
97-
if off_diagonal_nc_1 > 0:
76+
if local_source_cluster.get_offset() > 0:
9877
local_operator_1 = CustomLocalOperator(
99-
off_diagonal_generator,
100-
local_target_cluster,
101-
off_diagonal_cluster.get_cluster_on_partition(0),
102-
"N",
103-
"N",
104-
False,
105-
True,
78+
generator,
79+
local_target_cluster.get_offset(),
80+
local_target_cluster.get_size(),
81+
local_target_cluster.get_permutation(),
82+
0,
83+
local_source_cluster.get_offset(),
84+
source_cluster.get_permutation(),
10685
)
10786

10887
local_operator_2 = None
109-
if off_diagonal_nc_2 > 0:
88+
if source_cluster.get_size()-local_source_cluster.get_size()-local_source_cluster.get_offset() > 0:
11089
local_operator_2 = CustomLocalOperator(
111-
off_diagonal_generator,
112-
local_target_cluster,
113-
off_diagonal_cluster.get_cluster_on_partition(1),
114-
"N",
115-
"N",
116-
False,
117-
True,
90+
generator,
91+
local_target_cluster.get_offset(),
92+
local_target_cluster.get_size(),
93+
local_target_cluster.get_permutation(),
94+
local_source_cluster.get_size()+local_source_cluster.get_offset(),
95+
source_cluster.get_size()-local_source_cluster.get_size()-local_source_cluster.get_offset(),
96+
source_cluster.get_permutation(),
11897
)
11998

12099
if local_operator_1:
@@ -152,26 +131,16 @@
152131
if dimension == 2:
153132
ax1 = fig.add_subplot(2, 2, 1)
154133
ax2 = fig.add_subplot(2, 2, 2)
155-
ax3 = fig.add_subplot(2, 2, 3)
156-
ax4 = fig.add_subplot(2, 2, 4)
134+
ax3 = fig.add_subplot(2, 2, 4)
157135
elif dimension == 3:
158136
ax1 = fig.add_subplot(2, 2, 1, projection="3d")
159137
ax2 = fig.add_subplot(2, 2, 2, projection="3d")
160-
ax3 = fig.add_subplot(2, 2, 3, projection="3d")
161-
ax4 = fig.add_subplot(2, 2, 4)
138+
ax3 = fig.add_subplot(2, 2, 4)
162139

163140
ax1.set_title("source cluster at depth 1")
164141
ax2.set_title("source cluster at depth 2")
165-
ax3.set_title("off diagonal cluster on rank 0 at depth 2")
166-
ax4.set_title("Hmatrix on rank 0")
142+
ax3.set_title("Hmatrix on rank 0")
167143
Htool.plot(ax1, source_cluster, source_points, 1)
168144
Htool.plot(ax2, source_cluster, source_points, 2)
169-
if mpi4py.MPI.COMM_WORLD.Get_size() > 1:
170-
Htool.plot(
171-
ax3,
172-
off_diagonal_cluster.get_cluster_on_partition(1),
173-
permuted_source_points,
174-
2,
175-
)
176-
Htool.plot(ax4, hmatrix)
145+
Htool.plot(ax3, hmatrix)
177146
plt.show()

lib/htool

Submodule htool updated 86 files

src/htool/distributed_operator/distributed_operator.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@
33
#include "../misc/utility.hpp"
44
#include "../misc/wrapper_mpi.hpp"
55
#include <htool/distributed_operator/distributed_operator.hpp>
6+
#include <htool/distributed_operator/linalg/add_distributed_operator_matrix_product_global_to_global.hpp>
7+
#include <htool/distributed_operator/linalg/add_distributed_operator_vector_product_global_to_global.hpp>
68
#include <htool/distributed_operator/utility.hpp>
9+
#include <htool/matrix/matrix_view.hpp>
710
#include <pybind11/pybind11.h>
811

912
template <typename CoefficientPrecision>
1013
void declare_distributed_operator(py::module &m, const std::string &class_name) {
1114
using Class = DistributedOperator<CoefficientPrecision>;
1215

1316
py::class_<Class> py_class(m, class_name.c_str());
14-
py_class.def(py::init<VirtualPartition<CoefficientPrecision> &, VirtualPartition<CoefficientPrecision> &, char, char, MPI_Comm_wrapper>(), py::keep_alive<1, 2>(), py::keep_alive<1, 3>());
17+
py_class.def(py::init<VirtualPartition<CoefficientPrecision> &, VirtualPartition<CoefficientPrecision> &, MPI_Comm_wrapper>(), py::keep_alive<1, 2>(), py::keep_alive<1, 3>());
1518
py_class.def("add_local_operator", &Class::add_local_operator, py::keep_alive<1, 2>());
1619

1720
// Linear algebra
@@ -25,7 +28,7 @@ void declare_distributed_operator(py::module &m, const std::string &class_name)
2528
}
2629
py::array_t<CoefficientPrecision, py::array::f_style> result(std::array<long int, 1>{self.get_target_partition().get_global_size()});
2730
std::fill_n(result.mutable_data(), self.get_target_partition().get_global_size(), CoefficientPrecision(0));
28-
self.vector_product_global_to_global(input.data(), result.mutable_data());
31+
htool::add_distributed_operator_vector_product_global_to_global<CoefficientPrecision>('N', 1, self, input.data(), 0, result.mutable_data(), nullptr);
2932

3033
return result;
3134
},
@@ -46,7 +49,10 @@ void declare_distributed_operator(py::module &m, const std::string &class_name)
4649
std::array<long int, 2> shape{self.get_target_partition().get_global_size(), mu};
4750
py::array_t<CoefficientPrecision, py::array::f_style> result(shape);
4851
std::fill_n(result.mutable_data(), self.get_target_partition().get_global_size() * mu, CoefficientPrecision(0));
49-
self.matrix_product_global_to_global(input.data(), result.mutable_data(), mu);
52+
MatrixView<const CoefficientPrecision> input_view(self.get_source_partition().get_global_size(), mu, input.data());
53+
MatrixView<CoefficientPrecision> output_view(self.get_target_partition().get_global_size(), mu, result.mutable_data());
54+
CoefficientPrecision *work = nullptr;
55+
add_distributed_operator_matrix_product_global_to_global('N', CoefficientPrecision(1), self, input_view, CoefficientPrecision(0), output_view, work);
5056

5157
return result;
5258
},

src/htool/distributed_operator/utility.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ void declare_distributed_operator_utility(py::module &m, std::string prefix = ""
1717
std::string default_local_approximation_name = prefix + "DefaultLocalApproximationBuilder";
1818

1919
py::class_<CustomApproximation> custom_approximation_class(m, custom_approximation_name.c_str());
20-
custom_approximation_class.def(py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> &, char, char, MPI_Comm_wrapper, const VirtualLocalOperator<CoefficientPrecision> &>());
20+
custom_approximation_class.def(py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> &, MPI_Comm_wrapper, const VirtualLocalOperator<CoefficientPrecision> &>());
2121
custom_approximation_class.def_property_readonly(
2222
"distributed_operator", [](const CustomApproximation &self) { return &self.distributed_operator; }, py::return_value_policy::reference_internal);
2323

src/htool/local_operator/local_operator.hpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class LocalOperatorPython : public htool::LocalOperator<CoefficientPrecision, Co
99
public:
1010
using htool::LocalOperator<CoefficientPrecision, CoordinatePrecision>::LocalOperator;
1111

12-
LocalOperatorPython(const Cluster<CoordinatePrecision> &cluster_tree_target, const Cluster<CoordinatePrecision> &cluster_tree_source, char symmetry = 'N', char UPLO = 'N', bool target_use_permutation_to_mvprod = false, bool source_use_permutation_to_mvprod = false) : LocalOperator<CoefficientPrecision, CoordinatePrecision>(cluster_tree_target, cluster_tree_source, symmetry, UPLO, target_use_permutation_to_mvprod, source_use_permutation_to_mvprod) {}
12+
LocalOperatorPython(const Cluster<CoordinatePrecision> &cluster_tree_target, const Cluster<CoordinatePrecision> &cluster_tree_source, bool target_use_permutation_to_mvprod = false, bool source_use_permutation_to_mvprod = false) : LocalOperator<CoefficientPrecision, CoordinatePrecision>(cluster_tree_target, cluster_tree_source, target_use_permutation_to_mvprod, source_use_permutation_to_mvprod) {}
1313

1414
void local_add_vector_product(char trans, CoefficientPrecision alpha, const CoefficientPrecision *in, CoefficientPrecision beta, CoefficientPrecision *out) const override {
1515

@@ -19,14 +19,6 @@ class LocalOperatorPython : public htool::LocalOperator<CoefficientPrecision, Co
1919
add_vector_product(trans, alpha, input, beta, output);
2020
}
2121

22-
void local_add_vector_product_symmetric(char trans, CoefficientPrecision alpha, const CoefficientPrecision *in, CoefficientPrecision beta, CoefficientPrecision *out, char UPLO, char symmetry) const override {
23-
24-
py::array_t<CoefficientPrecision> input(std::array<long int, 1>{this->m_source_cluster.get_size()}, in, py::capsule(in));
25-
py::array_t<CoefficientPrecision> output(std::array<long int, 1>{this->m_target_cluster.get_size()}, out, py::capsule(out));
26-
27-
add_vector_product(trans, alpha, input, beta, output);
28-
}
29-
3022
void local_add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const CoefficientPrecision *in, CoefficientPrecision beta, CoefficientPrecision *out, int mu) const override {
3123

3224
py::array_t<CoefficientPrecision, py::array::c_style> input(std::array<long int, 2>{this->m_source_cluster.get_size(), mu}, in, py::capsule(in));
@@ -35,14 +27,6 @@ class LocalOperatorPython : public htool::LocalOperator<CoefficientPrecision, Co
3527
add_matrix_product_row_major(trans, alpha, input, beta, output);
3628
}
3729

38-
void local_add_matrix_product_symmetric_row_major(char trans, CoefficientPrecision alpha, const CoefficientPrecision *in, CoefficientPrecision beta, CoefficientPrecision *out, int mu, char UPLO, char symmetry) const override {
39-
40-
py::array_t<CoefficientPrecision, py::array::c_style> input(std::array<long int, 2>{this->m_source_cluster.get_size(), 1}, in, py::capsule(in));
41-
py::array_t<CoefficientPrecision, py::array::c_style> output(std::array<long int, 2>{this->m_target_cluster.get_size(), 1}, out, py::capsule(out));
42-
43-
add_matrix_product_row_major(trans, alpha, input, beta, output);
44-
}
45-
4630
// lcov does not see it because of trampoline I assume
4731
virtual void add_vector_product(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision> &out) const = 0; // LCOV_EXCL_LINE
4832
// virtual void local_add_vector_product_symmetric(char trans, CoefficientPrecision alpha, const std::vector<CoefficientPrecision> &in, CoefficientPrecision beta, std::vector<CoefficientPrecision> &out) const = 0; // LCOV_EXCL_LINE
@@ -92,7 +76,7 @@ void declare_local_operator(py::module &m, const std::string &class_name) {
9276

9377
using Class = LocalOperatorPython<CoefficientPrecision, CoordinatePrecision>;
9478
py::class_<Class, PyLocalOperator<CoefficientPrecision, CoordinatePrecision>, BaseClass> py_class(m, class_name.c_str());
95-
py_class.def(py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> &, char, char, bool, bool>());
79+
py_class.def(py::init<const Cluster<CoordinatePrecision> &, const Cluster<CoordinatePrecision> &, bool, bool>());
9680
py_class.def("add_vector_product", &Class::add_vector_product, py::arg("trans"), py::arg("alpha"), py::arg("in").noconvert(true), py::arg("beta"), py::arg("out").noconvert(true));
9781
py_class.def("add_matrix_product_row_major", &Class::add_matrix_product_row_major);
9882
}

src/htool/local_operator/virtual_local_operator.hpp

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,87 @@
44
#include <htool/distributed_operator/interfaces/virtual_local_operator.hpp>
55
#include <pybind11/pybind11.h>
66

7-
// template <typename CoefficientPrecision>
8-
// class PyVirtualLocalOperator : public htool::VirtualLocalOperator<CoefficientPrecision> {
9-
// public:
10-
// };
7+
template <typename CoefficientPrecision, typename CoordinatePrecision = CoefficientPrecision>
8+
class VirtualLocalOperatorPython : public htool::VirtualLocalOperator<CoefficientPrecision> {
9+
int m_target_offset;
10+
int m_target_size;
11+
int m_source_offset;
12+
int m_source_size;
13+
14+
public:
15+
VirtualLocalOperatorPython(int target_offset, int target_size, int source_offset, int source_size) : m_target_offset(target_offset), m_target_size(target_size), m_source_offset(source_offset), m_source_size(source_size) {}
16+
17+
int get_target_offset() const override { return m_target_offset; }
18+
int get_source_offset() const override { return m_source_offset; }
19+
int get_target_size() const override { return m_target_size; }
20+
int get_source_size() const override { return m_source_size; }
21+
22+
void add_vector_product(char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out) const override {
23+
py::array_t<CoefficientPrecision> input(std::array<long int, 1>{m_source_size}, in, py::capsule(in));
24+
py::array_t<CoefficientPrecision> output(std::array<long int, 1>{m_target_size}, out, py::capsule(out));
25+
26+
local_add_vector_product(trans, alpha, input, beta, output);
27+
}
28+
void add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const CoefficientPrecision *const in, CoefficientPrecision beta, CoefficientPrecision *const out, int mu) const override {
29+
py::array_t<CoefficientPrecision, py::array::c_style> input(std::array<long int, 2>{m_source_size, mu}, in, py::capsule(in));
30+
py::array_t<CoefficientPrecision, py::array::c_style> output(std::array<long int, 2>{m_target_size, mu}, out, py::capsule(out));
31+
32+
local_add_matrix_product_row_major(trans, alpha, input, beta, output);
33+
}
34+
35+
virtual void sub_matrix_product_to_local(const CoefficientPrecision *const in, CoefficientPrecision *const out, int mu, int offset, int size) const override {
36+
std::vector<CoefficientPrecision> temp(m_source_size * mu, 0);
37+
std::copy_n(in, size * mu, temp.data() + offset * mu);
38+
add_matrix_product_row_major('N', 1, temp.data(), 0, out, mu);
39+
};
40+
41+
virtual void local_add_vector_product(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision> &out) const = 0; // LCOV_EXCL_LINE
42+
43+
virtual void local_add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision, py::array::c_style> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision, py::array::c_style> &out) const = 0; // LCOV_EXCL_LINE
44+
};
45+
46+
template <typename CoefficientPrecision>
47+
class PyVirtualLocalOperator : public VirtualLocalOperatorPython<CoefficientPrecision> {
48+
public:
49+
using VirtualLocalOperatorPython<CoefficientPrecision>::VirtualLocalOperatorPython;
50+
51+
/* Trampoline (need one for each virtual function) */
52+
virtual void local_add_vector_product(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision> &out) const override {
53+
PYBIND11_OVERRIDE_PURE(
54+
void, /* Return type */
55+
VirtualLocalOperatorPython<CoefficientPrecision>, /* Parent class */
56+
add_vector_product, /* Name of function in C++ (must match Python name) */
57+
trans,
58+
alpha,
59+
in,
60+
beta,
61+
out /* Argument(s) */
62+
);
63+
}
64+
virtual void local_add_matrix_product_row_major(char trans, CoefficientPrecision alpha, const py::array_t<CoefficientPrecision, py::array::c_style> &in, CoefficientPrecision beta, py::array_t<CoefficientPrecision, py::array::c_style> &out) const override {
65+
PYBIND11_OVERRIDE_PURE(
66+
void, /* Return type */
67+
VirtualLocalOperatorPython<CoefficientPrecision>, /* Parent class */
68+
add_matrix_product_row_major, /* Name of function in C++ (must match Python name) */
69+
trans,
70+
alpha,
71+
in,
72+
beta,
73+
out /* Argument(s) */
74+
);
75+
}
76+
};
1177

1278
template <typename CoefficientPrecision>
13-
void declare_interface_local_operator(py::module &m, const std::string &class_name) {
14-
using Class = htool::VirtualLocalOperator<CoefficientPrecision>;
15-
py::class_<Class>(m, class_name.c_str());
79+
void declare_virtual_local_operator(py::module &m, const std::string &className, const std::string &base_class_name) {
80+
using BaseClass = htool::VirtualLocalOperator<CoefficientPrecision>;
81+
py::class_<BaseClass>(m, (base_class_name).c_str());
82+
83+
using Class = VirtualLocalOperatorPython<CoefficientPrecision>;
84+
py::class_<Class, PyVirtualLocalOperator<CoefficientPrecision>, BaseClass> py_class(m, className.c_str());
85+
py_class.def(py::init<int, int, int, int>());
86+
py_class.def("local_add_vector_product", &Class::add_vector_product, py::arg("trans"), py::arg("alpha"), py::arg("in").noconvert(true), py::arg("beta"), py::arg("out").noconvert(true));
87+
py_class.def("local_add_matrix_product_row_major", &Class::add_matrix_product_row_major);
1688
}
1789

1890
#endif

0 commit comments

Comments
 (0)