Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit e05f2cb

Browse files
authored
Convert models from Pytorch to Keras (#99)
* Convert models from Pytorch to Keras * Convert ResNet 50, 101, 152 models * Rebase; unify layer names b/t frameworks; produce samples * Verify accuracy of models using both Keras and Pytorch pipeline * Add args for arch key and model type; make style
1 parent 5bae8d5 commit e05f2cb

File tree

13 files changed

+1187
-23
lines changed

13 files changed

+1187
-23
lines changed

scripts/keras_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@
304304
from sparseml.keras.utils import (
305305
LossesAndMetricsLoggingCallback,
306306
ModelExporter,
307-
keras,
308307
TensorBoardLogger,
308+
keras,
309309
)
310310
from sparseml.utils import create_dirs
311311

src/sparseml/keras/datasets/classification/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,5 @@
1919
# flake8: noqa
2020

2121
from .imagefolder import *
22+
from .imagenet import *
2223
from .imagenette import *

src/sparseml/keras/datasets/classification/imagefolder.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
def imagenet_normalizer(img: tensorflow.Tensor, mode: str):
4949
"""
5050
Normalize an image using mean and std of the imagenet dataset
51-
5251
:param img: The input image to normalize
5352
:param mode: either "tf", "caffe", "torch"
5453
:return: The normalized image
@@ -73,7 +72,7 @@ def imagenet_normalizer(img: tensorflow.Tensor, mode: str):
7372

7473
def default_imagenet_normalizer():
7574
def normalizer(img: tensorflow.Tensor):
76-
# Default to the same preprocessing used by ResNet
75+
# Default to the same preprocessing used by Keras Applications ResNet
7776
return imagenet_normalizer(img, "caffe")
7877

7978
return normalizer
@@ -109,7 +108,7 @@ def __init__(
109108
self,
110109
root: str,
111110
train: bool,
112-
image_size: Union[int, Tuple[int, int]] = 224,
111+
image_size: Union[None, int, Tuple[int, int]] = 224,
113112
pre_resize_transforms: Union[SplitsTransforms, None] = SplitsTransforms(
114113
train=(
115114
random_scaling_crop(),
@@ -126,9 +125,14 @@ def __init__(
126125
if not os.path.exists(self._root):
127126
raise ValueError("Data set folder {} must exist".format(self._root))
128127
self._train = train
129-
self._image_size = (
130-
image_size if isinstance(image_size, tuple) else (image_size, image_size)
131-
)
128+
if image_size is not None:
129+
self._image_size = (
130+
image_size
131+
if isinstance(image_size, tuple)
132+
else (image_size, image_size)
133+
)
134+
else:
135+
self._image_size = None
132136
self._pre_resize_transforms = pre_resize_transforms
133137
self._post_resize_transforms = post_resize_transforms
134138

@@ -199,7 +203,6 @@ def processor(self, file_path: tensorflow.Tensor, label: tensorflow.Tensor):
199203
"""
200204
img = tensorflow.io.read_file(file_path)
201205
img = tensorflow.image.decode_jpeg(img, channels=3)
202-
203206
if self.pre_resize_transforms:
204207
transforms = (
205208
self.pre_resize_transforms.train
@@ -209,7 +212,7 @@ def processor(self, file_path: tensorflow.Tensor, label: tensorflow.Tensor):
209212
if transforms:
210213
for trans in transforms:
211214
img = trans(img)
212-
if self._image_size:
215+
if self._image_size is not None:
213216
img = tensorflow.image.resize(img, self.image_size)
214217

215218
if self.post_resize_transforms:
@@ -221,7 +224,6 @@ def processor(self, file_path: tensorflow.Tensor, label: tensorflow.Tensor):
221224
if transforms:
222225
for trans in transforms:
223226
img = trans(img)
224-
225227
return img, label
226228

227229
def creator(self):
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
Imagenet dataset implementations for the image classification field in computer vision.
17+
More info for the dataset can be found `here <http://www.image-net.org/>`__.
18+
"""
19+
20+
import random
21+
from typing import Tuple, Union
22+
23+
import tensorflow as tf
24+
25+
from sparseml.keras.datasets.classification import (
26+
ImageFolderDataset,
27+
SplitsTransforms,
28+
imagenet_normalizer,
29+
)
30+
from sparseml.keras.datasets.helpers import random_scaling_crop
31+
from sparseml.keras.datasets.registry import DatasetRegistry
32+
from sparseml.keras.utils import keras
33+
from sparseml.utils import clean_path
34+
from sparseml.utils.datasets import (
35+
IMAGENET_RGB_MEANS,
36+
IMAGENET_RGB_STDS,
37+
default_dataset_path,
38+
)
39+
40+
41+
__all__ = ["ImageNetDataset"]
42+
43+
44+
def torch_imagenet_normalizer():
45+
def normalizer(image: tf.Tensor):
46+
return imagenet_normalizer(image, "torch")
47+
48+
return normalizer
49+
50+
51+
def imagenet_pre_resize_processor():
52+
def processor(image: tf.Tensor):
53+
image_batch = tf.expand_dims(image, axis=0)
54+
55+
# Resize the image the following way to match torchvision's Resize
56+
# transform used by Pytorch code path for Imagenet:
57+
# torchvision.transforms.Resize(256)
58+
# which resize the smaller side of images to 256 and the other one based
59+
# on the aspect ratio
60+
shape = tf.shape(image)
61+
h, w = shape[0], shape[1]
62+
if h > w:
63+
new_h, new_w = tf.cast(256 * h / w, dtype=tf.uint16), tf.constant(
64+
256, dtype=tf.uint16
65+
)
66+
else:
67+
new_h, new_w = tf.constant(256, dtype=tf.uint16), tf.cast(
68+
256 * w / h, dtype=tf.uint16
69+
)
70+
resizer = keras.layers.experimental.preprocessing.Resizing(new_h, new_w)
71+
image_batch = tf.cast(resizer(image_batch), dtype=tf.uint8)
72+
73+
# Center crop
74+
center_cropper = keras.layers.experimental.preprocessing.CenterCrop(224, 224)
75+
image_batch = tf.cast(center_cropper(image_batch), dtype=tf.uint8)
76+
77+
return image_batch[0, :]
78+
79+
return processor
80+
81+
82+
@DatasetRegistry.register(
83+
key=["imagenet"],
84+
attributes={
85+
"num_classes": 1000,
86+
"transform_means": IMAGENET_RGB_MEANS,
87+
"transform_stds": IMAGENET_RGB_STDS,
88+
},
89+
)
90+
class ImageNetDataset(ImageFolderDataset):
91+
"""
92+
Wrapper for the ImageNet dataset to apply standard transforms.
93+
94+
:param root: The root folder to find the dataset at
95+
:param train: True if this is for the training distribution,
96+
False for the validation
97+
:param rand_trans: True to apply RandomCrop and RandomHorizontalFlip to the data,
98+
False otherwise
99+
:param image_size: the size of the image to output from the dataset
100+
"""
101+
102+
def __init__(
103+
self,
104+
root: str = default_dataset_path("imagenet"),
105+
train: bool = True,
106+
rand_trans: bool = False,
107+
image_size: Union[None, int, Tuple[int, int]] = 224,
108+
pre_resize_transforms=SplitsTransforms(
109+
train=(
110+
random_scaling_crop(),
111+
tf.image.random_flip_left_right,
112+
),
113+
val=(imagenet_pre_resize_processor(),),
114+
),
115+
post_resize_transforms=SplitsTransforms(
116+
train=(torch_imagenet_normalizer(),), val=(torch_imagenet_normalizer(),)
117+
),
118+
):
119+
root = clean_path(root)
120+
super().__init__(
121+
root,
122+
train,
123+
image_size=image_size,
124+
pre_resize_transforms=pre_resize_transforms,
125+
post_resize_transforms=post_resize_transforms,
126+
)
127+
128+
if train:
129+
# make sure we don't preserve the folder structure class order
130+
random.shuffle(self.samples)

src/sparseml/keras/datasets/dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def build(
4646
) -> tensorflow.data.Dataset:
4747
"""
4848
Create the dataset in the current graph using tensorflow.data APIs
49-
5049
:param batch_size: the batch size to create the dataset for
5150
:param repeat_count: the number of times to repeat the dataset,
5251
if unset or None, will repeat indefinitely

src/sparseml/keras/datasets/helpers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def random_scaling_crop(
3333
"""
3434
Random crop implementation which also randomly scales the crop taken
3535
as well as the aspect ratio of the crop.
36-
3736
:param scale_range: the (min, max) of the crop scales to take from the orig image
3837
:param ratio_range: the (min, max) of the aspect ratios to take from the orig image
3938
:return: the callable function for random scaling crop op,

src/sparseml/keras/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818

1919
# flake8: noqa
2020

21+
from .classification import *
2122
from .external import *
2223
from .registry import *
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .resnet import *

0 commit comments

Comments
 (0)