Skip to content

Commit adc49f1

Browse files
committed
fix: add linsys solver arg
1 parent 2d15db4 commit adc49f1

1 file changed

Lines changed: 20 additions & 14 deletions

File tree

loopsolver/admm_solver.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import dataclass
44
from scipy.sparse.linalg import lsmr
55
from scipy.sparse import vstack, csr_matrix
6-
6+
import tqdm
77

88
@dataclass
99
class Config:
@@ -33,7 +33,8 @@ def admm_solve(
3333
x0: np.ndarray,
3434
admm_weight: float = 0.1,
3535
nmajor=200,
36-
linsys_solver_kwargs={"maxiter": 100},
36+
linsys_solver_kwargs={},
37+
linsys_solver="lsmr",
3738
):
3839
if A.shape[1] != x0.shape[0]:
3940
raise ValueError("Number of columns in interpolation matrix does not match x0")
@@ -49,7 +50,6 @@ def admm_solve(
4950
raise ValueError("Bounds must have two columns")
5051
if A.shape[0] != b.shape[0]:
5152
raise ValueError("Number of rows in interpolation matrix and b are different")
52-
5353
n_ie = bounds.shape[0]
5454
qx_val = np.zeros((Q.shape[0], 1))
5555
model = np.zeros(A.shape[1])
@@ -67,19 +67,24 @@ def admm_solve(
6767
# scale the Q matrix by the admm f
6868
Q *= admm_weight
6969
matrix = vstack([A, Q])
70-
for _i in progressbar(range(nmajor)):
70+
for k in linsys_solver_kwargs:
71+
if not hasattr(linsys_solver_kwargs[k], '__len__') or len(linsys_solver_kwargs[k]) != nmajor:
72+
linsys_solver_kwargs[k] = [linsys_solver_kwargs[k]] * nmajor
73+
for _i in tqdm.tqdm(range(nmajor)):
7174
# current model value
7275
Mx = matrix @ model # np.dot(A, model)
73-
74-
qx_val[:, 0] = Mx[A_size:,] / admm_weight
75-
x0_ADMM = admm_method.admm_method_iterate_admm_array(xmin, xmax, qx_val)
76-
# print(x0_ADMM, qx_val.shape)
77-
# raise Exception
7876
b[:A_size] = b0[:A_size] - Mx[:A_size]
79-
b[A_size:] = -admm_weight * (qx_val[:, 0] - x0_ADMM)
80-
cost_data1 = np.linalg.norm(b[:A_size])
81-
cost_data2 = np.linalg.norm(b0[A_size:])
82-
model_norm = np.linalg.norm(model)
77+
78+
if Q.shape[0] > 0:
79+
80+
qx_val[:, 0] = Mx[A_size:,] / admm_weight
81+
x0_ADMM = admm_method.admm_method_iterate_admm_array(xmin, xmax, qx_val)
82+
# print(x0_ADMM, qx_val.shape)
83+
# raise Exception
84+
b[A_size:] = -admm_weight * (qx_val[:, 0] - x0_ADMM)
85+
# cost_data1 = np.linalg.norm(b[:A_size])
86+
# cost_data2 = np.linalg.norm(b0[A_size:])
87+
# model_norm = np.linalg.norm(model)
8388
if Config.verbose:
8489
cost_data = -1.0
8590
cost_data_model = 0.0
@@ -98,6 +103,7 @@ def admm_solve(
98103
print("cost_data_model = ", cost_data_model)
99104
print("cost_admm = ", cost_admm)
100105
print("----------------------------------------")
101-
x = lsmr(matrix, b, **linsys_solver_kwargs)
106+
linsys_kwargs = {k:v[_i] for k,v in linsys_solver_kwargs.items()}
107+
x = lsmr(matrix, b, **linsys_kwargs)
102108
model += x[0]
103109
return model

0 commit comments

Comments
 (0)