Skip to content

Commit 7133d30

Browse files
authored
Refactor preprocess_input (#49)
* Use keras-applications preprocessing
1 parent 293f19a commit 7133d30

File tree

4 files changed

+15
-23
lines changed

4 files changed

+15
-23
lines changed

efficientnet/keras.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from . import inject_keras_modules, init_keras_custom_objects
22
from . import model
33

4-
from .preprocessing import center_crop_and_resize, preprocess_input
4+
from .preprocessing import center_crop_and_resize
55

66
EfficientNetB0 = inject_keras_modules(model.EfficientNetB0)
77
EfficientNetB1 = inject_keras_modules(model.EfficientNetB1)
@@ -12,4 +12,6 @@
1212
EfficientNetB6 = inject_keras_modules(model.EfficientNetB6)
1313
EfficientNetB7 = inject_keras_modules(model.EfficientNetB7)
1414

15+
preprocess_input = inject_keras_modules(model.preprocess_input)
16+
1517
init_keras_custom_objects()

efficientnet/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
from six.moves import xrange
3737
from keras_applications.imagenet_utils import _obtain_input_shape
3838
from keras_applications.imagenet_utils import decode_predictions
39+
from keras_applications.imagenet_utils import preprocess_input as _preprocess_input
3940

4041
from . import get_submodules_from_kwargs
41-
from .preprocessing import preprocess_input
4242

4343
backend = None
4444
layers = None
@@ -76,9 +76,6 @@
7676
'da76b3a29e1c011635376e191c2c2d54')
7777
}
7878

79-
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
80-
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
81-
8279
BlockArgs = collections.namedtuple('BlockArgs', [
8380
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
8481
'expand_ratio', 'id_skip', 'strides', 'se_ratio'
@@ -127,6 +124,10 @@
127124
}
128125

129126

127+
def preprocess_input(x, **kwargs):
128+
return _preprocess_input(x, mode='torch', **kwargs)
129+
130+
130131
def swish(x):
131132
"""Swish activation function: x * sigmoid(x).
132133
Reference: [Searching for Activation Functions](https://arxiv.org/abs/1710.05941)
@@ -227,7 +228,7 @@ def mb_conv_block(inputs, block_args, drop_rate=None, relu_fn=swish, prefix='',
227228
block_args.input_filters * block_args.se_ratio
228229
))
229230
se_tensor = layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x)
230-
231+
231232
target_shape = (1, 1, filters) if backend.image_data_format() == 'channels_last' else (filters, 1, 1)
232233
se_tensor = layers.Reshape(target_shape, name=prefix + 'se_reshape')(se_tensor)
233234
se_tensor = layers.Conv2D(num_reduced_filters, 1,

efficientnet/preprocessing.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
import numpy as np
1616
from skimage.transform import resize
1717

18-
MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]
19-
STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]
20-
2118
MAP_INTERPOLATION_TO_ORDER = {
2219
"nearest": 0,
2320
"bilinear": 1,
@@ -39,9 +36,9 @@ def center_crop_and_resize(image, image_size, crop_padding=32, interpolation="bi
3936
offset_width = ((w - padded_center_crop_size) + 1) // 2
4037

4138
image_crop = image[
42-
offset_height : padded_center_crop_size + offset_height,
43-
offset_width : padded_center_crop_size + offset_width,
44-
]
39+
offset_height: padded_center_crop_size + offset_height,
40+
offset_width: padded_center_crop_size + offset_width,
41+
]
4542
resized_image = resize(
4643
image_crop,
4744
(image_size, image_size),
@@ -50,13 +47,3 @@ def center_crop_and_resize(image, image_size, crop_padding=32, interpolation="bi
5047
)
5148

5249
return resized_image
53-
54-
55-
def preprocess_input(x, **kwargs):
56-
assert x.ndim in (3, 4)
57-
assert x.shape[-1] == 3
58-
59-
x = x - np.array(MEAN_RGB)
60-
x = x / np.array(STDDEV_RGB)
61-
62-
return x

efficientnet/tfkeras.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from . import inject_tfkeras_modules, init_tfkeras_custom_objects
22
from . import model
33

4-
from .preprocessing import center_crop_and_resize, preprocess_input
4+
from .preprocessing import center_crop_and_resize
55

66
EfficientNetB0 = inject_tfkeras_modules(model.EfficientNetB0)
77
EfficientNetB1 = inject_tfkeras_modules(model.EfficientNetB1)
@@ -12,4 +12,6 @@
1212
EfficientNetB6 = inject_tfkeras_modules(model.EfficientNetB6)
1313
EfficientNetB7 = inject_tfkeras_modules(model.EfficientNetB7)
1414

15+
preprocess_input = inject_tfkeras_modules(model.preprocess_input)
16+
1517
init_tfkeras_custom_objects()

0 commit comments

Comments
 (0)