Skip to content
This repository was archived by the owner on Oct 6, 2022. It is now read-only.

Commit f2f4db1

Browse files
authored
Merge pull request #3 from mfherbst/dev-diis
improve structure of cDIIS.jl
2 parents 70cbc39 + 44e026f commit f2f4db1

File tree

1 file changed

+179
-110
lines changed

1 file changed

+179
-110
lines changed

src/algorithms/cDIIS.jl

Lines changed: 179 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -5,109 +5,146 @@
55
"""
66
Container type for the state of one spin type for cDIIS
77
"""
8-
mutable struct cDIISstate
9-
fock::CircularBuffer
8+
mutable struct DiisState
9+
iterate::CircularBuffer
1010
error::CircularBuffer
11+
12+
# errorOverlaps is a Circular Buffer containing already calculated rows of
13+
# the next iterate matrix. Each row is also stored as a Circular Buffer.
1114
errorOverlaps::CircularBuffer
15+
n_diis_size::Int
16+
17+
function DiisState(n_diis_size::Int)
18+
new(CircularBuffer{AbstractArray}(n_diis_size),
19+
CircularBuffer{AbstractArray}(n_diis_size),
20+
CircularBuffer{AbstractArray}(n_diis_size),
21+
n_diis_size
22+
)
23+
end
1224
end
1325

26+
1427
"""
1528
cDIIS
1629
"""
1730
mutable struct cDIIS <: Accelerator
18-
n_diis_size::Int
19-
state::Tuple{cDIISstate, cDIISstate}
31+
state::Tuple{DiisState, DiisState}
32+
sync_spins::Bool
2033
conditioning_threshold::Float64
2134
coefficient_threshold::Float64
2235

23-
function cDIIS(problem::ScfProblem; n_diis_size = 5, conditioning_threshold = 1e-14, coefficient_threshold = 1e-6, kwargs...)
24-
stateα = cDIISstate(CircularBuffer{AbstractArray}(n_diis_size),
25-
CircularBuffer{AbstractArray}(n_diis_size),
26-
CircularBuffer{AbstractArray}(sum(1:n_diis_size -1))
27-
)
28-
stateβ = deepcopy(stateα)
29-
new(n_diis_size, (stateα, stateβ), conditioning_threshold, coefficient_threshold)
36+
function cDIIS(problem::ScfProblem; n_diis_size = 5, sync_spins = true, conditioning_threshold = 1e-14, coefficient_threshold = 1e-6, kwargs...)
37+
stateα = DiisState(n_diis_size)
38+
stateβ = DiisState(n_diis_size)
39+
new((stateα, stateβ), sync_spins, conditioning_threshold, coefficient_threshold)
3040
end
3141
end
3242

3343
"""
3444
Helper function.
35-
pushes current fock and error matrices to states of both spin types
45+
pushes current iterate and error matrices to states of both spin types
3646
"""
37-
function push_iterate_to_state!(iterate::ScfIterState, states::Tuple)
38-
for i in 1:size(iterate.fock, 3)
39-
pushfirst!(states[i].fock, view(iterate.fock, :,:,i))
40-
pushfirst!(states[i].error, view(iterate.error_pulay, :,:,i))
41-
end
47+
function push_iterate!(state::DiisState, iterate::AbstractArray, error::Union{AbstractArray,Nothing} = nothing)
48+
pushfirst!(state.iterate, iterate)
49+
50+
# Push difference to previous iterate if no error given
51+
pushfirst!(state.error,
52+
error != nothing ? error : iterate - state.iterate[1])
53+
end
54+
55+
"""
56+
Helper functions to get views for specific spins
57+
"""
58+
function spin(obj::AbstractArray, dim::Int)
59+
view(obj, ntuple(x -> Colon(), ndims(obj) - 1)..., dim)
60+
end
61+
62+
function spincount(obj::AbstractArray)
63+
size(obj, ndims(obj))
64+
end
65+
66+
function spinloop(obj::Union{AbstractArray, Accelerator})
67+
typeof(obj) == Accelerator ?
68+
(1:spincount(obj.state.iterate)) :
69+
(1:spincount(obj))
4270
end
4371

