Skip to content

Commit 2b7d7cf

Browse files
fix format
1 parent 9c00d7d commit 2b7d7cf

File tree

2 files changed

+59
-14
lines changed

2 files changed

+59
-14
lines changed

example/define_custom_local_operator.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(
99
target_cluster: Htool.Cluster,
1010
global_source_cluster: Htool.Cluster,
1111
local_source_offset: int,
12-
local_source_size:int,
12+
local_source_size: int,
1313
) -> None:
1414
super().__init__(
1515
target_cluster,
@@ -22,12 +22,10 @@ def __init__(
2222
self.data = np.zeros((target_size, local_source_size))
2323
generator.build_submatrix(
2424
target_cluster.get_permutation()[
25-
target_offset : target_offset
26-
+ target_size
25+
target_offset : target_offset + target_size
2726
],
2827
global_source_cluster.get_permutation()[
29-
local_source_offset : local_source_offset
30-
+ local_source_size
28+
local_source_offset : local_source_offset + local_source_size
3129
],
3230
self.data,
3331
)
@@ -38,20 +36,60 @@ def add_vector_product(
3836
# Beware, inplace operation needed for output to keep the underlying data
3937
output *= beta
4038
if trans == "N":
41-
output += alpha * self.data.dot(input[self.local_source_offset : self.local_source_offset+ self.local_source_size])
39+
output += alpha * self.data.dot(
40+
input[
41+
self.local_source_offset : self.local_source_offset
42+
+ self.local_source_size
43+
]
44+
)
4245
elif trans == "T":
43-
output += alpha * np.transpose(self.data).dot(input[self.local_source_offset : self.local_source_offset+ self.local_source_size])
46+
output += alpha * np.transpose(self.data).dot(
47+
input[
48+
self.local_source_offset : self.local_source_offset
49+
+ self.local_source_size
50+
]
51+
)
4452
elif trans == "C":
45-
output += alpha * np.vdot(np.transpose(self.data), input[self.local_source_offset : self.local_source_offset+ self.local_source_size])
53+
output += alpha * np.vdot(
54+
np.transpose(self.data),
55+
input[
56+
self.local_source_offset : self.local_source_offset
57+
+ self.local_source_size
58+
],
59+
)
4660

4761
def add_matrix_product_row_major(
4862
self, trans, alpha, input: np.array, beta, output: np.array
4963
) -> None:
5064
output *= beta
5165
if trans == "N":
52-
output += alpha * self.data @ input[self.local_source_offset : self.local_source_offset+ self.local_source_size,:]
66+
output += (
67+
alpha
68+
* self.data
69+
@ input[
70+
self.local_source_offset : self.local_source_offset
71+
+ self.local_source_size,
72+
:,
73+
]
74+
)
5375
elif trans == "T":
54-
output += alpha * np.transpose(self.data) @ input[self.local_source_offset : self.local_source_offset+ self.local_source_size,:]
76+
output += (
77+
alpha
78+
* np.transpose(self.data)
79+
@ input[
80+
self.local_source_offset : self.local_source_offset
81+
+ self.local_source_size,
82+
:,
83+
]
84+
)
5585
elif trans == "C":
56-
output += alpha * np.matrix.H(self.data) @ input[self.local_source_offset : self.local_source_offset+ self.local_source_size,:]
86+
output += (
87+
alpha
88+
* np.matrix.H(self.data)
89+
@ input[
90+
self.local_source_offset : self.local_source_offset
91+
+ self.local_source_size,
92+
:,
93+
]
94+
)
5795
output = np.asfortranarray(output)

example/use_local_hmatrix_compression.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,20 @@
8383
)
8484

8585
local_operator_2 = None
86-
if source_cluster.get_size()-local_source_cluster.get_size()-local_source_cluster.get_offset() > 0:
86+
if (
87+
source_cluster.get_size()
88+
- local_source_cluster.get_size()
89+
- local_source_cluster.get_offset()
90+
> 0
91+
):
8792
local_operator_2 = CustomLocalOperator(
8893
generator,
8994
local_target_cluster,
9095
source_cluster,
91-
local_source_cluster.get_size()+local_source_cluster.get_offset(),
92-
source_cluster.get_size()-local_source_cluster.get_size()-local_source_cluster.get_offset(),
96+
local_source_cluster.get_size() + local_source_cluster.get_offset(),
97+
source_cluster.get_size()
98+
- local_source_cluster.get_size()
99+
- local_source_cluster.get_offset(),
93100
)
94101

95102
if local_operator_1:

0 commit comments

Comments
 (0)