Skip to content

Commit 57a4d1c

Browse files
committed
Add more ZeroRule tests
Test results were derived on paper
1 parent d616d50 commit 57a4d1c

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

test/test_rules.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,34 @@ using LinearAlgebra
55
using ReferenceTests
66
using Random
77

8+
## Hand-written tests
9+
@testset "ZeroRule analytic" begin
10+
rule = ZeroRule()
11+
12+
## Simple dense layer
13+
Rₖ₊₁ = [1/3, 2/3]
14+
aₖ = [1.0, 2.0]
15+
W = [3.0 4.0; 5.0 6.0]
16+
b = [7.0, 8.0]
17+
Rₖ = [17/90, 316/675] # expected output
18+
19+
layer = Dense(W, b, relu)
20+
@test rule(layer, aₖ, Rₖ₊₁) Rₖ
21+
22+
## Pooling layer
23+
Rₖ₊₁ = Float32.([1 2; 3 4]//30)
24+
aₖ = Float32.([1 2 3; 10 5 6; 7 8 9])
25+
Rₖ = Float32.([0 0 0; 4 0 2; 0 0 4]//30) # expected output
26+
27+
# Repeat in color channel dim and add batch dim
28+
Rₖ₊₁ = reshape(repeat(Rₖ₊₁, 1, 3), 2, 2, 3, 1)
29+
aₖ = reshape(repeat(aₖ,1, 3), 3, 3, 3, 1)
30+
Rₖ = reshape(repeat(Rₖ,1, 3), 3, 3, 3, 1)
31+
32+
layer = MaxPool((2,2), stride=(1,1))
33+
@test rule(layer, aₖ, Rₖ₊₁) Rₖ
34+
end
35+
836
# Fixed pseudo-random numbers
937
T = Float32
1038
pseudorandn(dims...) = randn(MersenneTwister(123), T, dims...)

0 commit comments

Comments
 (0)