Skip to content

Commit 3273424

Browse files
committed
Plumbing of linear2d with input2d and linear2d
1 parent 45b27cf commit 3273424

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

src/nf/nf_layer_submodule.f90

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,20 @@ pure module subroutine backward_2d(self, previous, gradient)
6262
real, intent(in) :: gradient(:,:)
6363

6464
! Backward pass from a 2-d layer downstream currently implemented
65-
! only for dense and flatten layers
66-
! CURRENTLY NO LAYERS, tbd: pull/197 and pull/199
65+
! only for input2d and linear2d layers
66+
select type(this_layer => self % p)
67+
68+
type is(linear2d_layer)
69+
70+
select type(prev_layer => previous % p)
71+
type is(input2d_layer)
72+
call this_layer % backward(prev_layer % output, gradient)
73+
type is(linear2d_layer)
74+
call this_layer % backward(prev_layer % output, gradient)
75+
end select
76+
77+
end select
78+
6779
end subroutine backward_2d
6880

6981

@@ -117,12 +129,6 @@ pure module subroutine backward_3d(self, previous, gradient)
117129
call this_layer % backward(prev_layer % output, gradient)
118130
end select
119131

120-
! type is(linear2d_layer)
121-
! select type(prev_layer => previous % p)
122-
! type is(input3d_layer)
123-
! call this_layer % backward(prev_layer % output, gradient)
124-
! end select
125-
126132
end select
127133

128134
end subroutine backward_3d
@@ -203,13 +209,15 @@ pure module subroutine forward(self, input)
203209
call this_layer % forward(prev_layer % output)
204210
end select
205211

206-
! type is(linear2d_layer)
207-
! select type(prev_layer => input % p)
208-
! type is(input3d_layer)
209-
! call this_layer % forward(prev_layer % output)
210-
! type is(linear2d_layer)
211-
! call this_layer % forward(prev_layer % output)
212-
! end select
212+
type is(linear2d_layer)
213+
214+
! Upstream layers permitted: input2d, linear2d
215+
select type(prev_layer => input % p)
216+
type is(input2d_layer)
217+
call this_layer % forward(prev_layer % output)
218+
type is(linear2d_layer)
219+
call this_layer % forward(prev_layer % output)
220+
end select
213221

214222
end select
215223

@@ -246,8 +254,10 @@ pure module subroutine get_output_2d(self, output)
246254

247255
type is(input2d_layer)
248256
allocate(output, source=this_layer % output)
257+
type is(linear2d_layer)
258+
allocate(output, source=this_layer % output)
249259
class default
250-
error stop '1-d output can only be read from an input1d, dense, or flatten layer.'
260+
error stop '2-d output can only be read from an input2d or linear2d layer.'
251261

252262
end select
253263

@@ -343,8 +353,8 @@ elemental module function get_num_params(self) result(num_params)
343353
num_params = 0
344354
type is (reshape3d_layer)
345355
num_params = 0
346-
! type is (linear2d_layer)
347-
! num_params = this_layer % get_num_params()
356+
type is (linear2d_layer)
357+
num_params = this_layer % get_num_params()
348358
class default
349359
error stop 'Unknown layer type.'
350360
end select
@@ -372,8 +382,8 @@ module function get_params(self) result(params)
372382
! No parameters to get.
373383
type is (reshape3d_layer)
374384
! No parameters to get.
375-
! type is (linear2d_layer)
376-
! params = this_layer % get_params()
385+
type is (linear2d_layer)
386+
params = this_layer % get_params()
377387
class default
378388
error stop 'Unknown layer type.'
379389
end select
@@ -401,8 +411,8 @@ module function get_gradients(self) result(gradients)
401411
! No gradients to get.
402412
type is (reshape3d_layer)
403413
! No gradients to get.
404-
! type is (linear2d_layer)
405-
! gradients = this_layer % get_gradients()
414+
type is (linear2d_layer)
415+
gradients = this_layer % get_gradients()
406416
class default
407417
error stop 'Unknown layer type.'
408418
end select
@@ -450,6 +460,9 @@ module subroutine set_params(self, params)
450460
type is (conv2d_layer)
451461
call this_layer % set_params(params)
452462

463+
type is (linear2d_layer)
464+
call this_layer % set_params(params)
465+
453466
type is (maxpool2d_layer)
454467
! No parameters to set.
455468
write(stderr, '(a)') 'Warning: calling set_params() ' &
@@ -468,8 +481,6 @@ module subroutine set_params(self, params)
468481
class default
469482
error stop 'Unknown layer type.'
470483

471-
! type is (linear2d_layer)
472-
! call this_layer % set_params(params)
473484
end select
474485

475486
end subroutine set_params

src/nf/nf_network_submodule.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ module subroutine backward(self, output, loss)
149149
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
150150
type is(reshape3d_layer)
151151
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
152-
! type is(linear2d_layer)
153-
! call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
152+
type is(linear2d_layer)
153+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
154154
end select
155155
end if
156156

0 commit comments

Comments
 (0)