33from dataclasses import dataclass
44from scipy .sparse .linalg import lsmr
55from scipy .sparse import vstack , csr_matrix
6-
6+ import tqdm
77
88@dataclass
99class Config :
@@ -33,7 +33,8 @@ def admm_solve(
3333 x0 : np .ndarray ,
3434 admm_weight : float = 0.1 ,
3535 nmajor = 200 ,
36- linsys_solver_kwargs = {"maxiter" : 100 },
36+ linsys_solver_kwargs = {},
37+ linsys_solver = "lsmr" ,
3738):
3839 if A .shape [1 ] != x0 .shape [0 ]:
3940 raise ValueError ("Number of columns in interpolation matrix does not match x0" )
@@ -49,7 +50,6 @@ def admm_solve(
4950 raise ValueError ("Bounds must have two columns" )
5051 if A .shape [0 ] != b .shape [0 ]:
5152 raise ValueError ("Number of rows in interpolation matrix and b are different" )
52-
5353 n_ie = bounds .shape [0 ]
5454 qx_val = np .zeros ((Q .shape [0 ], 1 ))
5555 model = np .zeros (A .shape [1 ])
@@ -67,19 +67,24 @@ def admm_solve(
6767 # scale the Q matrix by the admm f
6868 Q *= admm_weight
6969 matrix = vstack ([A , Q ])
70- for _i in progressbar (range (nmajor )):
70+ for k in linsys_solver_kwargs :
71+ if not hasattr (linsys_solver_kwargs [k ], '__len__' ) or len (linsys_solver_kwargs [k ]) != nmajor :
72+ linsys_solver_kwargs [k ] = [linsys_solver_kwargs [k ]] * nmajor
73+ for _i in tqdm .tqdm (range (nmajor )):
7174 # current model value
7275 Mx = matrix @ model # np.dot(A, model)
73-
74- qx_val [:, 0 ] = Mx [A_size :,] / admm_weight
75- x0_ADMM = admm_method .admm_method_iterate_admm_array (xmin , xmax , qx_val )
76- # print(x0_ADMM, qx_val.shape)
77- # raise Exception
7876 b [:A_size ] = b0 [:A_size ] - Mx [:A_size ]
79- b [A_size :] = - admm_weight * (qx_val [:, 0 ] - x0_ADMM )
80- cost_data1 = np .linalg .norm (b [:A_size ])
81- cost_data2 = np .linalg .norm (b0 [A_size :])
82- model_norm = np .linalg .norm (model )
77+
78+ if Q .shape [0 ] > 0 :
79+
80+ qx_val [:, 0 ] = Mx [A_size :,] / admm_weight
81+ x0_ADMM = admm_method .admm_method_iterate_admm_array (xmin , xmax , qx_val )
82+ # print(x0_ADMM, qx_val.shape)
83+ # raise Exception
84+ b [A_size :] = - admm_weight * (qx_val [:, 0 ] - x0_ADMM )
85+ # cost_data1 = np.linalg.norm(b[:A_size])
86+ # cost_data2 = np.linalg.norm(b0[A_size:])
87+ # model_norm = np.linalg.norm(model)
8388 if Config .verbose :
8489 cost_data = - 1.0
8590 cost_data_model = 0.0
@@ -98,6 +103,7 @@ def admm_solve(
98103 print ("cost_data_model = " , cost_data_model )
99104 print ("cost_admm = " , cost_admm )
100105 print ("----------------------------------------" )
101- x = lsmr (matrix , b , ** linsys_solver_kwargs )
106+ linsys_kwargs = {k :v [_i ] for k ,v in linsys_solver_kwargs .items ()}
107+ x = lsmr (matrix , b , ** linsys_kwargs )
102108 model += x [0 ]
103109 return model
0 commit comments