@@ -89,7 +89,13 @@ Base.@kwdef mutable struct DRLSState{R,Tx,TH}
8989 temp_x2:: Tx = similar (x)
9090end
9191
92- DRE (f_u:: Number , g_v:: Number , x, u, res, gamma) = f_u + g_v - real (dot (x - u, res)) / gamma + 1 / (2 * gamma) * norm (res)^ 2
92+ function DRE (f_u:: R , g_v, x, u, res, gamma) where R
93+ dot_product = R (0 )
94+ for (x_i, u_i, res_i) in zip (x, u, res)
95+ dot_product += (x_i - u_i) * res_i
96+ end
97+ return f_u + g_v - real (dot_product) / gamma + 1 / (2 * gamma) * norm (res)^ 2
98+ end
9399
94100DRE (state:: DRLSState ) = DRE (state. f_u, state. g_v, state. x, state. u, state. res, state. gamma)
95101
@@ -107,18 +113,25 @@ function Base.iterate(iter::DRLSIteration)
107113 return state, state
108114end
109115
110- set_next_direction! (:: QuasiNewtonStyle , :: DRLSIteration , state:: DRLSState ) = mul! (state. d, state. H, - state. res)
116+ function set_next_direction! (:: QuasiNewtonStyle , :: DRLSIteration , state:: DRLSState )
117+ mul! (state. d, state. H, state. res)
118+ state. d .*= - 1
119+ end
111120set_next_direction! (:: NesterovStyle , :: DRLSIteration , state:: DRLSState ) = state. d .= iterate (state. H)[1 ] .* (state. xbar .- state. xbar_prev) .+ (state. xbar .- state. x)
112121set_next_direction! (:: NoAccelerationStyle , :: DRLSIteration , state:: DRLSState ) = state. d .= state. xbar .- state. x
113122set_next_direction! (iter:: DRLSIteration , state:: DRLSState ) = set_next_direction! (acceleration_style (typeof (iter. directions)), iter, state)
114123
115- update_direction_state! (:: QuasiNewtonStyle , :: DRLSIteration , state:: DRLSState ) = update! (state. H, state. d, state. res - state. res_prev)
124+ function update_direction_state! (:: QuasiNewtonStyle , :: DRLSIteration , state:: DRLSState )
125+ state. res_prev .= state. res .- state. res_prev
126+ update! (state. H, state. d, state. res_prev)
127+ end
116128update_direction_state! (:: NesterovStyle , :: DRLSIteration , state:: DRLSState ) = return
117129update_direction_state! (:: NoAccelerationStyle , :: DRLSIteration , state:: DRLSState ) = return
118130update_direction_state! (iter:: DRLSIteration , state:: DRLSState ) = update_direction_state! (acceleration_style (typeof (iter. directions)), iter, state)
119131
120132function Base. iterate (iter:: DRLSIteration{R, Tx, Tf} , state:: DRLSState ) where {R, Tx, Tf}
121133 DRE_curr = DRE (state)
134+ threshold = iter. dre_sign * DRE_curr - iter. c / iter. gamma * norm (state. res)^ 2
122135
123136 set_next_direction! (iter, state)
124137
@@ -139,7 +152,7 @@ function Base.iterate(iter::DRLSIteration{R, Tx, Tf}, state::DRLSState) where {R
139152 a, b, c = R (0 ), R (0 ), R (0 )
140153
141154 for k in 1 : iter. max_backtracks
142- if iter. dre_sign * DRE (state) <= iter . dre_sign * DRE_curr - iter . c / iter . gamma * norm (state . res_prev) ^ 2
155+ if iter. dre_sign * DRE (state) <= threshold
143156 break
144157 end
145158
0 commit comments