Skip to content

Commit b4e2303

Browse files
committed
little modifies
1 parent 5e8eaed commit b4e2303

File tree

4 files changed

+104
-87
lines changed

4 files changed

+104
-87
lines changed

src/nf.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module nf
33
use nf_datasets_mnist, only: label_digits, load_mnist
44
use nf_layer, only: layer
55
use nf_layer_constructors, only: &
6-
conv2d, dense, flatten, input, maxpool2d, reshape, locally_connected_1d
6+
conv2d, dense, flatten, input, maxpool2d, reshape, reshape_generalized, locally_connected_1d
77
use nf_loss, only: mse, quadratic
88
use nf_metrics, only: corr, maxabs
99
use nf_network, only: network

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,16 @@ module function reshape(output_shape) result(res)
162162
end function reshape
163163

164164
module function reshape_generalized(output_shape) result(res)
165-
integer, intent(in) :: output_shape(:)
165+
integer, intent(in) :: output_shape(:) !! Always treat as an array
166166
type(layer) :: res
167-
167+
168168
res % name = 'reshape_generalized'
169169
res % layer_shape = output_shape
170170

171171
allocate(res % p, source=reshape_generalized_layer(output_shape))
172-
172+
173173
end function reshape_generalized
174+
175+
174176

175177
end submodule nf_layer_constructors_submodule

src/nf/nf_reshape_generalized.f90

Lines changed: 74 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,76 @@
11
module nf_reshape_layer_generalized
22

3-
!! This module provides the concrete reshape layer type.
4-
!! It is used internally by the layer type.
5-
!! It is not intended to be used directly by the user.
6-
7-
use nf_base_layer, only: base_layer
8-
9-
implicit none
10-
11-
private
12-
public :: reshape_generalized_layer
13-
14-
type, extends(base_layer) :: reshape_generalized_layer
15-
16-
!! Concrete implementation of a reshape layer type
17-
!! It implements reshaping for arbitrary ranks.
18-
19-
integer, allocatable :: input_shape(:)
20-
integer, allocatable :: output_shape(:)
21-
real, allocatable :: gradient(:)
22-
real, allocatable :: output(:)
23-
24-
contains
25-
26-
procedure :: backward
27-
procedure :: forward
28-
procedure :: init
29-
30-
end type reshape_generalized_layer
31-
32-
interface reshape_generalized_layer
33-
pure module function reshape_layer_cons(output_shape) result(res)
34-
!! This function returns the `reshape_layer` instance.
35-
integer, intent(in) :: output_shape(:)
36-
!! The shape of the output
37-
type(reshape_generalized_layer) :: res
38-
!! reshape_layer instance
39-
end function reshape_layer_cons
40-
end interface reshape_generalized_layer
41-
42-
interface
43-
44-
pure module subroutine backward(self, input, gradient)
45-
!! Apply the backward pass for the reshape layer.
46-
!! This is just flattening to a rank-1 array.
47-
class(reshape_generalized_layer), intent(in out) :: self
48-
!! Dense layer instance
49-
real, intent(in) :: input(:)
50-
!! Input from the previous layer
51-
real, intent(in) :: gradient(..)
52-
!! Gradient from the next layer
53-
end subroutine backward
54-
55-
pure module subroutine forward(self, input)
56-
!! Apply the forward pass for the reshape layer.
57-
!! This is reshaping from input rank to output rank.
58-
class(reshape_generalized_layer), intent(in out) :: self
59-
!! Dense layer instance
60-
real, intent(in) :: input(:)
61-
!! Input from the previous layer
62-
end subroutine forward
63-
64-
module subroutine init(self, input_shape)
65-
!! Initialize the layer data structures.
66-
!!
67-
!! This is a deferred procedure from the `base_layer` abstract type.
68-
class(reshape_generalized_layer), intent(in out) :: self
69-
!! Dense layer instance
70-
integer, intent(in) :: input_shape(:)
71-
!! Shape of the input layer
72-
end subroutine init
73-
74-
end interface
75-
76-
end module nf_reshape_layer_generalized
77-
3+
!! This module provides the concrete reshape layer type.
4+
!! It is used internally by the layer type.
5+
!! It is not intended to be used directly by the user.
6+
7+
use nf_base_layer, only: base_layer
8+
9+
implicit none
10+
11+
private
12+
public :: reshape_generalized_layer
13+
14+
type, extends(base_layer) :: reshape_generalized_layer
15+
16+
!! Concrete implementation of a reshape layer type
17+
!! It implements reshaping for arbitrary ranks.
18+
19+
integer, allocatable :: input_shape(:)
20+
integer, allocatable :: output_shape(:)
21+
real, allocatable :: gradient(:)
22+
real, allocatable :: output(:)
23+
24+
contains
25+
26+
procedure :: backward
27+
procedure :: forward
28+
procedure :: init
29+
30+
end type reshape_generalized_layer
31+
32+
interface reshape_generalized_layer
33+
pure module function reshape_layer_cons(output_shape) result(res)
34+
!! This function returns the `reshape_layer` instance.
35+
integer, intent(in) :: output_shape(:)
36+
!! The shape of the output
37+
type(reshape_generalized_layer) :: res
38+
!! reshape_layer instance
39+
end function reshape_layer_cons
40+
end interface reshape_generalized_layer
41+
42+
interface
43+
44+
pure module subroutine backward(self, input, gradient)
45+
!! Apply the backward pass for the reshape layer.
46+
!! This is just flattening to a rank-1 array.
47+
class(reshape_generalized_layer), intent(in out) :: self
48+
!! Dense layer instance
49+
real, intent(in) :: input(:)
50+
!! Input from the previous layer
51+
real, intent(in) :: gradient(..)
52+
!! Gradient from the next layer
53+
end subroutine backward
54+
55+
pure module subroutine forward(self, input)
56+
!! Apply the forward pass for the reshape layer.
57+
!! This is reshaping from input rank to output rank.
58+
class(reshape_generalized_layer), intent(in out) :: self
59+
!! Dense layer instance
60+
real, intent(in) :: input(:)
61+
!! Input from the previous layer
62+
end subroutine forward
63+
64+
module subroutine init(self, input_shape)
65+
!! Initialize the layer data structures.
66+
!! This is a deferred procedure from the `base_layer` abstract type.
67+
class(reshape_generalized_layer), intent(in out) :: self
68+
!! Dense layer instance
69+
integer, intent(in) :: input_shape(:)
70+
!! Shape of the input layer
71+
end subroutine init
72+
73+
end interface
74+
75+
end module nf_reshape_layer_generalized
76+

