1+ module nf_reshape_layer_generalized
2+
3+ ! ! This module provides the concrete reshape layer type.
4+ ! ! It is used internally by the layer type.
5+ ! ! It is not intended to be used directly by the user.
6+
7+ use nf_base_layer, only: base_layer
8+
9+ implicit none
10+
11+ private
12+ public :: reshape_generalized_layer
13+
14+ type, extends(base_layer) :: reshape_generalized_layer
15+
16+ ! ! Concrete implementation of a reshape layer type
17+ ! ! It implements reshaping for arbitrary ranks.
18+
19+ integer , allocatable :: input_shape(:)
20+ integer , allocatable :: output_shape(:)
21+ real , allocatable :: gradient(:)
22+ real , allocatable :: output(:)
23+
24+ contains
25+
26+ procedure :: backward
27+ procedure :: forward
28+ procedure :: init
29+
30+ end type reshape_generalized_layer
31+
32+ interface reshape_generalized_layer
33+ pure module function reshape_layer_cons(output_shape) result(res)
34+ ! ! This function returns the `reshape_layer` instance.
35+ integer , intent (in ) :: output_shape(:)
36+ ! ! The shape of the output
37+ type (reshape_generalized_layer) :: res
38+ ! ! reshape_layer instance
39+ end function reshape_layer_cons
40+ end interface reshape_generalized_layer
41+
42+ interface
43+
44+ pure module subroutine backward(self, input, gradient)
45+ ! ! Apply the backward pass for the reshape layer.
46+ ! ! This is just flattening to a rank-1 array.
47+ class(reshape_generalized_layer), intent (in out ) :: self
48+ ! ! Dense layer instance
49+ real , intent (in ) :: input(:)
50+ ! ! Input from the previous layer
51+ real , intent (in ) :: gradient(..)
52+ ! ! Gradient from the next layer
53+ end subroutine backward
54+
55+ pure module subroutine forward(self, input)
56+ ! ! Apply the forward pass for the reshape layer.
57+ ! ! This is reshaping from input rank to output rank.
58+ class(reshape_generalized_layer), intent (in out ) :: self
59+ ! ! Dense layer instance
60+ real , intent (in ) :: input(:)
61+ ! ! Input from the previous layer
62+ end subroutine forward
63+
64+ module subroutine init (self , input_shape )
65+ ! ! Initialize the layer data structures.
66+ ! !
67+ ! ! This is a deferred procedure from the `base_layer` abstract type.
68+ class(reshape_generalized_layer), intent (in out ) :: self
69+ ! ! Dense layer instance
70+ integer , intent (in ) :: input_shape(:)
71+ ! ! Shape of the input layer
72+ end subroutine init
73+
74+ end interface
75+
76+ end module nf_reshape_layer_generalized
77+
0 commit comments