Skip to content

Commit 0ef3ffa

Browse files
committed
Update model
1 parent 9fa1c32 commit 0ef3ffa

File tree

2 files changed

+129
-109
lines changed

2 files changed

+129
-109
lines changed

examples/src/config.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,15 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:TINY})
3434
fwd_maxiters=18,
3535
bwd_maxiters=20,
3636
continuous=true,
37-
stop_mode=:rel_deq_best,
37+
stop_mode=:rel_norm,
3838
nepochs=50,
3939
jfb=false,
4040
augment=false,
4141
model_type=:VANILLA,
4242
abstol=5.0f-2,
4343
reltol=5.0f-2,
4444
ode_solver=VCABM3(),
45-
pretrain_epochs=8,
45+
pretrain_epochs=5,
4646
lr_scheduler=:COSINE,
4747
optimiser=:ADAM,
4848
eta=0.001f0 * scaling_factor(),
@@ -72,7 +72,7 @@ function get_default_experiment_configuration(::Val{:CIFAR10}, ::Val{:LARGE})
7272
fwd_maxiters=18,
7373
bwd_maxiters=20,
7474
continuous=true,
75-
stop_mode=:rel_deq_best,
75+
stop_mode=:rel_norm,
7676
nepochs=220,
7777
jfb=false,
7878
augment=true,
@@ -110,7 +110,7 @@ function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:SMALL})
110110
fwd_maxiters=27,
111111
bwd_maxiters=28,
112112
continuous=true,
113-
stop_mode=:rel_deq_best,
113+
stop_mode=:rel_norm,
114114
nepochs=100,
115115
jfb=false,
116116
model_type=:VANILLA,
@@ -150,7 +150,7 @@ function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:LARGE})
150150
fwd_maxiters=27,
151151
bwd_maxiters=28,
152152
continuous=true,
153-
stop_mode=:rel_deq_best,
153+
stop_mode=:rel_norm,
154154
nepochs=100,
155155
jfb=false,
156156
model_type=:VANILLA,
@@ -190,7 +190,7 @@ function get_default_experiment_configuration(::Val{:IMAGENET}, ::Val{:XL})
190190
fwd_maxiters=27,
191191
bwd_maxiters=28,
192192
continuous=true,
193-
stop_mode=:rel_deq_best,
193+
stop_mode=:rel_norm,
194194
nepochs=100,
195195
jfb=false,
196196
model_type=:VANILLA,

examples/src/models.jl

Lines changed: 123 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,18 @@ function upsample_module(mapping, level_diff, activation; upsample_mode::Symbol=
7070
end
7171

7272
## Residual Block
73-
function ResidualBlockV1(
73+
struct ResidualBlock{C1,C2,Dr,Do,N1,N2,N3} <:
74+
Lux.AbstractExplicitContainerLayer{(:conv1, :conv2, :dropout, :downsample, :norm1, :norm2, :norm3)}
75+
conv1::C1
76+
conv2::C2
77+
dropout::Dr
78+
downsample::Do
79+
norm1::N1
80+
norm2::N2
81+
norm3::N3
82+
end
83+
84+
function ResidualBlock(
7485
mapping;
7586
deq_expand::Int=5,
7687
num_gn_groups::Int=4,
@@ -92,74 +103,81 @@ function ResidualBlockV1(
92103
conv1, conv2
93104
end
94105

95-
# gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats)
96-
# gn2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
97-
# gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
98-
gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats)
99-
gn2 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats)
100-
gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats)
106+
# norm1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats)
107+
# norm2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
108+
# norm3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
109+
norm1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats)
110+
norm2 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats)
111+
norm3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats)
101112

102-
dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate)
113+
dropout = VariationalHiddenDropout(dropout_rate)
103114

