Skip to content

Commit 1a3a515

Browse files
committed
add and export function dmatrix_ldb_flatten()
1 parent 48b4dc3 commit 1a3a515

File tree

2 files changed

+66
-1
lines changed

2 files changed

+66
-1
lines changed

src/MultiscaleGraphSignalTransforms.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ using Reexport
1919
@reexport using .GraphSignal, .GraphPartition, .BasisSpecification, .GHWT, .GHWT_2d, .GHWT_tf_1d, .GHWT_tf_2d, .HGLET
2020

2121
export dvec2dmatrix, dmatrix2dvec, levlist2levlengths!, bsfull, bs_haar, bs_level, bs_walsh, dvec_Threshold, rs_to_region, GraphSig_Plot, gplot, gplot!, partition_fiedler
22-
export cost_functional, dmatrix_flatten
22+
export cost_functional, dmatrix_flatten, dmatrix_ldb_flatten
2323

2424
## export functions of NGWP.jl
2525
using LinearAlgebra, SparseArrays, LightGraphs, SimpleWeightedGraphs, Clustering

src/utils.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,71 @@ function dmatrix_flatten(dmatrix::Array{Float64,3}, flatten::Any)
167167
return dmatrix
168168
end # of function dmatrix_flatten!
169169

170+
"""
171+
function ldb_discriminant_measure(p::Vector{Float64}, q::Vector{Float64}; dm::Symbol = :KLdivergence)
172+
173+
Discriminat measure in LDB.
174+
175+
### Input Arguments
176+
* `p,q::Vector{Float64}`: probability mass functions.
177+
* `dm::Symbol`: discriminant measure. Options: `:KLdivergence`(default),
178+
`:Jdivergence`, `:l1`, `:l2`, and `:Hellinger`.
179+
"""
180+
function ldb_discriminant_measure(p::Vector{Float64}, q::Vector{Float64}; dm::Symbol = :KLdivergence)
181+
@assert all(p .>= 0) && all(q .>= 0)
182+
@assert length(p) == length(q)
183+
if dm == :KLdivergence
184+
ind = findall((p .> 0) .& (q .> 0))
185+
return Distances.kl_divergence(p[ind], q[ind])
186+
elseif dm == :Jdivergence
187+
ind = findall((p .> 0) .& (q .> 0))
188+
return Distances.kl_divergence(p[ind], q[ind]) + Distances.kl_divergence(q[ind], p[ind])
189+
elseif dm == :l1
190+
return norm(p - q, 1)
191+
elseif dm == :l2
192+
return norm(p - q, 2)
193+
elseif dm == :Hellinger
194+
return hellinger(p, q)
195+
else
196+
error("This discriminat measure $(dm) is not supported! ")
197+
end
198+
end
199+
200+
"""
201+
function dmatrix_ldb_flatten(dmatrix::Array{Float64,3}...; dm::Symbol = :KLdivergence)
202+
203+
Flatten dmatrix using the LDB method; after this function is called, it becomes
204+
the size of (~, ~, 1). Example usage: dmatrix_ldb_flatten(dmatrix1, dmatrix2, dmatrix3),
205+
each argument is the expansion coefficient matrix of a class of signals.
206+
207+
### Input Arguments
208+
* `dmatrix::Array{Float64,3}`: the matrix of LDB expansion coefficients in one class
209+
* `dm::Symbol`: discriminant measure. Options: `:KLdivergence`(default),
210+
`:Jdivergence`, `:l1`, `:l2`, and `:Hellinger`.
211+
"""
212+
function dmatrix_ldb_flatten(dmatrix::Array{Float64,3}...; dm::Symbol = :KLdivergence)
213+
C = length(dmatrix) # number of signal classes
214+
if C < 2
215+
error("Input should contain at least two classes of signals.")
216+
end
217+
N, jmax, _ = Base.size(dmatrix[1])
218+
res = zeros(N, jmax)
219+
for u = 1:(C - 1), v = (i + 1):C
220+
for j = 1:jmax
221+
for i = 1:N
222+
h1 = ash(dmatrix[u][i, j, :])
223+
p = h1.density / norm(h1.density, 1)
224+
h2 = ash(dmatrix[v][i, j, :])
225+
q = h2.density / norm(h2.density, 1)
226+
res[i, j] += ldb_discriminant_measure(p, q; dm = dm)
227+
end
228+
end
229+
end
230+
res = reshape(res[:, :, 1], N, jmax, 1)
231+
return res
232+
end
233+
234+
170235
"""
171236
costfun = cost_functional(cfspec::Any)
172237

0 commit comments

Comments
 (0)