Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .vscode/settings.json

This file was deleted.

11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,24 +1,31 @@
name = "GsvdInitialization"
uuid = "2ac24108-be9c-42b8-8d78-6a4f62a87e7d"
authors = ["youdongguo <1010705897@qq.com> and contributors"]
version = "1.0.0"
authors = ["youdongguo <1010705897@qq.com> and contributors"]

[deps]
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386"
NonNegLeastSquares = "b7351bd1-99d9-5c5d-8786-f205a815c4d7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TSVD = "9449cd9e-2762-5aa3-a617-5413e99d722e"

[compat]
FileIO = "1.18"
JLD2 = "0.6"
Kronecker = "0.5"
LinearAlgebra = "1"
NMF = "1"
NonNegLeastSquares = "0.4"
TSVD = "0.4"
julia = "1.10"

[extras]
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["NMF", "Test"]
test = ["NMF", "Test", "FileIO", "JLD2"]
4 changes: 4 additions & 0 deletions demo/generate_ground_truth.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,7 @@ function gaussiantemplate(T::Type, r::Real)
end
gaussiantemplate(r::Real) = gaussiantemplate(Float64, r)

load_svd_of_gt() = load(joinpath(dirname(@__DIR__), "demo/svd_of_GT.jld2"))["svdX"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will surely need some comment about why we are uploading it




Binary file added demo/svd_of_GT.jld2
Binary file not shown.
96 changes: 84 additions & 12 deletions src/GsvdInitialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module GsvdInitialization

using LinearAlgebra, NMF, TSVD
using NonNegLeastSquares
using Kronecker, SparseArrays

export gsvdnmf,
gsvdrecover
Expand Down Expand Up @@ -33,6 +34,9 @@ Other keyword arguments are passed to `NMF.nnmf`.
function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f;
n2 = size(first(f), 2),
tol_nmf=1e-4,
alg = :cd,
initW = :standard,
truncmult = 1e-5,
kwargs...)
n1 = size(W, 2)
kadd = n2 - n1
Expand All @@ -42,9 +46,12 @@ function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f;
if kadd == 0
return W, H
else
W_recover, H_recover = gsvdrecover(X, copy(W), copy(H), kadd, f)
result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=W_recover, H0=H_recover)
return result_recover.W, result_recover.H
W_recover, H_recover, _ = gsvdrecover(X, copy(W), copy(H), kadd, f; initW=initW)
if alg == :multmse
W_recover, H_recover = max.(W_recover, truncmult), max.(H_recover, truncmult)
end
result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=copy(W_recover), H0=copy(H_recover))
return result_recover, result_recover.W, result_recover.H
end
end
gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, n2::Int; kwargs...) = gsvdnmf(X, W, H, tsvd(X, n2); kwargs...)
Expand Down Expand Up @@ -73,13 +80,13 @@ Keyword arguments:

Other keyword arguments are passed to `NMF.nnmf`.
"""
function gsvdnmf(X::AbstractMatrix, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=tol_final, kwargs...)
function gsvdnmf(X::AbstractMatrix, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=tol_final, initW = :standard, kwargs...)
n1, n2 = ncomponents
f = tsvd(X, n2)
W0, H0 = NMF.nndsvd(X, n1; initdata = (U = f[1], S = f[2], V = f[3]))
result_initial_nmf = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=copy(W0), H0=copy(H0))
W_initial_nmf, H_initial_nmf = result_initial_nmf.W, result_initial_nmf.H
return gsvdnmf(X, W_initial_nmf, H_initial_nmf, f; kwargs..., n2=n2, tol_nmf=tol_final)
return gsvdnmf(X, W_initial_nmf, H_initial_nmf, f; kwargs..., n2=n2, tol_nmf=tol_final, initW=initW)
end
gsvdnmf(X::AbstractMatrix, ncomponents_final::Integer; kwargs...) = gsvdnmf(X, ncomponents_final-1 => ncomponents_final; kwargs...)

Expand Down Expand Up @@ -108,7 +115,7 @@ Arguments:

`f`: SVD (or Truncated SVD) of `X`
"""
function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple)
function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple; initW::Symbol = :standard, kwargs...)
m, n = size(W0)
kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components"))
if kadd == 0
Expand All @@ -117,15 +124,29 @@ function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kad
U0, S0, V0 = f
U0, S0, V0 = U0[:,1:n], S0[1:n], V0[:,1:n]
Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd)
Wadd, a = init_W(X, W0, H0, Hadd)
Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd'))
W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn]
cs = Wcols_modification(X, W0_1, H0_1)
W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1
if initW == :standard
Wadd, a = init_W(X, W0, H0, Hadd)
Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd'))
W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn]
cs = Wcols_modification(X, W0_1, H0_1)
W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1
elseif initW == :joint
W0_2, H0_2 = gsvdrecover_Wa(X, W0, H0, Hadd; kwargs...)
else
throw(ArgumentError("Unknown initW method: $initW"))
end
return abs.(W0_2), abs.(H0_2), Λ
end
end

function gsvdrecover_Wa(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, Hadd::AbstractArray)
m = size(W0, 1)
Hadd_nn = truncatepos(Hadd', X, W0, H0)'
Wadd, a = init_Wa(X, W0, H0, Hadd_nn)
W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd], [H0; Hadd_nn]
return abs.(W0_1), abs.(H0_1)
end

