@@ -70,7 +70,18 @@ function upsample_module(mapping, level_diff, activation; upsample_mode::Symbol=
7070end
7171
7272# # Residual Block
73- function ResidualBlockV1 (
73+ struct ResidualBlock{C1,C2,Dr,Do,N1,N2,N3} < :
74+ Lux. AbstractExplicitContainerLayer{(:conv1 , :conv2 , :dropout , :downsample , :norm1 , :norm2 , :norm3 )}
75+ conv1:: C1
76+ conv2:: C2
77+ dropout:: Dr
78+ downsample:: Do
79+ norm1:: N1
80+ norm2:: N2
81+ norm3:: N3
82+ end
83+
84+ function ResidualBlock (
7485 mapping;
7586 deq_expand:: Int = 5 ,
7687 num_gn_groups:: Int = 4 ,
@@ -92,74 +103,81 @@ function ResidualBlockV1(
92103 conv1, conv2
93104 end
94105
95- # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats)
96- # gn2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
97- # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
98- gn1 = BatchNorm (inner_planes, relu; affine= gn_affine, track_stats= gn_track_stats)
99- gn2 = BatchNorm (outplanes; affine= gn_affine, track_stats= gn_track_stats)
100- gn3 = BatchNorm (outplanes; affine= gn_affine, track_stats= gn_track_stats)
106+ # norm1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats)
107+ # norm2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
108+ # norm3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
109+ norm1 = BatchNorm (inner_planes, relu; affine= gn_affine, track_stats= gn_track_stats)
110+ norm2 = BatchNorm (outplanes; affine= gn_affine, track_stats= gn_track_stats)
111+ norm3 = BatchNorm (outplanes; affine= gn_affine, track_stats= gn_track_stats)
101112
102- dropout = iszero (dropout_rate) ? NoOpLayer () : VariationalHiddenDropout (dropout_rate)
113+ dropout = VariationalHiddenDropout (dropout_rate)
103114
104- return Chain (
105- Parallel (
106- reassociate, # Reassociate and Merge
107- Chain (conv1, gn1, conv2, BranchLayer (downsample, dropout)), # For x
108- NoOpLayer (), # For injection
109- ),
110- Parallel (
111- addrelu,
112- NoOpLayer (), # For y1
113- Chain (
114- WrappedFunction (addtuple), # Since injection could be a scalar
115- gn2,
116- ), # For (y2, injection)
115+ return ResidualBlock (conv1, conv2, dropout, downsample, norm1, norm2, norm3)
116+ end
117+
118+ function (rb:: ResidualBlock )((x, y):: NTuple{2,<:AbstractArray} , ps, st)
119+ x, st_conv1 = rb. conv1 (x, ps. conv1, st. conv1)
120+ x, st_norm1 = rb. norm1 (x, ps. norm1, st. norm1)
121+ x, st_conv2 = rb. conv2 (x, ps. conv2, st. conv2)
122+
123+ x_do, st_downsample = rb. downsample (x, ps. downsample, st. downsample)
124+ x_dr, st_dropout = rb. dropout (x, ps. dropout, st. dropout)
125+
126+ y_ = x_dr .+ y
127+ y_, st_norm2 = rb. norm2 (y_, ps. norm2, st. norm2)
128+
129+ y__ = relu .(y_ .+ x_do)
130+ y__, st_norm3 = rb. norm3 (y__, ps. norm3, st. norm3)
131+
132+ return (
133+ y__,
134+ (
135+ conv1= st_conv1,
136+ conv2= st_conv2,
137+ dropout= st_dropout,
138+ downsample= st_downsample,
139+ norm1= st_norm1,
140+ norm2= st_norm2,
141+ norm3= st_norm3,
117142 ),
118- gn3,
119143 )
120144end
121145
122- function ResidualBlockV2 (
123- mapping;
124- deq_expand:: Int = 1 ,
125- num_gn_groups:: Int = 4 ,
126- downsample= NoOpLayer (),
127- n_big_kernels:: Int = 0 ,
128- dropout_rate:: Real = 0.0f0 ,
129- gn_affine:: Bool = true ,
130- weight_norm:: Bool = true ,
131- gn_track_stats:: Bool = false ,
132- )
133- inplanes, outplanes = mapping
134- inner_planes = outplanes * deq_expand
135- conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias= false )
136- conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias= false )
137-
138- conv1, conv2 = if weight_norm
139- WeightNorm (conv1, (:weight ,), (4 ,)), WeightNorm (conv2, (:weight ,), (4 ,))
140- else
141- conv1, conv2
142- end
143-
144- # gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats)
145- # gn2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
146- # gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
147- gn1 = BatchNorm (inner_planes, relu; affine= gn_affine, track_stats= gn_track_stats)
148- gn2 = BatchNorm (outplanes; affine= gn_affine, track_stats= gn_track_stats)
149- gn3 = BatchNorm (outplanes; affine= gn_affine, track_stats= gn_track_stats)
150-
151- dropout = iszero (dropout_rate) ? NoOpLayer () : VariationalHiddenDropout (dropout_rate)
152-
153- return Chain (
154- conv1,
155- gn1,
156- conv2,
157- Parallel (addrelu, downsample, Chain (dropout, gn2)),
158- gn3,
146+ function (rb:: ResidualBlock )(x:: AbstractArray , ps, st)
147+ x, st_conv1 = rb. conv1 (x, ps. conv1, st. conv1)
148+ x, st_norm1 = rb. norm1 (x, ps. norm1, st. norm1)
149+ x, st_conv2 = rb. conv2 (x, ps. conv2, st. conv2)
150+
151+ x_do, st_downsample = rb. downsample (x, ps. downsample, st. downsample)
152+
153+ x_dr, st_dropout = rb. dropout (x, ps. dropout, st. dropout)
154+ x_dr, st_norm2 = rb. norm2 (x_dr, ps. norm2, st. norm2)
155+
156+ y__ = relu .(x_dr .+ x_do)
157+ y__, st_norm3 = rb. norm3 (y__, ps. norm3, st. norm3)
158+
159+ return (
160+ y__,
161+ (
162+ conv1= st_conv1,
163+ conv2= st_conv2,
164+ dropout= st_dropout,
165+ downsample= st_downsample,
166+ norm1= st_norm1,
167+ norm2= st_norm2,
168+ norm3= st_norm3,
169+ ),
159170 )
160171end
161172
162- function BottleneckBlockV1 (mapping:: Pair , expansion:: Int = 4 ; bn_track_stats:: Bool = false , bn_affine:: Bool = true )
173+ # Bottleneck Block
174+ struct BottleneckBlock{R,C,M} <: Lux.AbstractExplicitContainerLayer{(:rescale, :conv, :mapping)}
175+ rescale:: R
176+ conv:: C
177+ mapping:: M
178+ end
179+
180+ function BottleneckBlock (mapping:: Pair , expansion:: Int = 4 ; bn_track_stats:: Bool = true , bn_affine:: Bool = true )
163181 rescale = if first (mapping) != last (mapping) * expansion
164182 Chain (
165183 conv1x1 (first (mapping) => last (mapping) * expansion),
@@ -169,48 +187,48 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool
169187 NoOpLayer ()
170188 end
171189
172- return Chain (
173- Parallel (reassociate, BranchLayer (rescale, conv1x1 (mapping)), NoOpLayer ()),
174- Parallel (
175- addrelu,
176- NoOpLayer (),
177- Chain (
178- WrappedFunction (addtuple), # Since injection could be a scalar
179- Chain (
180- BatchNorm (last (mapping), relu; affine= bn_affine, track_stats= bn_track_stats),
181- conv3x3 (last (mapping) => last (mapping)),
182- BatchNorm (last (mapping), relu; track_stats= bn_track_stats, affine= bn_affine),
183- conv1x1 (last (mapping) => last (mapping) * expansion),
184- BatchNorm (last (mapping) * expansion; track_stats= bn_track_stats, affine= bn_affine),
185- ),
186- ),
187- ),
190+ return BottleneckBlock (
191+ rescale,
192+ conv1x1 (mapping),
193+ Chain (
194+ BatchNorm (last (mapping), relu; affine= bn_affine, track_stats= bn_track_stats),
195+ conv3x3 (last (mapping) => last (mapping)),
196+ BatchNorm (last (mapping), relu; track_stats= bn_track_stats, affine= bn_affine),
197+ conv1x1 (last (mapping) => last (mapping) * expansion),
198+ BatchNorm (last (mapping) * expansion; track_stats= bn_track_stats, affine= bn_affine)
199+ )
188200 )
189201end
190202
191- function BottleneckBlockV2 (mapping:: Pair , expansion:: Int = 4 ; bn_track_stats:: Bool = true , bn_affine:: Bool = true )
192- rescale = if first (mapping) != last (mapping) * expansion
193- Chain (
194- conv1x1 (first (mapping) => last (mapping) * expansion),
195- BatchNorm (last (mapping) * expansion; track_stats= bn_track_stats, affine= bn_affine),
203+ function (bn:: BottleneckBlock )((x, y):: NTuple{2,<:AbstractArray} , ps, st)
204+ x_r, st_rescale = bn. rescale (x, ps. rescale, st. rescale)
205+ x_m, st_conv1 = bn. conv (x_r, ps. conv, st. conv)
206+
207+ x_m = y .+ x_m
208+ x_m, st_mapping = bn. mapping (x_m, ps. mapping, st. mapping)
209+
210+ return (
211+ relu .(x_m .+ x_r),
212+ (
213+ rescale= st_rescale,
214+ conv= st_conv1,
215+ mapping= st_mapping,
196216 )
197- else
198- NoOpLayer ()
199- end
217+ )
218+ end
200219
201- return Chain (
202- Parallel (
203- addrelu,
204- rescale,
205- Chain (
206- conv1x1 (mapping),
207- BatchNorm (last (mapping), relu; affine= bn_affine, track_stats= bn_track_stats),
208- conv3x3 (last (mapping) => last (mapping)),
209- BatchNorm (last (mapping), relu; track_stats= bn_track_stats, affine= bn_affine),
210- conv1x1 (last (mapping) => last (mapping) * expansion),
211- BatchNorm (last (mapping) * expansion; track_stats= bn_track_stats, affine= bn_affine),
212- ),
213- ),
220+ function (bn:: BottleneckBlock )(x:: AbstractArray , ps, st)
221+ x_r, st_rescale = bn. rescale (x, ps. rescale, st. rescale)
222+ x_m, st_conv1 = bn. conv (x_r, ps. conv, st. conv)
223+ x_m, st_mapping = bn. mapping (x_m, ps. mapping, st. mapping)
224+
225+ return (
226+ relu .(x_m .+ x_r),
227+ (
228+ rescale= st_rescale,
229+ conv= st_conv1,
230+ mapping= st_mapping,
231+ )
214232 )
215233end
216234
@@ -219,7 +237,7 @@ function get_model(
219237 config:: NamedTuple ;
220238 device= gpu,
221239 warmup:: Bool = true , # Helps reduce Zygote compile times
222- loss_function= nothing
240+ loss_function= nothing ,
223241)
224242 @assert ! warmup || loss_function != = nothing
225243
@@ -254,7 +272,7 @@ function get_model(
254272 initial_layers = Chain (downsample, stage0)
255273
256274 main_layers = Tuple (
257- ResidualBlockV1 (
275+ ResidualBlock (
258276 config. num_channels[i] => config. num_channels[i];
259277 deq_expand= config. expansion_factor,
260278 dropout_rate= config. dropout_rate,
@@ -295,7 +313,7 @@ function get_model(
295313
296314 increment_modules = Parallel (
297315 nothing ,
298- [BottleneckBlockV2 (config. num_channels[i] => config. head_channels[i]) for i in 1 : (config. num_branches)]. .. ,
316+ [BottleneckBlock (config. num_channels[i] => config. head_channels[i]) for i in 1 : (config. num_branches)]. .. ,
299317 )
300318
301319 downsample_modules = PairwiseFusion (
@@ -322,8 +340,8 @@ function get_model(
322340 ContinuousDEQSolver (
323341 config. ode_solver;
324342 mode= config. stop_mode,
325- abstol= 1.0f-3 ,
326- reltol= 1.0f-3 ,
343+ abstol= config . abstol ,
344+ reltol= config . reltol ,
327345 abstol_termination= config. abstol,
328346 reltol_termination= config. reltol,
329347 )
@@ -342,7 +360,9 @@ function get_model(
342360
343361 deq = if config. model_type ∈ (:SKIP , :SKIPV2 )
344362 shortcut = if config. model_type == :SKIP
345- slayers = Lux. AbstractExplicitLayer[ResidualBlockV2 (config. num_channels[1 ] => config. num_channels[1 ], weight_norm= true )]
363+ slayers = Lux. AbstractExplicitLayer[ResidualBlock (
364+ config. num_channels[1 ] => config. num_channels[1 ]; weight_norm= true
365+ )]
346366 for i in 1 : (config. num_branches - 1 )
347367 push! (
348368 slayers,
0 commit comments