1+ program test_conv1d_network
2+
3+ use iso_fortran_env, only: stderr = > error_unit
4+ use nf, only: conv1d, input, network, dense, sgd, maxpool1d
5+
6+ implicit none
7+
8+ type (network) :: net
9+ real , allocatable :: sample_input(:,:), output(:,:), o(:)
10+ logical :: ok = .true.
11+
12+ ! 3-layer convolutional network
13+ net = network([ &
14+ input(3 , 32 ), &
15+ conv1d(filters= 16 , kernel_size= 3 ), &
16+ conv1d(filters= 32 , kernel_size= 3 ) &
17+ ])
18+
19+ if (.not. size (net % layers) == 3 ) then
20+ write (stderr, ' (a)' ) ' conv2d network should have 3 layers.. failed'
21+ ok = .false.
22+ end if
23+
24+ ! Test for output shape
25+ allocate (sample_input(3 , 32 ))
26+ sample_input = 0
27+
28+ call net % forward(sample_input)
29+ call net % layers(3 ) % get_output(output)
30+
31+ if (.not. all (shape (output) == [32 , 28 ])) then
32+ write (stderr, ' (a)' ) ' conv1d network output should have correct shape.. failed'
33+ ok = .false.
34+ end if
35+
36+ deallocate (sample_input, output)
37+
38+ training1: block
39+
40+ type (network) :: cnn
41+ real :: y(1 )
42+ real :: tolerance = 1e-4
43+ integer :: n
44+ integer , parameter :: num_iterations = 1000
45+
46+ ! Test training of a minimal constant mapping
47+ allocate (sample_input(1 , 5 ))
48+ call random_number (sample_input)
49+
50+ cnn = network([ &
51+ input(1 , 5 ), &
52+ conv1d(filters= 1 , kernel_size= 3 ), &
53+ conv1d(filters= 1 , kernel_size= 3 ), &
54+ dense(1 ) &
55+ ])
56+
57+ y = [0.1234567 ]
58+
59+ do n = 1 , num_iterations
60+ call cnn % forward(sample_input)
61+ call cnn % backward(y)
62+ call cnn % update(optimizer= sgd(learning_rate= 1 .))
63+ o = cnn % layers(2 ) % get_params()
64+ print * , o
65+ if (all (abs (cnn % predict(sample_input) - y) < tolerance)) exit
66+ end do
67+
68+ if (.not. n <= num_iterations) then
69+ write (stderr, ' (a)' ) &
70+ ' convolutional network 1 should converge in simple training.. failed'
71+ ok = .false.
72+ end if
73+
74+ end block training1
75+
76+ training2: block
77+
78+ type (network) :: cnn
79+ real :: x(1 , 8 )
80+ real :: y(1 )
81+ real :: tolerance = 1e-4
82+ integer :: n
83+ integer , parameter :: num_iterations = 1000
84+
85+ call random_number (x)
86+ y = [0.1234567 ]
87+
88+ cnn = network([ &
89+ input(1 , 8 ), &
90+ conv1d(filters= 1 , kernel_size= 3 ), &
91+ maxpool1d(pool_size= 2 ), &
92+ conv1d(filters= 1 , kernel_size= 3 ), &
93+ dense(1 ) &
94+ ])
95+
96+ do n = 1 , num_iterations
97+ call cnn % forward(x)
98+ call cnn % backward(y)
99+ call cnn % update(optimizer= sgd(learning_rate= 1 .))
100+ if (all (abs (cnn % predict(x) - y) < tolerance)) exit
101+ end do
102+
103+ if (.not. n <= num_iterations) then
104+ write (stderr, ' (a)' ) &
105+ ' convolutional network 2 should converge in simple training.. failed'
106+ ok = .false.
107+ end if
108+
109+ end block training2
110+
111+ training3: block
112+
113+ type (network) :: cnn
114+ real :: x(1 , 12 )
115+ real :: y(9 )
116+ real :: tolerance = 1e-4
117+ integer :: n
118+ integer , parameter :: num_iterations = 5000
119+
120+ call random_number (x)
121+ y = [0.12345 , 0.23456 , 0.34567 , 0.45678 , 0.56789 , 0.67890 , 0.78901 , 0.89012 , 0.90123 ]
122+
123+ cnn = network([ &
124+ input(1 , 12 ), &
125+ conv1d(filters= 1 , kernel_size= 3 ), & ! 1x12x12 input, 1x10x10 output
126+ maxpool1d(pool_size= 2 ), & ! 1x10x10 input, 1x5x5 output
127+ conv1d(filters= 1 , kernel_size= 3 ), & ! 1x5x5 input, 1x3x3 output
128+ dense(9 ) & ! 9 outputs
129+ ])
130+
131+ do n = 1 , num_iterations
132+ call cnn % forward(x)
133+ call cnn % backward(y)
134+ call cnn % update(optimizer= sgd(learning_rate= 1 .))
135+ if (all (abs (cnn % predict(x) - y) < tolerance)) exit
136+ end do
137+
138+ if (.not. n <= num_iterations) then
139+ write (stderr, ' (a)' ) &
140+ ' convolutional network 3 should converge in simple training.. failed'
141+ ok = .false.
142+ end if
143+
144+ end block training3
145+
146+
147+ if (ok) then
148+ print ' (a)' , ' test_conv1d_network: All tests passed.'
149+ else
150+ write (stderr, ' (a)' ) ' test_conv1d_network: One or more tests failed.'
151+ stop 1
152+ end if
153+
154+ end program test_conv1d_network
155+
0 commit comments