function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int)
_, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), (U0'*W0)*(H0*V0));
inv_RQt = inv(R*Q')
Expand All @@ -138,7 +159,41 @@ function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::Abs
H_index = sortperm(Λ, rev = true)[1:kadd]
Hadd = inv_RQt[:, H_index]
Hadd_1 = V0*Hadd
return Hadd_1', Λ[H_index]
return Hadd_1', Λ
end

function init_Wa(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}) where T
m = size(X, 1)
kadd = size(Hadd, 1)
G = gram_sp_C(W0, H0, Hadd)[1]
b = gram_b(X, W0, H0, Hadd)
θ = nonneg_lsq(G, b; alg=:fnnls, gram=true)
Wadd = reshape(θ[1:m*kadd], m, kadd)
α = θ[m*kadd+1:end]
return Wadd, α
end

function gram_sp_C(W0, H0, Hadd)
m, r0 = size(W0)
k = size(Hadd, 1)
mk = m*k
W0W0, H0H0 = W0'*W0, H0*H0'
P = Hadd*H0'
HH = Hadd*Hadd'
G22 = sparse(W0W0.*H0H0)
G12 = zeros(Float64, mk, r0)
for j in 1:r0
G12[:,j] .= vec(W0[:,j] * P[:,j]')
end
G12 = sparse(G12)
G11 = kronecker(HH, sparse(I, m, m))
G = [G11 G12; G12' G22]
return G, G11, G12, G22
end

function gram_b(X, W0, H0, Hadd)
b = vcat(vec(X * Hadd'), diag(W0' * X * H0'))
return b
end

function init_W(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}; α = nothing) where T
Expand Down Expand Up @@ -176,4 +231,21 @@ function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::Abstrac
return β[:]
end

function truncatepos(Y, X, W, H)
ΔX = max.(zero(eltype(X)), X - W*H)
Yout = similar(Y)
for j in axes(Y, 2)
y = view(Y, :, j)
yp = max.(y, zero(eltype(y)))
ym = max.(-y, zero(eltype(y)))
if sum(ΔX * yp) >= sum(ΔX * ym)
Yout[:, j] = yp
else
Yout[:, j] = ym
end
end
return Yout
end


end
44 changes: 36 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
using GsvdInitialization
using Test

using LinearAlgebra, NMF
using LinearAlgebra, NMF, FileIO

include(joinpath(dirname(@__DIR__), "demo/generate_ground_truth.jl"))

W_GT, H_GT = generate_ground_truth()
svdX = load_svd_of_gt()

@testset "test top wrapper" begin
W = W_GT
H = H_GT
X = W*H
standard_nmf = nnmf(X, 10; alg = :cd, init=:nndsvd, tol=1e-4, initdata = svd(float(X)))
W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4);
img_tol_int = sum(abs2, X)
standard_nmf = nnmf(X, 10; alg = :cd, init=:nndsvd, tol=1e-4, maxiter = 10^5, initdata = svdX)
_, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4);
@test size(W_gsvd, 2) == 10
@test sum(abs2, X-standard_nmf.W*standard_nmf.H)/sum(abs2, X) > sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X)
@test sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) < 2e-10
@test sum(abs2, X-standard_nmf.W*standard_nmf.H)/sum(abs2, X) > sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X)

X = rand(30, 20)
W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd)
W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd)
_, W_gsvd_1, H_gsvd_1 = gsvdnmf(X, 10; alg=:cd)
_, W_gsvd_2, H_gsvd_2 = gsvdnmf(X, 9 => 10; alg=:cd)
@test sum(abs2, W_gsvd_1-W_gsvd_2) <= 1e-12
@test sum(abs2, H_gsvd_1-H_gsvd_2) <= 1e-12
end

@testset "GsvdInitialization.jl" begin
@testset "GsvdInitialization" begin
W, H = rand(10, 3), rand(3, 8)
X = W*H
U, S, V = svd(X)
Expand Down Expand Up @@ -54,3 +54,31 @@ end
@test β.*β0 ≈ ones(3)

end

@testset "joint optimize W and alpha" begin
W = W_GT
H = H_GT
X = W*H
_, W_gsvd, H_gsvd = gsvdnmf(X, 9=>10; alg = :cd, maxiter = 10^5, tol_final=1e-4, tol_intermediate = 1e-4, initW=:joint);
@test size(W_gsvd, 2) == 10
@test sum(abs2, X-W_gsvd*H_gsvd)/sum(abs2, X) < 2e-10

W, H = rand(10, 3), rand(3, 8)
X = W*H
U, S, V = svd(X)

W0, H0 = copy(W), copy(H)
Hadd = rand(2, 8)
Wadd, a = GsvdInitialization.init_Wa(X, W0, H0, Hadd)
@test a ≈ ones(size(W0, 2))
@test norm(Wadd) <= 1e-8

G = GsvdInitialization.gram_sp_C(W0, H0, Hadd)[1]
b = GsvdInitialization.gram_b(X, W0, H0, Hadd)
Wadd = rand(10, 2)
α = rand(3)
θ = vcat(vec(Wadd), α)
E = θ'*G*θ-2*b'*θ+sum(abs2, X)
@test abs(E-sum(abs2, X-[repeat(α', size(W0, 1)).*W0 Wadd]*[H0;Hadd])) <= 1e-12

end