|
36 | 36 | from six.moves import xrange |
37 | 37 | from keras_applications.imagenet_utils import _obtain_input_shape |
38 | 38 | from keras_applications.imagenet_utils import decode_predictions |
| 39 | +from keras_applications.imagenet_utils import preprocess_input as _preprocess_input |
39 | 40 |
|
40 | 41 | from . import get_submodules_from_kwargs |
41 | | -from .preprocessing import preprocess_input |
42 | 42 |
|
43 | 43 | backend = None |
44 | 44 | layers = None |
|
76 | 76 | 'da76b3a29e1c011635376e191c2c2d54') |
77 | 77 | } |
78 | 78 |
|
79 | | -MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] |
80 | | -STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] |
81 | | - |
82 | 79 | BlockArgs = collections.namedtuple('BlockArgs', [ |
83 | 80 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', |
84 | 81 | 'expand_ratio', 'id_skip', 'strides', 'se_ratio' |
|
127 | 124 | } |
128 | 125 |
|
129 | 126 |
|
| 127 | +def preprocess_input(x, **kwargs): |
| 128 | + return _preprocess_input(x, mode='torch', **kwargs) |
| 129 | + |
| 130 | + |
130 | 131 | def swish(x): |
131 | 132 | """Swish activation function: x * sigmoid(x). |
132 | 133 | Reference: [Searching for Activation Functions](https://arxiv.org/abs/1710.05941) |
@@ -227,7 +228,7 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='', |
227 | 228 | block_args.input_filters * block_args.se_ratio |
228 | 229 | )) |
229 | 230 | se_tensor = layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x) |
230 | | - |
| 231 | + |
231 | 232 | target_shape = (1, 1, filters) if backend.image_data_format() == 'channels_last' else (filters, 1, 1) |
232 | 233 | se_tensor = layers.Reshape(target_shape, name=prefix + 'se_reshape')(se_tensor) |
233 | 234 | se_tensor = layers.Conv2D(num_reduced_filters, 1, |
|
0 commit comments