Skip to content

Commit ad7f57f

Browse files
Merge pull request #20 from haotian127/master
LDB for graph signals
2 parents cad1308 + 8a0a887 commit ad7f57f

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

docs/src/functions/Utils.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ cost_functional
1616
```@docs
1717
dmatrix_flatten
1818
```
19+
```@docs
20+
dmatrix_ldb_flatten
21+
```

src/MultiscaleGraphSignalTransforms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using Reexport
1919
@reexport using .GraphSignal, .GraphPartition, .BasisSpecification, .GHWT, .GHWT_2d, .GHWT_tf_1d, .GHWT_tf_2d, .HGLET
2020

2121
export dvec2dmatrix, dmatrix2dvec, levlist2levlengths!, bsfull, bs_haar, bs_level, bs_walsh, dvec_Threshold, rs_to_region, GraphSig_Plot, gplot, gplot!, partition_fiedler
22-
export cost_functional, dmatrix_flatten
22+
export cost_functional, dmatrix_flatten, dmatrix_ldb_flatten
2323

2424
## export functions of NGWP.jl
2525
using LinearAlgebra, SparseArrays, LightGraphs, SimpleWeightedGraphs, Clustering

src/utils.jl

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearAlgebra
1+
using LinearAlgebra, AverageShiftedHistograms, Distances
22
"""
33
T = ind_class(N::Int)
44
@@ -167,6 +167,77 @@ function dmatrix_flatten(dmatrix::Array{Float64,3}, flatten::Any)
167167
return dmatrix
168168
end # of function dmatrix_flatten!
169169

170+
"""
171+
function ldb_discriminant_measure(p::Vector{Float64}, q::Vector{Float64}; dm::Symbol = :KLdivergence)
172+
173+
Discriminat measure in LDB.
174+
175+
### Input Arguments
176+
* `p,q::Vector{Float64}`: probability mass functions.
177+
* `dm::Symbol`: discriminant measure. Options: `:KLdivergence`(default),
178+
`:Jdivergence`, `:l1`, `:l2`, and `:Hellinger`.
179+
"""
180+
function ldb_discriminant_measure(p::Vector{Float64}, q::Vector{Float64}; dm::Symbol = :KLdivergence)
181+
@assert all(p .>= 0) && all(q .>= 0)
182+
@assert length(p) == length(q)
183+
if dm == :KLdivergence
184+
ind = findall((p .> 0) .& (q .> 0))
185+
return Distances.kl_divergence(p[ind], q[ind])
186+
elseif dm == :Jdivergence
187+
ind = findall((p .> 0) .& (q .> 0))
188+
return Distances.kl_divergence(p[ind], q[ind]) + Distances.kl_divergence(q[ind], p[ind])
189+
elseif dm == :l1
190+
return norm(p - q, 1)
191+
elseif dm == :l2
192+
return norm(p - q, 2)
193+
elseif dm == :Hellinger
194+
return hellinger(p, q)
195+
else
196+
error("This discriminat measure $(dm) is not supported! ")
197+
end
198+
end
199+
200+
"""
201+
function dmatrix_ldb_flatten(dmatrix::Array{Float64,3}...; dm::Symbol = :KLdivergence)
202+
203+
Flatten dmatrices using the LDB method; after this function is called, it returns
204+
a matrix of size (~, ~, 1).
205+
206+
### Input Arguments
207+
* `dmatrix::Array{Float64,3}`: the matrix of LDB expansion coefficients in one class.
208+
* `dm::Symbol`: discriminant measure. Options: `:KLdivergence` (default),
209+
`:Jdivergence`, `:l1`, `:l2`, and `:Hellinger`.
210+
211+
### Example Usage:
212+
`dmatrix_ldb_flatten(dmatrix1, dmatrix2, dmatrix3)`,
213+
each argument is the expansion coefficient matrix of a class of signals. It uses
214+
the default discriminant measure KL divergence to flatten these matrices.
215+
In other words, it flattens these expansion coefficent matrices by computing and
216+
summing "statistical distances" among them.
217+
"""
218+
function dmatrix_ldb_flatten(dmatrix::Array{Float64,3}...; dm::Symbol = :KLdivergence)
219+
C = length(dmatrix) # number of signal classes
220+
if C < 2
221+
error("Input should contain at least two classes of signals.")
222+
end
223+
N, jmax, _ = Base.size(dmatrix[1])
224+
res = zeros(N, jmax)
225+
for u = 1:(C - 1), v = (u + 1):C
226+
for j = 1:jmax
227+
for i = 1:N
228+
h1 = ash(dmatrix[u][i, j, :])
229+
p = h1.density / norm(h1.density, 1)
230+
h2 = ash(dmatrix[v][i, j, :])
231+
q = h2.density / norm(h2.density, 1)
232+
res[i, j] += ldb_discriminant_measure(p, q; dm = dm)
233+
end
234+
end
235+
end
236+
res = reshape(res[:, :, 1], N, jmax, 1)
237+
return res
238+
end
239+
240+
170241
"""
171242
costfun = cost_functional(cfspec::Any)
172243

0 commit comments

Comments
 (0)