Skip to content

Commit 270fbe5

Browse files
committed
fix: implement batch mode for ADMM solver with validation checks
1 parent adc49f1 commit 270fbe5

2 files changed

Lines changed: 184 additions & 23 deletions

File tree

loopsolver/admm_solver.py

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def admm_solve(
3535
nmajor=200,
3636
linsys_solver_kwargs={},
3737
linsys_solver="lsmr",
38+
batch_size: int = None,
39+
batch_fraction: float = None,
40+
random_seed: int = None,
3841
):
3942
if A.shape[1] != x0.shape[0]:
4043
raise ValueError("Number of columns in interpolation matrix does not match x0")
@@ -51,6 +54,31 @@ def admm_solve(
5154
if A.shape[0] != b.shape[0]:
5255
raise ValueError("Number of rows in interpolation matrix and b are different")
5356
n_ie = bounds.shape[0]
57+
58+
# Setup batch mode
59+
use_batch_mode = False
60+
n_batch = n_ie
61+
rng = None
62+
63+
if batch_size is not None and batch_fraction is not None:
64+
raise ValueError("Cannot specify both batch_size and batch_fraction")
65+
66+
if batch_size is not None:
67+
if batch_size < 1 or batch_size > n_ie:
68+
raise ValueError(f"batch_size must be between 1 and {n_ie}")
69+
n_batch = batch_size
70+
use_batch_mode = True
71+
elif batch_fraction is not None:
72+
if batch_fraction <= 0 or batch_fraction > 1:
73+
raise ValueError("batch_fraction must be between 0 and 1")
74+
n_batch = max(1, int(n_ie * batch_fraction))
75+
use_batch_mode = True
76+
77+
if use_batch_mode and random_seed is not None:
78+
rng = np.random.RandomState(random_seed)
79+
elif use_batch_mode:
80+
rng = np.random.RandomState()
81+
5482
qx_val = np.zeros((Q.shape[0], 1))
5583
model = np.zeros(A.shape[1])
5684
model[:] = x0[:]
@@ -71,38 +99,75 @@ def admm_solve(
7199
if not hasattr(linsys_solver_kwargs[k], '__len__') or len(linsys_solver_kwargs[k]) != nmajor:
72100
linsys_solver_kwargs[k] = [linsys_solver_kwargs[k]] * nmajor
73101
for _i in tqdm.tqdm(range(nmajor)):
102+
# Sample batch of inequality constraints if batch mode is enabled
103+
if use_batch_mode:
104+
batch_idx = rng.choice(n_ie, size=n_batch, replace=False)
105+
batch_idx = np.sort(batch_idx) # Sort for consistent sparse matrix operations
106+
Q_batch = Q[batch_idx, :]
107+
xmin_batch = xmin[batch_idx, :]
108+
xmax_batch = xmax[batch_idx, :]
109+
matrix = vstack([A, Q_batch])
110+
b = np.zeros(A.shape[0] + n_batch)
111+
112+
# Create a temporary ADMM object for the batch if needed
113+
# This maintains z and u variables only for the sampled constraints
114+
admm_batch = ADMM(n_batch)
115+
admm_batch.z = admm_method.z[batch_idx].copy()
116+
admm_batch.u = admm_method.u[batch_idx].copy()
117+
else:
118+
batch_idx = None
119+
74120
# current model value
75121
Mx = matrix @ model # np.dot(A, model)
76122
b[:A_size] = b0[:A_size] - Mx[:A_size]
77123

78124
if Q.shape[0] > 0:
79-
80-
qx_val[:, 0] = Mx[A_size:,] / admm_weight
81-
x0_ADMM = admm_method.admm_method_iterate_admm_array(xmin, xmax, qx_val)
82-
# print(x0_ADMM, qx_val.shape)
83-
# raise Exception
84-
b[A_size:] = -admm_weight * (qx_val[:, 0] - x0_ADMM)
125+
if use_batch_mode:
126+
qx_val_batch = np.zeros((n_batch, 1))
127+
qx_val_batch[:, 0] = Mx[A_size:] / admm_weight
128+
x0_ADMM_batch = admm_batch.admm_method_iterate_admm_array(xmin_batch, xmax_batch, qx_val_batch)
129+
b[A_size:] = -admm_weight * (qx_val_batch[:, 0] - x0_ADMM_batch)
130+
131+
# Update the main ADMM state with the batch results
132+
admm_method.z[batch_idx] = admm_batch.z
133+
admm_method.u[batch_idx] = admm_batch.u
134+
else:
135+
qx_val[:, 0] = Mx[A_size:,] / admm_weight
136+
x0_ADMM = admm_method.admm_method_iterate_admm_array(xmin, xmax, qx_val)
137+
# print(x0_ADMM, qx_val.shape)
138+
# raise Exception
139+
b[A_size:] = -admm_weight * (qx_val[:, 0] - x0_ADMM)
85140
# cost_data1 = np.linalg.norm(b[:A_size])
86141
# cost_data2 = np.linalg.norm(b0[A_size:])
87142
# model_norm = np.linalg.norm(model)
88143
if Config.verbose:
89-
cost_data = -1.0
90-
cost_data_model = 0.0
91-
if cost_data2 > 0:
92-
cost_data = cost_data1 / cost_data2
93-
if model_norm > 0:
94-
cost_data_model = cost_data1 / model_norm
95-
cost_admm1 = np.linalg.norm(qx_val - admm_method.z)
96-
cost_admm2 = np.linalg.norm(admm_method.z)
97-
cost_admm = -1.0
98-
if cost_admm2 > 0:
99-
cost_admm = cost_admm1 / cost_admm2
100-
print("----------------------------------------")
101-
print(f"it = {_i}")
102-
print("cost_data = ", cost_data)
103-
print("cost_data_model = ", cost_data_model)
104-
print("cost_admm = ", cost_admm)
105-
print("----------------------------------------")
144+
if use_batch_mode:
145+
# In batch mode, compute metrics on the full constraint set
146+
Qx_full = Q @ model
147+
cost_admm1 = np.linalg.norm(Qx_full / admm_weight - admm_method.z)
148+
cost_admm2 = np.linalg.norm(admm_method.z)
149+
print("----------------------------------------")
150+
print(f"it = {_i} (batch mode: {n_batch}/{n_ie} constraints)")
151+
print(f"cost_admm = {cost_admm1 / cost_admm2 if cost_admm2 > 0 else -1.0}")
152+
print("----------------------------------------")
153+
else:
154+
cost_data = -1.0
155+
cost_data_model = 0.0
156+
# if cost_data2 > 0:
157+
# cost_data = cost_data1 / cost_data2
158+
# if model_norm > 0:
159+
# cost_data_model = cost_data1 / model_norm
160+
cost_admm1 = np.linalg.norm(qx_val - admm_method.z)
161+
cost_admm2 = np.linalg.norm(admm_method.z)
162+
cost_admm = -1.0
163+
if cost_admm2 > 0:
164+
cost_admm = cost_admm1 / cost_admm2
165+
print("----------------------------------------")
166+
print(f"it = {_i}")
167+
print("cost_data = ", cost_data)
168+
print("cost_data_model = ", cost_data_model)
169+
print("cost_admm = ", cost_admm)
170+
print("----------------------------------------")
106171
linsys_kwargs = {k:v[_i] for k,v in linsys_solver_kwargs.items()}
107172
x = lsmr(matrix, b, **linsys_kwargs)
108173
model += x[0]

tests/test_batch_mode.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""
2+
Unit tests for ADMM batch mode.
3+
"""
4+
import numpy as np
5+
from scipy.sparse import csr_matrix
6+
7+
from loopsolver import admm_solve
8+
9+
10+
def _build_problem(seed=42, n_data=30, n_vars=15, n_ineq=200):
11+
rng = np.random.RandomState(seed)
12+
A = csr_matrix(rng.randn(n_data, n_vars))
13+
x_true = rng.randn(n_vars)
14+
b = A @ x_true + 0.1 * rng.randn(n_data)
15+
16+
Q = csr_matrix(rng.randn(n_ineq, n_vars))
17+
Q_x_true = Q @ x_true
18+
bounds = np.column_stack([
19+
Q_x_true - 0.5,
20+
Q_x_true + 0.5,
21+
])
22+
23+
x0 = np.zeros(n_vars)
24+
return A, b, Q, bounds, x0
25+
26+
27+
def test_admm_batch_size_runs():
28+
A, b, Q, bounds, x0 = _build_problem()
29+
30+
result = admm_solve(
31+
A,
32+
b,
33+
Q,
34+
bounds,
35+
x0,
36+
admm_weight=0.1,
37+
nmajor=10,
38+
batch_size=50,
39+
random_seed=123,
40+
linsys_solver_kwargs={"atol": 1e-6, "btol": 1e-6},
41+
)
42+
43+
assert result.shape == (A.shape[1],)
44+
assert np.all(np.isfinite(result))
45+
46+
47+
def test_admm_batch_fraction_runs():
48+
A, b, Q, bounds, x0 = _build_problem(seed=7)
49+
50+
result = admm_solve(
51+
A,
52+
b,
53+
Q,
54+
bounds,
55+
x0,
56+
admm_weight=0.1,
57+
nmajor=10,
58+
batch_fraction=0.25,
59+
random_seed=456,
60+
linsys_solver_kwargs={"atol": 1e-6, "btol": 1e-6},
61+
)
62+
63+
assert result.shape == (A.shape[1],)
64+
assert np.all(np.isfinite(result))
65+
66+
67+
def test_admm_batch_mode_matches_full_shape():
68+
A, b, Q, bounds, x0 = _build_problem(seed=99)
69+
70+
full_result = admm_solve(
71+
A,
72+
b,
73+
Q,
74+
bounds,
75+
x0,
76+
admm_weight=0.1,
77+
nmajor=5,
78+
linsys_solver_kwargs={"atol": 1e-6, "btol": 1e-6},
79+
)
80+
81+
batch_result = admm_solve(
82+
A,
83+
b,
84+
Q,
85+
bounds,
86+
x0,
87+
admm_weight=0.1,
88+
nmajor=5,
89+
batch_size=40,
90+
random_seed=789,
91+
linsys_solver_kwargs={"atol": 1e-6, "btol": 1e-6},
92+
)
93+
94+
assert full_result.shape == batch_result.shape
95+
assert np.all(np.isfinite(full_result))
96+
assert np.all(np.isfinite(batch_result))

0 commit comments

Comments
 (0)