4472
"""
4573
Computes next iterate using cDIIS
4674
"""
47-
function compute_next_iterate(acc::cDIIS, iterate::ScfIterState)
75+
function compute_next_iterate(acc::cDIIS, iterstate::ScfIterState)
76+
# Push iterate and error to state
77+
map-> push_iterate!(acc.state[σ], spin(iterstate.fock, σ), spin(iterstate.error_pulay, σ)), spinloop(iterstate.fock))
78+
4879
# Check if the number of known fock and error matrices is equal for both
4980
# spins before doing anything
50-
history_size = acc.state[1].fock.length
51-
for i in 1:size(iterate.fock, 3)
52-
@assert acc.state[i].fock.length == history_size
53-
@assert acc.state[i].error.length == history_size
81+
history_size = acc.state[1].iterate.length
82+
for σ in spinloop(iterstate.fock)
83+
@assert acc.state[σ].iterate.length == history_size
84+
@assert acc.state[σ].error.length == history_size
5485
end
5586

56-
# Store the current fock and error matrices first
57-
push_iterate_to_state!(iterate, acc.state)
87+
# To save memory we store only new_iterate once and pass views of it to the
88+
# computation routines that write directly into the view.
89+
new_iterate = zeros(size(iterstate.fock))
5890

59-
# Calculate the new fock matrix for each spin type separately
60-
# and write them in the result matrix directly to save memory
61-
new_iterate_fock = zeros(size(iterate.fock))
62-
for i in 1:size(iterate.fock, 3)
63-
compute_next_iterate_fock!(acc.state[i], acc.n_diis_size, view(new_iterate_fock, :,:,i), acc.conditioning_threshold, acc.coefficient_threshold)
64-
end
91+
# Defining anonymous functions with given arguments improves readability later on.
92+
matrix(σ) = diis_build_matrix(acc.state[σ])
93+
coefficients(A) = diis_solve_coefficients(A, acc.conditioning_threshold)
94+
compute(c, σ) = compute_next_iterate_matrix!(acc.state[σ], c, spin(new_iterate, σ), acc.coefficient_threshold)
6595

66-
return update_iterate_matrix(iterate, new_iterate_fock)
96+
# If sync_spins is enabled, we need to calculate the coefficients using the
97+
# merged matrix. This also means we need to remove the same number of
98+
# matrices from both histrories.
99+
if acc.sync_spins & (spincount(iterstate.fock) == 2)
100+
A = merge_matrices(matrix(1), matrix(2))
101+
c, matrixpurgecount = coefficients(A)
102+
map-> compute(c, σ), spinloop(iterstate.fock))
103+
map-> purge_matrix_from_state(acc.state[σ], matrixpurgecount), spinloop(iterstate.fock))
104+
else
105+
# If we are calculating the spins separately, each spin has its own coefficients.
106+
for σ in spinloop(iterstate.fock)
107+
c, matrixpurgecount = coefficients(matrix(σ))
108+
compute(c, σ)
109+
purge_matrix_from_state(acc.state[σ], matrixpurgecount)
110+
end
111+
end
112+
return update_iterate_matrix(iterstate, new_iterate)
67113
end
68114

69115
"""
70-
Computes a new matrix to be used as Fock Matrix in next iteration
71-
The function writes the result into the passed argument 'fock'
116+
When synchronizing spins the resulting DIIS matrices have to be added
117+
together but the constraint must be kept as is.
72118
"""
73-
function compute_next_iterate_fock!(state::cDIISstate, n_diis_size::Int, fock::SubArray, conditioning_threshold::Float64, coefficient_threshold::Float64)
74-
@assert n_diis_size > 0
75-
@assert state.fock.length > 0
76-
77-
# The Fock Matrix in the next Iteration is a linear combination of
78-
# previous Fock Matrices.
79-
#
80-
# The linear system has dimension m <= n_diis_size, since in the
81-
# beginning of the iteration we do not have the full number of
82-
# previous Fock Matrices yet.
83-
84-
m = min(n_diis_size, state.fock.length)
85-
86-
# Build the linear system we need to solve in order to obtain
87-
# the neccessary coefficients c_i.
88-
89-
A = diis_build_matrix(n_diis_size, m, state)
90-
rhs = diis_build_rhs(n_diis_size, m, state)
91-
(c, bad_condition) = diis_solve_coefficients(A, rhs, conditioning_threshold)
119+
function merge_matrices(A1::AbstractArray, A2::AbstractArray)
120+
view(A1, 1:size(A1, 1) - 1, 1:size(A1, 2) - 1) .+ view(A2, 1:size(A2, 1) - 1, 1:size(A2, 2) - 1)
121+
return A1
122+
end
92123

93-
# In some cases the matrix so badly conditioned, that we cannot produce a
94-
# sane new iterate. In this case we need to reuse the newest fock matrix
95-
if bad_condition
96-
fock .= state.fock[1]
97-
return
124+
function purge_matrix_from_state(state::DiisState, count::Int)
125+
for i in 1:2*count
126+
pop!(state.iterate)
127+
pop!(state.error)
128+
pop!(state.errorOverlaps)
98129
end
130+
end
99131