104-
return Chain(
105-
Parallel(
106-
reassociate, # Reassociate and Merge
107-
Chain(conv1, gn1, conv2, BranchLayer(downsample, dropout)), # For x
108-
NoOpLayer(), # For injection
109-
),
110-
Parallel(
111-
addrelu,
112-
NoOpLayer(), # For y1
113-
Chain(
114-
WrappedFunction(addtuple), # Since injection could be a scalar
115-
gn2,
116-
), # For (y2, injection)
115+
return ResidualBlock(conv1, conv2, dropout, downsample, norm1, norm2, norm3)
116+
end
117+
118+
function (rb::ResidualBlock)((x, y)::NTuple{2,<:AbstractArray}, ps, st)
119+
x, st_conv1 = rb.conv1(x, ps.conv1, st.conv1)
120+
x, st_norm1 = rb.norm1(x, ps.norm1, st.norm1)
121+
x, st_conv2 = rb.conv2(x, ps.conv2, st.conv2)
122+
123+
x_do, st_downsample = rb.downsample(x, ps.downsample, st.downsample)
124+
x_dr, st_dropout = rb.dropout(x, ps.dropout, st.dropout)
125+
126+
y_ = x_dr .+ y
127+
y_, st_norm2 = rb.norm2(y_, ps.norm2, st.norm2)
128+
129+
y__ = relu.(y_ .+ x_do)
130+
y__, st_norm3 = rb.norm3(y__, ps.norm3, st.norm3)
131+
132+
return (
133+
y__,
134+
(
135+
conv1=st_conv1,
136+
conv2=st_conv2,
137+
dropout=st_dropout,
138+
downsample=st_downsample,
139+
norm1=st_norm1,
140+
norm2=st_norm2,
141+
norm3=st_norm3,
117142
),
118-
gn3,
119143
)
120144
end
121145

122-
function ResidualBlockV2(
123-
mapping;
124-
deq_expand::Int=1,
125-
num_gn_groups::Int=4,
126-
downsample=NoOpLayer(),
127-
n_big_kernels::Int=0,
128-
dropout_rate::Real=0.0f0,
129-
gn_affine::Bool=true,
130-
weight_norm::Bool=true,
131-
gn_track_stats::Bool=false,
132-
)
133-
inplanes, outplanes = mapping
134-
inner_planes = outplanes * deq_expand
135-
conv1 = (n_big_kernels >= 1 ? conv5x5 : conv3x3)(inplanes => inner_planes; bias=false)
136-
conv2 = (n_big_kernels >= 2 ? conv5x5 : conv3x3)(inner_planes => outplanes; bias=false)
137-
138-
conv1, conv2 = if weight_norm
139-
WeightNorm(conv1, (:weight,), (4,)), WeightNorm(conv2, (:weight,), (4,))
140-
else
141-
conv1, conv2
142-
end
143-
144-
# gn1 = GroupNorm(inner_planes, num_gn_groups, relu; affine=gn_affine, track_stats=gn_track_stats)
145-
# gn2 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
146-
# gn3 = GroupNorm(outplanes, num_gn_groups; affine=gn_affine, track_stats=gn_track_stats)
147-
gn1 = BatchNorm(inner_planes, relu; affine=gn_affine, track_stats=gn_track_stats)
148-
gn2 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats)
149-
gn3 = BatchNorm(outplanes; affine=gn_affine, track_stats=gn_track_stats)
150-
151-
dropout = iszero(dropout_rate) ? NoOpLayer() : VariationalHiddenDropout(dropout_rate)
152-
153-
return Chain(
154-
conv1,
155-
gn1,
156-
conv2,
157-
Parallel(addrelu, downsample, Chain(dropout, gn2)),
158-
gn3,
146+
function (rb::ResidualBlock)(x::AbstractArray, ps, st)
147+
x, st_conv1 = rb.conv1(x, ps.conv1, st.conv1)
148+
x, st_norm1 = rb.norm1(x, ps.norm1, st.norm1)
149+
x, st_conv2 = rb.conv2(x, ps.conv2, st.conv2)
150+
151+
x_do, st_downsample = rb.downsample(x, ps.downsample, st.downsample)
152+
153+
x_dr, st_dropout = rb.dropout(x, ps.dropout, st.dropout)
154+
x_dr, st_norm2 = rb.norm2(x_dr, ps.norm2, st.norm2)
155+
156+
y__ = relu.(x_dr .+ x_do)
157+
y__, st_norm3 = rb.norm3(y__, ps.norm3, st.norm3)
158+
159+
return (
160+
y__,
161+
(
162+
conv1=st_conv1,
163+
conv2=st_conv2,
164+
dropout=st_dropout,
165+
downsample=st_downsample,
166+
norm1=st_norm1,
167+
norm2=st_norm2,
168+
norm3=st_norm3,
169+
),
159170
)
160171
end
161172

162-
function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=false, bn_affine::Bool=true)
173+
# Bottleneck Block
174+
struct BottleneckBlock{R,C,M} <: Lux.AbstractExplicitContainerLayer{(:rescale, :conv, :mapping)}
175+
rescale::R
176+
conv::C
177+
mapping::M
178+
end
179+
180+
function BottleneckBlock(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true)
163181
rescale = if first(mapping) != last(mapping) * expansion
164182
Chain(
165183
conv1x1(first(mapping) => last(mapping) * expansion),
@@ -169,48 +187,48 @@ function BottleneckBlockV1(mapping::Pair, expansion::Int=4; bn_track_stats::Bool
169187
NoOpLayer()
170188
end
171189

172-
return Chain(
173-
Parallel(reassociate, BranchLayer(rescale, conv1x1(mapping)), NoOpLayer()),
174-
Parallel(
175-
addrelu,
176-
NoOpLayer(),
177-
Chain(
178-
WrappedFunction(addtuple), # Since injection could be a scalar
179-
Chain(
180-
BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats),
181-
conv3x3(last(mapping) => last(mapping)),
182-
BatchNorm(last(mapping), relu; track_stats=bn_track_stats, affine=bn_affine),
183-
conv1x1(last(mapping) => last(mapping) * expansion),
184-
BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine),
185-
),
186-
),
187-
),
190+
return BottleneckBlock(
191+
rescale,
192+
conv1x1(mapping),
193+
Chain(
194+
BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats),
195+
conv3x3(last(mapping) => last(mapping)),
196+
BatchNorm(last(mapping), relu; track_stats=bn_track_stats, affine=bn_affine),
197+
conv1x1(last(mapping) => last(mapping) * expansion),
198+
BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine)
199+
)
188200
)
189201
end
190202