src/nf/nf_reshape_generalized_submodule.f90

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,18 @@
99
pure module function reshape_layer_cons(output_shape) result(res)
1010
integer, intent(in) :: output_shape(:)
1111
type(reshape_generalized_layer) :: res
12-
allocate(res % output_shape(size(output_shape)))
13-
res % output_shape = output_shape
12+
13+
! Check if output_shape is scalar (size 1)
14+
if (size(output_shape) == 1) then
15+
allocate(res % output_shape(1))
16+
res % output_shape = output_shape
17+
else
18+
allocate(res % output_shape(size(output_shape)))
19+
res % output_shape = output_shape
20+
end if
1421
end function reshape_layer_cons
1522

23+
1624
pure module subroutine backward(self, input, gradient)
1725
class(reshape_generalized_layer), intent(in out) :: self
1826
real, intent(in) :: input(:)
@@ -56,14 +64,22 @@ module subroutine init(self, input_shape)
5664

5765
self % input_shape = input_shape
5866

59-
! Allocate gradient buffer based on input size
60-
allocate(self % gradient(product(input_shape)))
67+
!! Handle scalar input (size 1) or non-scalar
68+
if (size(input_shape) == 1) then
69+
allocate(self % gradient(1))
70+
else
71+
allocate(self % gradient(product(input_shape)))
72+
end if
6173
self % gradient = 0
6274

63-
! Allocate output buffer based on output_shape
64-
allocate(self % output(product(self % output_shape)))
75+
!! Handle scalar output_shape (size 1) or non-scalar
76+
if (size(self % output_shape) == 1) then
77+
allocate(self % output(1))
78+
else
79+
allocate(self % output(product(self % output_shape)))
80+
end if
6581
self % output = 0
66-
67-
end subroutine init
82+
end subroutine init
83+
6884

6985
end submodule nf_reshape_layer_generalized_submodule

0 commit comments

Comments
 (0)