@@ -4,37 +4,80 @@ program test_embedding_layer
44 implicit none
55
66 logical :: ok = .true.
7- integer :: sample_input(3 ) = [2 , 1 , 3 ]
8- real :: sample_gradient(3 , 2 ) = reshape ([0.1 , 0.2 , 0.3 , 0.4 , 0.6 , 0.6 ], [3 , 2 ])
9- real :: output_flat(6 )
10- real :: expected_output_flat(6 ) = reshape ([0.3 , 0.1 , 0.5 , 0.4 , 0.2 , 0.6 ], [6 ])
11- real :: dw_flat(8 )
12- real :: expected_dw_flat(8 ) = reshape ([0.2 , 0.1 , 0.3 , 0 ., 0.6 , 0.4 , 0.6 , 0 .], [8 ])
13- type (embedding_layer) :: embedding
14-
15- embedding = embedding_layer(vocab_size= 4 , model_dimension= 2 )
16- call embedding % init([3 ])
17- embedding % weights = reshape ([0.1 , 0.3 , 0.5 , 0.7 , 0.2 , 0.4 , 0.6 , 0.8 ], [4 , 2 ])
18-
19- call embedding % forward(sample_input)
20-
21- output_flat = reshape (embedding % output, [6 ])
22- if (.not. all (output_flat.eq. expected_output_flat)) then
23- ok = .false.
24- write (stderr, ' (a)' ) ' forward returned incorrect values.. failed'
25- end if
267
27- call embedding % backward(sample_input, sample_gradient)
28- dw_flat = reshape (embedding % dw, shape (dw_flat))
29- if (.not. all (dw_flat.eq. expected_dw_flat)) then
30- ok = .false.
31- write (stderr, ' (a)' ) ' backward returned incorrect dw values.. failed'
32- end if
8+ call test_simple(ok)
9+ call test_positional(ok)
3310
3411 if (ok) then
3512 print ' (a)' , ' test_embedding_layer: All tests passed.'
3613 else
3714 write (stderr, ' (a)' ) ' test_embedding_layer: One or more tests failed.'
3815 stop 1
3916 end if
17+
18+ contains
19+ subroutine test_simple (ok )
20+ logical , intent (in out ) :: ok
21+
22+ integer :: sample_input(3 ) = [2 , 1 , 3 ]
23+ real :: sample_gradient(3 , 2 ) = reshape ([0.1 , 0.2 , 0.3 , 0.4 , 0.6 , 0.6 ], [3 , 2 ])
24+ real :: output_flat(6 )
25+ real :: expected_output_flat(6 ) = reshape ([0.3 , 0.1 , 0.5 , 0.4 , 0.2 , 0.6 ], [6 ])
26+ real :: dw_flat(8 )
27+ real :: expected_dw_flat(8 ) = reshape ([0.2 , 0.1 , 0.3 , 0 ., 0.6 , 0.4 , 0.6 , 0 .], [8 ])
28+ type (embedding_layer) :: embedding
29+
30+ embedding = embedding_layer(vocab_size= 4 , model_dimension= 2 )
31+ call embedding % init([3 ])
32+ embedding % weights = reshape ([0.1 , 0.3 , 0.5 , 0.7 , 0.2 , 0.4 , 0.6 , 0.8 ], [4 , 2 ])
33+
34+ call embedding % forward(sample_input)
35+
36+ output_flat = reshape (embedding % output, [6 ])
37+ if (.not. all (output_flat.eq. expected_output_flat)) then
38+ ok = .false.
39+ write (stderr, ' (a)' ) ' forward returned incorrect values.. failed'
40+ end if
41+
42+ call embedding % backward(sample_input, sample_gradient)
43+ dw_flat = reshape (embedding % dw, shape (dw_flat))
44+ if (.not. all (dw_flat.eq. expected_dw_flat)) then
45+ ok = .false.
46+ write (stderr, ' (a)' ) ' backward returned incorrect dw values.. failed'
47+ end if
48+ end subroutine test_simple
49+
50+ subroutine test_positional (ok )
51+ logical , intent (in out ) :: ok
52+
53+ integer :: sample_input(3 ) = [2 , 1 , 3 ]
54+ real :: output_flat(12 )
55+ real :: expected_output_flat(12 ) = reshape ([&
56+ 0.3 , 0.941471 , 1.4092975 ,&
57+ 1.3 , 0.64030236 , 0.08385316 ,&
58+ 0.3 , 0.10999984 , 0.51999867 ,&
59+ 1.3 , 1.09995 , 1.4998 &
60+ ], [12 ])
61+ type (embedding_layer) :: embedding
62+
63+ real :: theta
64+ integer :: i, pos
65+
66+ embedding = embedding_layer(vocab_size= 5 , model_dimension= 4 , positional= .true. )
67+ call embedding % init([3 ])
68+ embedding % weights = reshape ([&
69+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 ,&
70+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 ,&
71+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 ,&
72+ 0.1 , 0.3 , 0.5 , 0.7 , 0.2 &
73+ ], [5 , 4 ])
74+
75+ call embedding % forward(sample_input)
76+
77+ output_flat = reshape (embedding % output, [12 ])
78+ if (.not. all (abs (output_flat - expected_output_flat) <= (1e-06 + 1e-05 * abs (expected_output_flat)))) then
79+ ok = .false.
80+ write (stderr, ' (a)' ) ' positional encoding returned incorrect values.. failed'
81+ end if
82+ end subroutine test_positional
4083end program test_embedding_layer
0 commit comments