Skip to content

Commit 9de4d48

Browse files
authored
Reduce allocations in PANOC, PANOCplus, ZeroFPR, DRLS (#74)
1 parent 2d72fdd commit 9de4d48

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

src/algorithms/drls.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,13 @@ Base.@kwdef mutable struct DRLSState{R,Tx,TH}
8989
temp_x2::Tx = similar(x)
9090
end
9191

92-
DRE(f_u::Number, g_v::Number, x, u, res, gamma) = f_u + g_v - real(dot(x - u, res)) / gamma + 1 / (2 * gamma) * norm(res)^2
92+
function DRE(f_u::R, g_v, x, u, res, gamma) where R
93+
dot_product = R(0)
94+
for (x_i, u_i, res_i) in zip(x, u, res)
95+
dot_product += (x_i - u_i) * res_i
96+
end
97+
return f_u + g_v - real(dot_product) / gamma + 1 / (2 * gamma) * norm(res)^2
98+
end
9399

94100
DRE(state::DRLSState) = DRE(state.f_u, state.g_v, state.x, state.u, state.res, state.gamma)
95101

@@ -107,18 +113,25 @@ function Base.iterate(iter::DRLSIteration)
107113
return state, state
108114
end
109115

110-
set_next_direction!(::QuasiNewtonStyle, ::DRLSIteration, state::DRLSState) = mul!(state.d, state.H, -state.res)
116+
function set_next_direction!(::QuasiNewtonStyle, ::DRLSIteration, state::DRLSState)
117+
mul!(state.d, state.H, state.res)
118+
state.d .*= -1
119+
end
111120
set_next_direction!(::NesterovStyle, ::DRLSIteration, state::DRLSState) = state.d .= iterate(state.H)[1] .* (state.xbar .- state.xbar_prev) .+ (state.xbar .- state.x)
112121
set_next_direction!(::NoAccelerationStyle, ::DRLSIteration, state::DRLSState) = state.d .= state.xbar .- state.x
113122
set_next_direction!(iter::DRLSIteration, state::DRLSState) = set_next_direction!(acceleration_style(typeof(iter.directions)), iter, state)
114123

115-
update_direction_state!(::QuasiNewtonStyle, ::DRLSIteration, state::DRLSState) = update!(state.H, state.d, state.res - state.res_prev)
124+
function update_direction_state!(::QuasiNewtonStyle, ::DRLSIteration, state::DRLSState)
125+
state.res_prev .= state.res .- state.res_prev
126+
update!(state.H, state.d, state.res_prev)
127+
end
116128
update_direction_state!(::NesterovStyle, ::DRLSIteration, state::DRLSState) = return
117129
update_direction_state!(::NoAccelerationStyle, ::DRLSIteration, state::DRLSState) = return
118130
update_direction_state!(iter::DRLSIteration, state::DRLSState) = update_direction_state!(acceleration_style(typeof(iter.directions)), iter, state)
119131

120132
function Base.iterate(iter::DRLSIteration{R, Tx, Tf}, state::DRLSState) where {R, Tx, Tf}
121133
DRE_curr = DRE(state)
134+
threshold = iter.dre_sign * DRE_curr - iter.c / iter.gamma * norm(state.res)^2
122135

123136
set_next_direction!(iter, state)
124137

@@ -139,7 +152,7 @@ function Base.iterate(iter::DRLSIteration{R, Tx, Tf}, state::DRLSState) where {R
139152
a, b, c = R(0), R(0), R(0)
140153

141154
for k in 1:iter.max_backtracks
142-
if iter.dre_sign * DRE(state) <= iter.dre_sign * DRE_curr - iter.c / iter.gamma * norm(state.res_prev)^2
155+
if iter.dre_sign * DRE(state) <= threshold
143156
break
144157
end
145158

src/algorithms/panoc.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,18 @@ function Base.iterate(iter::PANOCIteration{R}) where R
9898
return state, state
9999
end
100100

101-
set_next_direction!(::QuasiNewtonStyle, ::PANOCIteration, state::PANOCState) = mul!(state.d, state.H, -state.res)
101+
function set_next_direction!(::QuasiNewtonStyle, ::PANOCIteration, state::PANOCState)
102+
mul!(state.d, state.H, state.res)
103+
state.d .*= -1
104+
end
102105
set_next_direction!(::NoAccelerationStyle, ::PANOCIteration, state::PANOCState) = state.d .= .-state.res
103106
set_next_direction!(iter::PANOCIteration, state::PANOCState) = set_next_direction!(acceleration_style(typeof(iter.directions)), iter, state)
104107

105-
update_direction_state!(::QuasiNewtonStyle, ::PANOCIteration, state::PANOCState) = update!(state.H, state.x - state.x_prev, state.res - state.res_prev)
108+
function update_direction_state!(::QuasiNewtonStyle, ::PANOCIteration, state::PANOCState)
109+
state.x_prev .= state.x .- state.x_prev
110+
state.res_prev .= state.res .- state.res_prev
111+
update!(state.H, state.x_prev, state.res_prev)
112+
end
106113
update_direction_state!(::NoAccelerationStyle, ::PANOCIteration, state::PANOCState) = return
107114
update_direction_state!(iter::PANOCIteration, state::PANOCState) = update_direction_state!(acceleration_style(typeof(iter.directions)), iter, state)
108115

src/algorithms/panocplus.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,18 @@ function Base.iterate(iter::PANOCplusIteration{R}) where {R}
103103
return state, state
104104
end
105105

106-
set_next_direction!(::QuasiNewtonStyle, ::PANOCplusIteration, state::PANOCplusState) = mul!(state.d, state.H, -state.res_prev)
106+
function set_next_direction!(::QuasiNewtonStyle, ::PANOCplusIteration, state::PANOCplusState)
107+
mul!(state.d, state.H, state.res_prev)
108+
state.d .*= -1
109+
end
107110
set_next_direction!(::NoAccelerationStyle, ::PANOCplusIteration, state::PANOCplusState) = state.d .= .-state.res_prev
108111
set_next_direction!(iter::PANOCplusIteration, state::PANOCplusState) = set_next_direction!(acceleration_style(typeof(iter.directions)), iter, state)
109112

110-
update_direction_state!(::QuasiNewtonStyle, ::PANOCplusIteration, state::PANOCplusState) = update!(state.H, state.x - state.x_prev, state.res - state.res_prev)
113+
function update_direction_state!(::QuasiNewtonStyle, ::PANOCplusIteration, state::PANOCplusState)
114+
state.x_prev .= state.x .- state.x_prev
115+
state.res_prev .= state.res .- state.res_prev
116+
update!(state.H, state.x_prev, state.res_prev)
117+
end
111118
update_direction_state!(::NoAccelerationStyle, ::PANOCplusIteration, state::PANOCplusState) = return
112119
update_direction_state!(iter::PANOCplusIteration, state::PANOCplusState) = update_direction_state!(acceleration_style(typeof(iter.directions)), iter, state)
113120

@@ -116,8 +123,6 @@ reset_direction_state!(::NoAccelerationStyle, ::PANOCplusIteration, state::PANOC
116123
reset_direction_state!(iter::PANOCplusIteration, state::PANOCplusState) = reset_direction_state!(acceleration_style(typeof(iter.directions)), iter, state)
117124

118125
function Base.iterate(iter::PANOCplusIteration{R}, state::PANOCplusState) where R
119-
f_Az, a, b, c = R(Inf), R(Inf), R(Inf), R(Inf)
120-
121126
# store iterate and residual for metric update later on
122127
state.x_prev .= state.x
123128
state.res_prev .= state.res

src/algorithms/zerofpr.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,18 @@ function Base.iterate(iter::ZeroFPRIteration{R}) where R
9696
return state, state
9797
end
9898

99-
set_next_direction!(::QuasiNewtonStyle, ::ZeroFPRIteration, state::ZeroFPRState) = mul!(state.d, state.H, -state.res_xbar)
99+
function set_next_direction!(::QuasiNewtonStyle, ::ZeroFPRIteration, state::ZeroFPRState)
100+
mul!(state.d, state.H, state.res_xbar)
101+
state.d .*= -1
102+
end
100103
set_next_direction!(::NoAccelerationStyle, ::ZeroFPRIteration, state::ZeroFPRState) = state.d .= .-state.res
101104
set_next_direction!(iter::ZeroFPRIteration, state::ZeroFPRState) = set_next_direction!(acceleration_style(typeof(iter.directions)), iter, state)
102105

103-
update_direction_state!(::QuasiNewtonStyle, ::ZeroFPRIteration, state::ZeroFPRState) = update!(state.H, state.xbar - state.xbar_prev, state.res_xbar - state.res_xbar_prev)
106+
function update_direction_state!(::QuasiNewtonStyle, ::ZeroFPRIteration, state::ZeroFPRState)
107+
state.xbar_prev .= state.xbar .- state.xbar_prev
108+
state.res_xbar_prev .= state.res_xbar .- state.res_xbar_prev
109+
update!(state.H, state.xbar_prev, state.res_xbar_prev)
110+
end
104111
update_direction_state!(::NoAccelerationStyle, ::ZeroFPRIteration, state::ZeroFPRState) = return
105112
update_direction_state!(iter::ZeroFPRIteration, state::ZeroFPRState) = update_direction_state!(acceleration_style(typeof(iter.directions)), iter, state)
106113

0 commit comments

Comments
 (0)