Skip to content

Commit 5e8eaed

Browse files
committed
Adding a generalized reshape function with new files
1 parent 0fbbda4 commit 5e8eaed

8 files changed

+169
-4
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ add_library(neural-fortran
4949
src/nf/nf_parallel.f90
5050
src/nf/nf_parallel_submodule.f90
5151
src/nf/nf_random.f90
52+
src/nf/nf_reshape_generalized.f90
53+
src/nf/nf_reshape_generalized_submodule.f90
5254
src/nf/nf_reshape_layer.f90
5355
src/nf/nf_reshape_layer_submodule.f90
5456
src/nf/io/nf_io_binary.f90

example/cnn_mnist_1d.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ program cnn_mnist
2020

2121
net = network([ &
2222
input(784), &
23-
reshape([1,784]), &
23+
reshape([1,28,28]), &
2424
locally_connected_1d(filters=8, kernel_size=2, activation=relu()), &
2525
dense(10, activation=softmax()) &
2626
])

src/nf/nf_layer_constructors.f90

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module nf_layer_constructors
88
implicit none
99

1010
private
11-
public :: conv2d, dense, flatten, input, locally_connected_1d, maxpool2d, reshape
11+
public :: conv2d, dense, flatten, input, locally_connected_1d, maxpool2d, reshape, reshape_generalized
1212

1313
interface input
1414

@@ -193,6 +193,12 @@ module function reshape(output_shape) result(res)
193193
!! Resulting layer instance
194194
end function reshape
195195

196+
module function reshape_generalized(output_shape) result(res)
197+
integer, intent(in) :: output_shape
198+
type(layer) :: res
199+
200+
end function reshape_generalized
201+
196202
end interface
197203

198204
end module nf_layer_constructors

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
use nf_locally_connected_1d_layer, only: locally_connected_1d_layer
1010
use nf_maxpool2d_layer, only: maxpool2d_layer
1111
use nf_reshape_layer, only: reshape3d_layer
12+
use nf_reshape_layer_generalized, only: reshape_generalized_layer
1213
use nf_activation, only: activation_function, relu, sigmoid
1314

1415
implicit none
@@ -160,4 +161,15 @@ module function reshape(output_shape) result(res)
160161

161162
end function reshape
162163

164+
module function reshape_generalized(output_shape) result(res)
165+
integer, intent(in) :: output_shape(:)
166+
type(layer) :: res
167+
168+
res % name = 'reshape_generalized'
169+
res % layer_shape = output_shape
170+
171+
allocate(res % p, source=reshape_generalized_layer(output_shape))
172+
173+
end function reshape_generalized
174+
163175
end submodule nf_layer_constructors_submodule

src/nf/nf_layer_submodule.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
use nf_flatten_layer, only: flatten_layer
77
use nf_input1d_layer, only: input1d_layer
88
use nf_input3d_layer, only: input3d_layer
9+
use nf_locally_connected_1d_layer, only: locally_connected_1d_layer
910
use nf_maxpool2d_layer, only: maxpool2d_layer
1011
use nf_reshape_layer, only: reshape3d_layer
1112
use nf_optimizers, only: optimizer_base_type

src/nf/nf_reshape_generalized.f90

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
module nf_reshape_layer_generalized
2+
3+
!! This module provides the concrete reshape layer type.
4+
!! It is used internally by the layer type.
5+
!! It is not intended to be used directly by the user.
6+
7+
use nf_base_layer, only: base_layer
8+
9+
implicit none
10+
11+
private
12+
public :: reshape_generalized_layer
13+
14+
type, extends(base_layer) :: reshape_generalized_layer
15+
16+
!! Concrete implementation of a reshape layer type
17+
!! It implements reshaping for arbitrary ranks.
18+
19+
integer, allocatable :: input_shape(:)
20+
integer, allocatable :: output_shape(:)
21+
real, allocatable :: gradient(:)
22+
real, allocatable :: output(:)
23+
24+
contains
25+
26+
procedure :: backward
27+
procedure :: forward
28+
procedure :: init
29+
30+
end type reshape_generalized_layer
31+
32+
interface reshape_generalized_layer
33+
pure module function reshape_layer_cons(output_shape) result(res)
34+
!! This function returns the `reshape_layer` instance.
35+
integer, intent(in) :: output_shape(:)
36+
!! The shape of the output
37+
type(reshape_generalized_layer) :: res
38+
!! reshape_layer instance
39+
end function reshape_layer_cons
40+
end interface reshape_generalized_layer
41+
42+
interface
43+
44+
pure module subroutine backward(self, input, gradient)
45+
!! Apply the backward pass for the reshape layer.
46+
!! This is just flattening to a rank-1 array.
47+
class(reshape_generalized_layer), intent(in out) :: self
48+
!! Dense layer instance
49+
real, intent(in) :: input(:)
50+
!! Input from the previous layer
51+
real, intent(in) :: gradient(..)
52+
!! Gradient from the next layer
53+
end subroutine backward
54+
55+
pure module subroutine forward(self, input)
56+
!! Apply the forward pass for the reshape layer.
57+
!! This is reshaping from input rank to output rank.
58+
class(reshape_generalized_layer), intent(in out) :: self
59+
!! Dense layer instance
60+
real, intent(in) :: input(:)
61+
!! Input from the previous layer
62+
end subroutine forward
63+
64+
module subroutine init(self, input_shape)
65+
!! Initialize the layer data structures.
66+
!!
67+
!! This is a deferred procedure from the `base_layer` abstract type.
68+
class(reshape_generalized_layer), intent(in out) :: self
69+
!! Dense layer instance
70+
integer, intent(in) :: input_shape(:)
71+
!! Shape of the input layer
72+
end subroutine init
73+
74+
end interface
75+
76+
end module nf_reshape_layer_generalized
77+
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
submodule(nf_reshape_layer_generalized) nf_reshape_layer_generalized_submodule
2+
3+
use nf_base_layer, only: base_layer
4+
5+
implicit none
6+
7+
contains
8+
9+
pure module function reshape_layer_cons(output_shape) result(res)
10+
integer, intent(in) :: output_shape(:)
11+
type(reshape_generalized_layer) :: res
12+
allocate(res % output_shape(size(output_shape)))
13+
res % output_shape = output_shape
14+
end function reshape_layer_cons
15+
16+
pure module subroutine backward(self, input, gradient)
17+
class(reshape_generalized_layer), intent(in out) :: self
18+
real, intent(in) :: input(:)
19+
real, intent(in) :: gradient(..) ! Assumed-rank gradient
20+
21+
! Handle different ranks of gradient using SELECT RANK
22+
select rank (gradient)
23+
rank default
24+
error stop "Unsupported gradient rank in reshape layer"
25+
rank (0)
26+
self % gradient = [gradient]
27+
rank (1)
28+
self % gradient = gradient
29+
rank (2)
30+
self % gradient = reshape(gradient, [size(gradient)])
31+
rank (3)
32+
self % gradient = reshape(gradient, [size(gradient)])
33+
end select
34+
35+
end subroutine backward
36+
37+
pure module subroutine forward(self, input)
38+
class(reshape_generalized_layer), intent(in out) :: self
39+
real, intent(in) :: input(:)
40+
integer :: i
41+
42+
! Ensure output is allocated
43+
if (.not. allocated(self % output)) then
44+
allocate(self % output(size(input))) ! Flattened storage
45+
end if
46+
47+
! Copy elements manually (assuming Fortran column-major order)
48+
do i = 1, size(input)
49+
self % output(i) = input(i)
50+
end do
51+
end subroutine forward
52+
53+
module subroutine init(self, input_shape)
54+
class(reshape_generalized_layer), intent(in out) :: self
55+
integer, intent(in) :: input_shape(:)
56+
57+
self % input_shape = input_shape
58+
59+
! Allocate gradient buffer based on input size
60+
allocate(self % gradient(product(input_shape)))
61+
self % gradient = 0
62+
63+
! Allocate output buffer based on output_shape
64+
allocate(self % output(product(self % output_shape)))
65+
self % output = 0
66+
67+
end subroutine init
68+
69+
end submodule nf_reshape_layer_generalized_submodule

src/nf/nf_reshape_layer_submodule.f90

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,4 @@ module subroutine init(self, input_shape)
4848

4949
end subroutine init
5050

51-
52-
5351
end submodule nf_reshape_layer_submodule

0 commit comments

Comments
 (0)