191-
function BottleneckBlockV2(mapping::Pair, expansion::Int=4; bn_track_stats::Bool=true, bn_affine::Bool=true)
192-
rescale = if first(mapping) != last(mapping) * expansion
193-
Chain(
194-
conv1x1(first(mapping) => last(mapping) * expansion),
195-
BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine),
203+
function (bn::BottleneckBlock)((x, y)::NTuple{2,<:AbstractArray}, ps, st)
204+
x_r, st_rescale = bn.rescale(x, ps.rescale, st.rescale)
205+
x_m, st_conv1 = bn.conv(x_r, ps.conv, st.conv)
206+
207+
x_m = y .+ x_m
208+
x_m, st_mapping = bn.mapping(x_m, ps.mapping, st.mapping)
209+
210+
return (
211+
relu.(x_m .+ x_r),
212+
(
213+
rescale=st_rescale,
214+
conv=st_conv1,
215+
mapping=st_mapping,
196216
)
197-
else
198-
NoOpLayer()
199-
end
217+
)
218+
end
200219

201-
return Chain(
202-
Parallel(
203-
addrelu,
204-
rescale,
205-
Chain(
206-
conv1x1(mapping),
207-
BatchNorm(last(mapping), relu; affine=bn_affine, track_stats=bn_track_stats),
208-
conv3x3(last(mapping) => last(mapping)),
209-
BatchNorm(last(mapping), relu; track_stats=bn_track_stats, affine=bn_affine),
210-
conv1x1(last(mapping) => last(mapping) * expansion),
211-
BatchNorm(last(mapping) * expansion; track_stats=bn_track_stats, affine=bn_affine),
212-
),
213-
),
220+
function (bn::BottleneckBlock)(x::AbstractArray, ps, st)
221+
x_r, st_rescale = bn.rescale(x, ps.rescale, st.rescale)
222+
x_m, st_conv1 = bn.conv(x_r, ps.conv, st.conv)
223+
x_m, st_mapping = bn.mapping(x_m, ps.mapping, st.mapping)
224+
225+
return (
226+
relu.(x_m .+ x_r),
227+
(
228+
rescale=st_rescale,
229+
conv=st_conv1,
230+
mapping=st_mapping,
231+
)
214232
)
215233
end
216234

@@ -219,7 +237,7 @@ function get_model(
219237
config::NamedTuple;
220238
device=gpu,
221239
warmup::Bool=true, # Helps reduce Zygote compile times
222-
loss_function=nothing
240+
loss_function=nothing,
223241
)
224242
@assert !warmup || loss_function !== nothing
225243

@@ -254,7 +272,7 @@ function get_model(
254272
initial_layers = Chain(downsample, stage0)
255273

256274
main_layers = Tuple(
257-
ResidualBlockV1(
275+
ResidualBlock(
258276
config.num_channels[i] => config.num_channels[i];
259277
deq_expand=config.expansion_factor,
260278
dropout_rate=config.dropout_rate,
@@ -295,7 +313,7 @@ function get_model(
295313

296314
increment_modules = Parallel(
297315
nothing,
298-
[BottleneckBlockV2(config.num_channels[i] => config.head_channels[i]) for i in 1:(config.num_branches)]...,
316+
[BottleneckBlock(config.num_channels[i] => config.head_channels[i]) for i in 1:(config.num_branches)]...,
299317
)
300318

301319
downsample_modules = PairwiseFusion(
@@ -322,8 +340,8 @@ function get_model(
322340
ContinuousDEQSolver(
323341
config.ode_solver;
324342
mode=config.stop_mode,
325-
abstol=1.0f-3,
326-
reltol=1.0f-3,
343+
abstol=config.abstol,
344+
reltol=config.reltol,
327345
abstol_termination=config.abstol,
328346
reltol_termination=config.reltol,
329347
)
@@ -342,7 +360,9 @@ function get_model(
342360

343361
deq = if config.model_type (:SKIP, :SKIPV2)
344362
shortcut = if config.model_type == :SKIP
345-
slayers = Lux.AbstractExplicitLayer[ResidualBlockV2(config.num_channels[1] => config.num_channels[1], weight_norm=true)]
363+
slayers = Lux.AbstractExplicitLayer[ResidualBlock(
364+
config.num_channels[1] => config.num_channels[1]; weight_norm=true
365+
)]
346366
for i in 1:(config.num_branches - 1)
347367
push!(
348368
slayers,

0 commit comments

Comments
 (0)