@@ -35,6 +35,9 @@ def admm_solve(
3535 nmajor = 200 ,
3636 linsys_solver_kwargs = {},
3737 linsys_solver = "lsmr" ,
38+ batch_size : int = None ,
39+ batch_fraction : float = None ,
40+ random_seed : int = None ,
3841):
3942 if A .shape [1 ] != x0 .shape [0 ]:
4043 raise ValueError ("Number of columns in interpolation matrix does not match x0" )
@@ -51,6 +54,31 @@ def admm_solve(
5154 if A .shape [0 ] != b .shape [0 ]:
5255 raise ValueError ("Number of rows in interpolation matrix and b are different" )
5356 n_ie = bounds .shape [0 ]
57+
58+ # Setup batch mode
59+ use_batch_mode = False
60+ n_batch = n_ie
61+ rng = None
62+
63+ if batch_size is not None and batch_fraction is not None :
64+ raise ValueError ("Cannot specify both batch_size and batch_fraction" )
65+
66+ if batch_size is not None :
67+ if batch_size < 1 or batch_size > n_ie :
68+ raise ValueError (f"batch_size must be between 1 and { n_ie } " )
69+ n_batch = batch_size
70+ use_batch_mode = True
71+ elif batch_fraction is not None :
72+ if batch_fraction <= 0 or batch_fraction > 1 :
73+ raise ValueError ("batch_fraction must be between 0 and 1" )
74+ n_batch = max (1 , int (n_ie * batch_fraction ))
75+ use_batch_mode = True
76+
77+ if use_batch_mode and random_seed is not None :
78+ rng = np .random .RandomState (random_seed )
79+ elif use_batch_mode :
80+ rng = np .random .RandomState ()
81+
5482 qx_val = np .zeros ((Q .shape [0 ], 1 ))
5583 model = np .zeros (A .shape [1 ])
5684 model [:] = x0 [:]
@@ -71,38 +99,75 @@ def admm_solve(
7199 if not hasattr (linsys_solver_kwargs [k ], '__len__' ) or len (linsys_solver_kwargs [k ]) != nmajor :
72100 linsys_solver_kwargs [k ] = [linsys_solver_kwargs [k ]] * nmajor
73101 for _i in tqdm .tqdm (range (nmajor )):
102+ # Sample batch of inequality constraints if batch mode is enabled
103+ if use_batch_mode :
104+ batch_idx = rng .choice (n_ie , size = n_batch , replace = False )
105+ batch_idx = np .sort (batch_idx ) # Sort for consistent sparse matrix operations
106+ Q_batch = Q [batch_idx , :]
107+ xmin_batch = xmin [batch_idx , :]
108+ xmax_batch = xmax [batch_idx , :]
109+ matrix = vstack ([A , Q_batch ])
110+ b = np .zeros (A .shape [0 ] + n_batch )
111+
112+ # Create a temporary ADMM object for the batch if needed
113+ # This maintains z and u variables only for the sampled constraints
114+ admm_batch = ADMM (n_batch )
115+ admm_batch .z = admm_method .z [batch_idx ].copy ()
116+ admm_batch .u = admm_method .u [batch_idx ].copy ()
117+ else :
118+ batch_idx = None
119+
74120 # current model value
75121 Mx = matrix @ model # np.dot(A, model)
76122 b [:A_size ] = b0 [:A_size ] - Mx [:A_size ]
77123
78124 if Q .shape [0 ] > 0 :
79-
80- qx_val [:, 0 ] = Mx [A_size :,] / admm_weight
81- x0_ADMM = admm_method .admm_method_iterate_admm_array (xmin , xmax , qx_val )
82- # print(x0_ADMM, qx_val.shape)
83- # raise Exception
84- b [A_size :] = - admm_weight * (qx_val [:, 0 ] - x0_ADMM )
125+ if use_batch_mode :
126+ qx_val_batch = np .zeros ((n_batch , 1 ))
127+ qx_val_batch [:, 0 ] = Mx [A_size :] / admm_weight
128+ x0_ADMM_batch = admm_batch .admm_method_iterate_admm_array (xmin_batch , xmax_batch , qx_val_batch )
129+ b [A_size :] = - admm_weight * (qx_val_batch [:, 0 ] - x0_ADMM_batch )
130+
131+ # Update the main ADMM state with the batch results
132+ admm_method .z [batch_idx ] = admm_batch .z
133+ admm_method .u [batch_idx ] = admm_batch .u
134+ else :
135+ qx_val [:, 0 ] = Mx [A_size :,] / admm_weight
136+ x0_ADMM = admm_method .admm_method_iterate_admm_array (xmin , xmax , qx_val )
137+ # print(x0_ADMM, qx_val.shape)
138+ # raise Exception
139+ b [A_size :] = - admm_weight * (qx_val [:, 0 ] - x0_ADMM )
85140 # cost_data1 = np.linalg.norm(b[:A_size])
86141 # cost_data2 = np.linalg.norm(b0[A_size:])
87142 # model_norm = np.linalg.norm(model)
88143 if Config .verbose :
89- cost_data = - 1.0
90- cost_data_model = 0.0
91- if cost_data2 > 0 :
92- cost_data = cost_data1 / cost_data2
93- if model_norm > 0 :
94- cost_data_model = cost_data1 / model_norm
95- cost_admm1 = np .linalg .norm (qx_val - admm_method .z )
96- cost_admm2 = np .linalg .norm (admm_method .z )
97- cost_admm = - 1.0
98- if cost_admm2 > 0 :
99- cost_admm = cost_admm1 / cost_admm2
100- print ("----------------------------------------" )
101- print (f"it = { _i } " )
102- print ("cost_data = " , cost_data )
103- print ("cost_data_model = " , cost_data_model )
104- print ("cost_admm = " , cost_admm )
105- print ("----------------------------------------" )
144+ if use_batch_mode :
145+ # In batch mode, compute metrics on the full constraint set
146+ Qx_full = Q @ model
147+ cost_admm1 = np .linalg .norm (Qx_full / admm_weight - admm_method .z )
148+ cost_admm2 = np .linalg .norm (admm_method .z )
149+ print ("----------------------------------------" )
150+ print (f"it = { _i } (batch mode: { n_batch } /{ n_ie } constraints)" )
151+ print (f"cost_admm = { cost_admm1 / cost_admm2 if cost_admm2 > 0 else - 1.0 } " )
152+ print ("----------------------------------------" )
153+ else :
154+ cost_data = - 1.0
155+ cost_data_model = 0.0
156+ # if cost_data2 > 0:
157+ # cost_data = cost_data1 / cost_data2
158+ # if model_norm > 0:
159+ # cost_data_model = cost_data1 / model_norm
160+ cost_admm1 = np .linalg .norm (qx_val - admm_method .z )
161+ cost_admm2 = np .linalg .norm (admm_method .z )
162+ cost_admm = - 1.0
163+ if cost_admm2 > 0 :
164+ cost_admm = cost_admm1 / cost_admm2
165+ print ("----------------------------------------" )
166+ print (f"it = { _i } " )
167+ print ("cost_data = " , cost_data )
168+ print ("cost_data_model = " , cost_data_model )
169+ print ("cost_admm = " , cost_admm )
170+ print ("----------------------------------------" )
106171 linsys_kwargs = {k :v [_i ] for k ,v in linsys_solver_kwargs .items ()}
107172 x = lsmr (matrix , b , ** linsys_kwargs )
108173 model += x [0 ]
0 commit comments