Skip to content

Commit b39e6da

Browse files
milancurcicOneAdder
authored andcommitted
Plumbing of linear2d with input2d and linear2d
1 parent f1a01a6 commit b39e6da

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

src/nf/nf_layer_submodule.f90

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,20 @@ pure module subroutine backward_2d(self, previous, gradient)
6464
real, intent(in) :: gradient(:,:)
6565

6666
! Backward pass from a 2-d layer downstream currently implemented
67-
! only for dense and flatten layers
68-
! CURRENTLY NO LAYERS, tbd: pull/197 and pull/199
67+
! only for input2d and linear2d layers
68+
select type(this_layer => self % p)
69+
70+
type is(linear2d_layer)
71+
72+
select type(prev_layer => previous % p)
73+
type is(input2d_layer)
74+
call this_layer % backward(prev_layer % output, gradient)
75+
type is(linear2d_layer)
76+
call this_layer % backward(prev_layer % output, gradient)
77+
end select
78+
79+
end select
80+
6981
end subroutine backward_2d
7082

7183

@@ -119,12 +131,6 @@ pure module subroutine backward_3d(self, previous, gradient)
119131
call this_layer % backward(prev_layer % output, gradient)
120132
end select
121133

122-
! type is(linear2d_layer)
123-
! select type(prev_layer => previous % p)
124-
! type is(input3d_layer)
125-
! call this_layer % backward(prev_layer % output, gradient)
126-
! end select
127-
128134
end select
129135

130136
end subroutine backward_3d
@@ -207,13 +213,15 @@ pure module subroutine forward(self, input)
207213
call this_layer % forward(prev_layer % output)
208214
end select
209215

210-
! type is(linear2d_layer)
211-
! select type(prev_layer => input % p)
212-
! type is(input3d_layer)
213-
! call this_layer % forward(prev_layer % output)
214-
! type is(linear2d_layer)
215-
! call this_layer % forward(prev_layer % output)
216-
! end select
216+
type is(linear2d_layer)
217+
218+
! Upstream layers permitted: input2d, linear2d
219+
select type(prev_layer => input % p)
220+
type is(input2d_layer)
221+
call this_layer % forward(prev_layer % output)
222+
type is(linear2d_layer)
223+
call this_layer % forward(prev_layer % output)
224+
end select
217225

218226
end select
219227

@@ -250,8 +258,10 @@ pure module subroutine get_output_2d(self, output)
250258

251259
type is(input2d_layer)
252260
allocate(output, source=this_layer % output)
261+
type is(linear2d_layer)
262+
allocate(output, source=this_layer % output)
253263
class default
254-
error stop '1-d output can only be read from an input1d, dense, or flatten layer.'
264+
error stop '2-d output can only be read from an input2d or linear2d layer.'
255265

256266
end select
257267

@@ -347,8 +357,8 @@ elemental module function get_num_params(self) result(num_params)
347357
num_params = 0
348358
type is (reshape3d_layer)
349359
num_params = 0
350-
! type is (linear2d_layer)
351-
! num_params = this_layer % get_num_params()
360+
type is (linear2d_layer)
361+
num_params = this_layer % get_num_params()
352362
class default
353363
error stop 'Unknown layer type.'
354364
end select
@@ -376,8 +386,8 @@ module function get_params(self) result(params)
376386
! No parameters to get.
377387
type is (reshape3d_layer)
378388
! No parameters to get.
379-
! type is (linear2d_layer)
380-
! params = this_layer % get_params()
389+
type is (linear2d_layer)
390+
params = this_layer % get_params()
381391
class default
382392
error stop 'Unknown layer type.'
383393
end select
@@ -405,8 +415,8 @@ module function get_gradients(self) result(gradients)
405415
! No gradients to get.
406416
type is (reshape3d_layer)
407417
! No gradients to get.
408-
! type is (linear2d_layer)
409-
! gradients = this_layer % get_gradients()
418+
type is (linear2d_layer)
419+
gradients = this_layer % get_gradients()
410420
class default
411421
error stop 'Unknown layer type.'
412422
end select
@@ -454,6 +464,9 @@ module subroutine set_params(self, params)
454464
type is (conv2d_layer)
455465
call this_layer % set_params(params)
456466

467+
type is (linear2d_layer)
468+
call this_layer % set_params(params)
469+
457470
type is (maxpool2d_layer)
458471
! No parameters to set.
459472
write(stderr, '(a)') 'Warning: calling set_params() ' &
@@ -472,8 +485,6 @@ module subroutine set_params(self, params)
472485
class default
473486
error stop 'Unknown layer type.'
474487

475-
! type is (linear2d_layer)
476-
! call this_layer % set_params(params)
477488
end select
478489

479490
end subroutine set_params

src/nf/nf_network_submodule.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ module subroutine backward(self, output, loss)
157157

158158
type is(reshape3d_layer)
159159
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
160-
! type is(linear2d_layer)
161-
! call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
160+
type is(linear2d_layer)
161+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
162162
end select
163163
end if
164164

0 commit comments

Comments
 (0)