132+
"""
133+
Computes a new matrix to be used as Fock Matrix in next iteration
134+
The function writes the result into the passed argument 'fock'
135+
"""
136+
function compute_next_iterate_matrix!(state::DiisState, c::AbstractArray, iterate::SubArray, coefficient_threshold::Float64)
100137
# add very small coefficients to the largest one but always use the most
101-
# recent fock matrix regardless of the coefficient value
138+
# recent iterate matrix regardless of the coefficient value
102139
mask = map(x -> norm(x) > coefficient_threshold, c)
103140
mask[1] = true
104141
c[argmax(c)] += sum(c[ .! mask])
105142

106143
# Construct new Fock Matrix using obtained coefficients
107-
# and write it to the given fock matrix. We assume, that
108-
# fock is a matrix of zeros.
144+
# and write it to the given iterate matrix. We assume, that
145+
# iterate is a matrix of zeros.
109146
for i in eachindex(c)[mask]
110-
fock .+= c[i] * state.fock[i]
147+
iterate .+= c[i] * state.iterate[i]
111148
end
112149
end
113150

@@ -118,83 +155,115 @@ end
118155
Returns the vector c and and a boolean value representing if the matrix A
119156
is so badly conditioned, that the previous fock matrix should be used.
120157
"""
121-
function diis_solve_coefficients(A, rhs, threshold)
122-
println(cond(A))
158+
function diis_solve_coefficients(A::AbstractArray, threshold::Float64)
159+
# Right hand side of the equation
160+
rhs = diis_build_rhs(size(A, 1))
161+
123162
# calculate the eigenvalues of A and select sufficiently large eigenvalues
124163
λ, U = eigen(A)
125164
mask = map(x -> norm(x) > threshold, λ)
126165

127-
# if all eigenvalues are under the threshold, return and indicate, that the
128-
# previous fock matrix should be used without modifications.
166+
if !all(mask)
167+
println(" Removing ", count(.! mask), " of ", length(mask), " eigenvalues from DIIS linear system.")
168+
end
169+
170+
# if all eigenvalues are under the threshold, we cannot calculate sane
171+
# coefficients. The current fock matrix should be used without
172+
# modifications.
129173
if all( .! mask)
130174
println("All eigenvalues are under the threshold! Skipping iterate modification…")
131-
return (nothing, true)
175+
c = zeros(size(A, 1) - 1)
176+
c[1] = 1
177+
return c
132178
end
133179

134180
# Obtain the solution of the linear system A * c = rhs using the inverse
135181
# matrix of A constructed using the above decomposition
136182
c = U[:,mask] * Diagonal(1 ./ λ[mask]) * U[:,mask]' * rhs
137183

138-
# Warning: Note that c has size (n_diis_size + 1) since
139-
# the last element is the lagrange multiplier
140-
# corresponding to the constraint
141-
# \sum_{i=1}^n_diis_size c_i = 1
142-
# We need to remove this element!
143-
144-
return (c[1:length(c) - 1], false)
184+
# Note that c has size (n_diis_size + 1) since the last element is the
185+
# lagrange multiplier corresponding to the constraint
186+
# \sum_{i=1}^n_diis_size c_i = 1 We need to remove this element!
187+
return c[1:length(c) - 1], count(.! mask)
145188
end
146189

147190
"""
148191
Linear System Matrix for the cDIIM accelerator.
149192
This is a hermitian matrix containing error overlaps B
150193
and ones in the form
151194
152-
B 1
153-
1† 0
195+
A = B 1
196+
1† 0
154197
"""
155-
function diis_build_matrix(n_diis_size::Int, m::Int, state::cDIISstate)
156-
A = zeros(m +1,m +1)
198+
function diis_build_matrix(state::DiisState)
199+
@assert state.n_diis_size > 0
200+
@assert state.iterate.length > 0
201+
202+
# The Fock Matrix in the next Iteration is a linear combination of
203+
# previous Fock Matrices.
204+
#
205+
# The linear system has dimension m <= state.n_diis_size, since in the
206+
# beginning of the iteration we do not have the full number of
207+
# previous Fock Matrices yet.
157208

158-
# Since the Matrix A is Hermitian, we only have to calculate
159-
# the upper triagonal and can use the Julia function 'Hermitian'
160-
# to fill the lower triagonal of the matrix. This also allowes
161-
# Julia to use optimized algorithems for Hermitian matrices.
209+
m = min(state.n_diis_size, length(state.iterate))
162210

163-
# We can reuse most of the matrix B from the last iteration,
164-
# since the values do not change and only have to calculate
165-
# the first lign of B.
166-
# We accomplish this using a circular buffer of size n_diis_size,
167-
# which holds circular buffers of size n_diis_size to store already
168-
# calculated elements.
211+
A = zeros(m +1,m +1)
212+
213+
# Since the Matrix A is symmetric, we only have to calculate
214+
# the upper triagonal and can use the Julia object 'Symmetric'
215+
# to fill the lower triagonal of the matrix. This also allows
216+
# Julia to use optimized algorithems for symmetric matrices.
217+
#
218+
# The matrix A is filled in the following form:
219+
#
220+
# B[1,1] B[1,2] B[1,3] … B[1,m] 1
221+
# 0 B[2,2] B[2,3] … B[2,m] 1
222+
# 0 0 B[3,3] … B[3,m] 1
223+
# ⋮ ⋮ ⋮ … ⋮ ⋮
224+
# 0 0 0 … B[m,m] 1
225+
# 0 0 0 … 0 0
226+
#
227+
# Note, that the values A[1,1] … A[1,m-1] will become the values
228+
# of A[2,2] … A[2,m] in the next iteration. This means, that we effectively
229+
# only have to calculate the first row and can reuse all consecutive ones.
230+
#
231+
# We would like to store the rows in such a way, that the storage variable
232+
# can be mapped to a row immediately. Since the values are shifted to the
233+
# right in each iteration and the last element is discarded it makes sense
234+
# to use a Circular Buffer to hold each row.
235+
#
236+
# Since the last row is also discarded after the new row is calculated we
237+
# use a Circular Buffer again to store the rows.
169238

170239
# Fill the first row with newly calculated values and cache them
171240
# in a newly created Circular Buffer
172-
newValues = CircularBuffer{Any}(n_diis_size)
173-
for j in 1:m
174-
A[1,j] = tr(state.error[1]' * state.error[j])
175-
push!(newValues, A[1,j])
176-
end
177-
# Since we want to use this buffer as the 2nd row of A in the next
178-
# iteration we need the following layout of the buffer
179-
# 0 a1 a2 … am
180-
# so we need to push a 0 at the beginning
181-
pushfirst!(newValues, 0)
182-
183-
# The last element of each row has to be 1, see above for a
184-
# detailed explanation.
241+
newValues = CircularBuffer{Any}(state.n_diis_size)
242+
map(j -> push!(newValues,
243+
tr(state.error[1]' * state.error[j])), 1:m)
244+
fill!(newValues, 0)
245+
246+
# Push newly calculated row to the row buffer.
247+
pushfirst!(state.errorOverlaps, newValues)
248+
249+
# The last element of each row of A has to be 1. After calling Symmetric(A)
250+
# the copy of these 1s in the bottom row of A defines the constraint
251+
# sum(c) == 1.
185252
A[1, m + 1] = 1
186253

187-
# Now fill the rest of the lines using cached values,
254+
# Now fill all rows with cached values,
188255
# push a '0' on each buffer to prepare it for the next iteration
189256
# and set the last element of each row to 1.
190-
for i in 2:m
191-
A[i,1:m] = state.errorOverlaps[i-1][1:m]
192-
pushfirst!(state.errorOverlaps[i-1], 0)
257+
for i in 1:m
258+
A[i,1:m] = state.errorOverlaps[i][1:m]
193259
A[i, m + 1] = 1
194-
end
195260

196-
# Push newly calculated row to the row buffer.
197-
pushfirst!(state.errorOverlaps, newValues)
261+
# Since we want to use this buffer as the 2nd row of A in the next
262+
# iteration we need the following layout of the buffer
263+
# 0 A[1,1] A[1,2] … A[1,m-1]
264+
# so we need to push a 0 to the beginning
265+
pushfirst!(state.errorOverlaps[i], 0)
266+
end
198267

199268
return Symmetric(A)
200269
end
@@ -204,8 +273,8 @@ end
204273
This is a vector of size m+1 containing only ones
205274
except for the last element which is zero.
206275
"""
207-
function diis_build_rhs(n_diis_size::Int, m::Int, state::cDIISstate)
208-
rhs = zeros(m + 1)
209-
rhs[m + 1] = 1
276+
function diis_build_rhs(vectorsize::Int)
277+
rhs = zeros(vectorsize)
278+
rhs[end] = 1
210279
return rhs
211280
end

0 commit comments

Comments
 (0)