Skip to content

Commit 293f19a

Browse files
authored
Fix channels_first and add preprocess_input (#46)
* Fix channels_first, add preprocess_input * Fix for theano backend
1 parent c993591 commit 293f19a

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

efficientnet/model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from keras_applications.imagenet_utils import decode_predictions
3939

4040
from . import get_submodules_from_kwargs
41+
from .preprocessing import preprocess_input
4142

4243
backend = None
4344
layers = None
@@ -189,7 +190,7 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
189190
has_se = (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1)
190191
bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
191192

192-
# workaround over non working dropout in tf.keras
193+
# workaround over non working dropout with None in noise_shape in tf.keras
193194
Dropout = get_dropout(
194195
backend=backend,
195196
layers=layers,
@@ -226,8 +227,9 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
226227
block_args.input_filters * block_args.se_ratio
227228
))
228229
se_tensor = layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x)
229-
se_tensor = layers.Reshape((1, 1, filters),
230-
name=prefix + 'se_reshape')(se_tensor)
230+
231+
target_shape = (1, 1, filters) if backend.image_data_format() == 'channels_last' else (filters, 1, 1)
232+
se_tensor = layers.Reshape(target_shape, name=prefix + 'se_reshape')(se_tensor)
231233
se_tensor = layers.Conv2D(num_reduced_filters, 1,
232234
activation=relu_fn,
233235
padding='same',
@@ -243,8 +245,10 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
243245
if backend.backend() == 'theano':
244246
# For the Theano backend, we have to explicitly make
245247
# the excitation weights broadcastable.
248+
pattern = ([True, True, True, False] if backend.image_data_format() == 'channels_last'
249+
else [True, False, True, True])
246250
se_tensor = layers.Lambda(
247-
lambda x: backend.pattern_broadcast(x, [True, True, True, False]),
251+
lambda x: backend.pattern_broadcast(x, pattern),
248252
name=prefix + 'se_broadcast')(se_tensor)
249253
x = layers.multiply([x, se_tensor], name=prefix + 'se_excite')
250254

0 commit comments

Comments
 (0)