We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
ZeroRule
1 parent d616d50 commit 57a4d1cCopy full SHA for 57a4d1c
test/test_rules.jl
@@ -5,6 +5,34 @@ using LinearAlgebra
5
using ReferenceTests
6
using Random
7
8
+## Hand-written tests
9
+@testset "ZeroRule analytic" begin
10
+ rule = ZeroRule()
11
+
12
+ ## Simple dense layer
13
+ Rₖ₊₁ = [1/3, 2/3]
14
+ aₖ = [1.0, 2.0]
15
+ W = [3.0 4.0; 5.0 6.0]
16
+ b = [7.0, 8.0]
17
+ Rₖ = [17/90, 316/675] # expected output
18
19
+ layer = Dense(W, b, relu)
20
+ @test rule(layer, aₖ, Rₖ₊₁) ≈ Rₖ
21
22
+ ## Pooling layer
23
+ Rₖ₊₁ = Float32.([1 2; 3 4]//30)
24
+ aₖ = Float32.([1 2 3; 10 5 6; 7 8 9])
25
+ Rₖ = Float32.([0 0 0; 4 0 2; 0 0 4]//30) # expected output
26
27
+ # Repeat in color channel dim and add batch dim
28
+ Rₖ₊₁ = reshape(repeat(Rₖ₊₁, 1, 3), 2, 2, 3, 1)
29
+ aₖ = reshape(repeat(aₖ,1, 3), 3, 3, 3, 1)
30
+ Rₖ = reshape(repeat(Rₖ,1, 3), 3, 3, 3, 1)
31
32
+ layer = MaxPool((2,2), stride=(1,1))
33
34
+end
35
36
# Fixed pseudo-random numbers
37
T = Float32
38
pseudorandn(dims...) = randn(MersenneTwister(123), T, dims...)
0 commit comments