3838from keras_applications .imagenet_utils import decode_predictions
3939
4040from . import get_submodules_from_kwargs
41+ from .preprocessing import preprocess_input
4142
4243backend = None
4344layers = None
@@ -189,7 +190,7 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
189190 has_se = (block_args .se_ratio is not None ) and (0 < block_args .se_ratio <= 1 )
190191 bn_axis = 3 if backend .image_data_format () == 'channels_last' else 1
191192
192- # workaround over non working dropout in tf.keras
193+ # workaround over non working dropout with None in noise_shape in tf.keras
193194 Dropout = get_dropout (
194195 backend = backend ,
195196 layers = layers ,
@@ -226,8 +227,9 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
226227 block_args .input_filters * block_args .se_ratio
227228 ))
228229 se_tensor = layers .GlobalAveragePooling2D (name = prefix + 'se_squeeze' )(x )
229- se_tensor = layers .Reshape ((1 , 1 , filters ),
230- name = prefix + 'se_reshape' )(se_tensor )
230+
231+ target_shape = (1 , 1 , filters ) if backend .image_data_format () == 'channels_last' else (filters , 1 , 1 )
232+ se_tensor = layers .Reshape (target_shape , name = prefix + 'se_reshape' )(se_tensor )
231233 se_tensor = layers .Conv2D (num_reduced_filters , 1 ,
232234 activation = relu_fn ,
233235 padding = 'same' ,
@@ -243,8 +245,10 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
243245 if backend .backend () == 'theano' :
244246 # For the Theano backend, we have to explicitly make
245247 # the excitation weights broadcastable.
248+ pattern = ([True , True , True , False ] if backend .image_data_format () == 'channels_last'
249+ else [True , False , True , True ])
246250 se_tensor = layers .Lambda (
247- lambda x : backend .pattern_broadcast (x , [ True , True , True , False ] ),
251+ lambda x : backend .pattern_broadcast (x , pattern ),
248252 name = prefix + 'se_broadcast' )(se_tensor )
249253 x = layers .multiply ([x , se_tensor ], name = prefix + 'se_excite' )
250254
0 commit comments