Skip to content

Commit faffe17

Browse files
authored
FIX: Use correct TF imports (#55)
* FIX: Use correct TF imports
1 parent 6d70ad4 commit faffe17

File tree

17 files changed

+381
-436
lines changed

17 files changed

+381
-436
lines changed

tensorflow_addons/custom_ops/image/python/transform.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,14 @@
1919

2020
import tensorflow as tf
2121
from tensorflow.python.framework import common_shapes
22-
from tensorflow.python.framework import constant_op
23-
from tensorflow.python.framework import dtypes
24-
from tensorflow.python.framework import load_library
2522
from tensorflow.python.framework import ops
26-
from tensorflow.python.ops import array_ops
27-
from tensorflow.python.ops import linalg_ops
28-
from tensorflow.python.ops import math_ops
2923
from tensorflow.python.platform import resource_loader
3024

31-
_image_ops_so = load_library.load_op_library(
25+
_image_ops_so = tf.load_op_library(
3226
resource_loader.get_path_to_datafile("_image_ops.so"))
3327

34-
_IMAGE_DTYPES = set([dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float16,
35-
dtypes.float32, dtypes.float64])
28+
_IMAGE_DTYPES = set([tf.dtypes.uint8, tf.dtypes.int32, tf.dtypes.int64,
29+
tf.dtypes.float16, tf.dtypes.float32, tf.dtypes.float64])
3630

3731
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
3832

@@ -78,7 +72,7 @@ def transform(images,
7872
image_or_images = ops.convert_to_tensor(images, name="images")
7973
transform_or_transforms = ops.convert_to_tensor(transforms,
8074
name="transforms",
81-
dtype=dtypes.float32)
75+
dtype=tf.dtypes.float32)
8276
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
8377
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
8478
elif image_or_images.get_shape().ndims is None:
@@ -93,10 +87,10 @@ def transform(images,
9387
raise TypeError("Images should have rank between 2 and 4.")
9488

9589
if output_shape is None:
96-
output_shape = array_ops.shape(images)[1:3]
90+
output_shape = tf.shape(images)[1:3]
9791

9892
output_shape = ops.convert_to_tensor(output_shape,
99-
dtypes.int32,
93+
tf.dtypes.int32,
10094
name="output_shape")
10195

10296
if not output_shape.get_shape().is_compatible_with([2]):
@@ -147,8 +141,8 @@ def compose_transforms(*transforms):
147141
composed = flat_transforms_to_matrices(transforms[0])
148142
for tr in transforms[1:]:
149143
# Multiply batches of matrices.
150-
composed = math_ops.matmul(composed,
151-
flat_transforms_to_matrices(tr))
144+
composed = tf.matmul(composed,
145+
flat_transforms_to_matrices(tr))
152146
return matrices_to_flat_transforms(composed)
153147

154148

@@ -177,15 +171,15 @@ def flat_transforms_to_matrices(transforms):
177171
raise ValueError("Transforms should be 1D or 2D, got: %s" %
178172
transforms)
179173
# Make the transform(s) 2D in case the input is a single transform.
180-
transforms = array_ops.reshape(transforms,
181-
constant_op.constant([-1, 8]))
182-
num_transforms = array_ops.shape(transforms)[0]
174+
transforms = tf.reshape(transforms,
175+
tf.constant([-1, 8]))
176+
num_transforms = tf.shape(transforms)[0]
183177
# Add a column of ones for the implicit last entry in the matrix.
184-
return array_ops.reshape(
185-
array_ops.concat(
186-
[transforms, array_ops.ones([num_transforms, 1])],
178+
return tf.reshape(
179+
tf.concat(
180+
[transforms, tf.ones([num_transforms, 1])],
187181
axis=1),
188-
constant_op.constant([-1, 3, 3]))
182+
tf.constant([-1, 3, 3]))
189183

190184

191185
def matrices_to_flat_transforms(transform_matrices):
@@ -215,8 +209,8 @@ def matrices_to_flat_transforms(transform_matrices):
215209
raise ValueError("Matrices should be 2D or 3D, got: %s" %
216210
transform_matrices)
217211
# Flatten each matrix.
218-
transforms = array_ops.reshape(transform_matrices,
219-
constant_op.constant([-1, 9]))
212+
transforms = tf.reshape(transform_matrices,
213+
tf.constant([-1, 9]))
220214
# Divide each matrix by the last entry (normally 1).
221215
transforms /= transforms[:, 8:9]
222216
return transforms[:, :8]
@@ -243,28 +237,28 @@ def angles_to_projective_transforms(angles,
243237
with ops.name_scope(name, "angles_to_projective_transforms"):
244238
angle_or_angles = ops.convert_to_tensor(angles,
245239
name="angles",
246-
dtype=dtypes.float32)
240+
dtype=tf.dtypes.float32)
247241
if len(angle_or_angles.get_shape()) == 0:
248242
angles = angle_or_angles[None]
249243
elif len(angle_or_angles.get_shape()) == 1:
250244
angles = angle_or_angles
251245
else:
252246
raise TypeError("Angles should have rank 0 or 1.")
253247
x_offset = ((image_width - 1) -
254-
(math_ops.cos(angles) * (image_width - 1) -
255-
math_ops.sin(angles) * (image_height - 1))) / 2.0
248+
(tf.math.cos(angles) * (image_width - 1) -
249+
tf.math.sin(angles) * (image_height - 1))) / 2.0
256250
y_offset = ((image_height - 1) -
257-
(math_ops.sin(angles) * (image_width - 1) +
258-
math_ops.cos(angles) * (image_height - 1))) / 2.0
259-
num_angles = array_ops.shape(angles)[0]
260-
return array_ops.concat(
261-
values=[math_ops.cos(angles)[:, None],
262-
-math_ops.sin(angles)[:, None],
251+
(tf.math.sin(angles) * (image_width - 1) +
252+
tf.math.cos(angles) * (image_height - 1))) / 2.0
253+
num_angles = tf.shape(angles)[0]
254+
return tf.concat(
255+
values=[tf.math.cos(angles)[:, None],
256+
-tf.math.sin(angles)[:, None],
263257
x_offset[:, None],
264-
math_ops.sin(angles)[:, None],
265-
math_ops.cos(angles)[:, None],
258+
tf.math.sin(angles)[:, None],
259+
tf.math.cos(angles)[:, None],
266260
y_offset[:, None],
267-
array_ops.zeros((num_angles, 2), dtypes.float32),],
261+
tf.zeros((num_angles, 2), tf.dtypes.float32),],
268262
axis=1)
269263

270264

@@ -278,7 +272,7 @@ def _image_projective_transform_grad(op, grad):
278272
image_or_images = ops.convert_to_tensor(images, name="images")
279273
transform_or_transforms = ops.convert_to_tensor(transforms,
280274
name="transforms",
281-
dtype=dtypes.float32)
275+
dtype=tf.dtypes.float32)
282276

283277
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
284278
raise TypeError("Invalid dtype %s." % image_or_images.dtype)
@@ -291,11 +285,11 @@ def _image_projective_transform_grad(op, grad):
291285

292286
# Invert transformations
293287
transforms = flat_transforms_to_matrices(transforms=transforms)
294-
inverse = linalg_ops.matrix_inverse(transforms)
288+
inverse = tf.linalg.inv(transforms)
295289
transforms = matrices_to_flat_transforms(inverse)
296290
output = _image_ops_so.image_projective_transform(
297291
images=grad,
298292
transforms=transforms,
299-
output_shape=array_ops.shape(image_or_images)[1:3],
293+
output_shape=tf.shape(image_or_images)[1:3],
300294
interpolation=interpolation)
301295
return [output, None, None]

tensorflow_addons/custom_ops/image/python/transform_test.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,31 @@
1919
from __future__ import print_function
2020

2121
import numpy as np
22+
import tensorflow as tf
2223

23-
from tensorflow.python.framework import constant_op
24-
from tensorflow.python.framework import dtypes
2524
from tensorflow.python.framework import test_util as tf_test_util
2625
from tensorflow.python.ops import gradient_checker
27-
from tensorflow.python.ops import random_ops
28-
from tensorflow.python.platform import test
2926
from tensorflow_addons.custom_ops.image.python import transform as transform_ops
3027

31-
_DTYPES = set([dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float16,
32-
dtypes.float32, dtypes.float64])
28+
_DTYPES = set([tf.dtypes.uint8, tf.dtypes.int32, tf.dtypes.int64,
29+
tf.dtypes.float16, tf.dtypes.float32, tf.dtypes.float64])
3330

3431

35-
class ImageOpsTest(test.TestCase):
32+
class ImageOpsTest(tf.test.TestCase):
3633
@tf_test_util.run_all_in_graph_and_eager_modes
3734
def test_compose(self):
3835
for dtype in _DTYPES:
39-
image = constant_op.constant(
36+
image = tf.constant(
4037
[[1, 1, 1, 0], [1, 0, 0, 0], [1, 1, 1, 0], [0, 0, 0, 0]],
4138
dtype=dtype)
4239
# Rotate counter-clockwise by pi / 2.
4340
rotation = transform_ops.angles_to_projective_transforms(np.pi / 2,
4441
4, 4)
4542
# Translate right by 1 (the transformation matrix is always inverted,
4643
# hence the -1).
47-
translation = constant_op.constant(
44+
translation = tf.constant(
4845
[1, 0, -1, 0, 1, 0, 0, 0],
49-
dtype=dtypes.float32)
46+
dtype=tf.dtypes.float32)
5047
composed = transform_ops.compose_transforms(rotation, translation)
5148
image_transformed = transform_ops.transform(image, composed)
5249
self.assertAllEqual(
@@ -56,31 +53,31 @@ def test_compose(self):
5653
@tf_test_util.run_all_in_graph_and_eager_modes
5754
def test_extreme_projective_transform(self):
5855
for dtype in _DTYPES:
59-
image = constant_op.constant(
56+
image = tf.constant(
6057
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]],
6158
dtype=dtype)
62-
transformation = constant_op.constant(
63-
[1, 0, 0, 0, 1, 0, -1, 0], dtypes.float32)
59+
transformation = tf.constant(
60+
[1, 0, 0, 0, 1, 0, -1, 0], tf.dtypes.float32)
6461
image_transformed = transform_ops.transform(image, transformation)
6562
self.assertAllEqual(
6663
[[1, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 0],
6764
[0, 0, 0, 0]], image_transformed)
6865

6966
@tf_test_util.run_all_in_graph_and_eager_modes
7067
def test_transform_static_output_shape(self):
71-
image = constant_op.constant([[1., 2.], [3., 4.]])
68+
image = tf.constant([[1., 2.], [3., 4.]])
7269
result = transform_ops.transform(
7370
image,
74-
random_ops.random_uniform(
71+
tf.random.uniform(
7572
[8], -1, 1),
76-
output_shape=constant_op.constant([3, 5]))
73+
output_shape=tf.constant([3, 5]))
7774
self.assertAllEqual([3, 5], result.shape)
7875

7976
def _test_grad(self, shape_to_test):
8077
with self.cached_session():
8178
test_image_shape = shape_to_test
8279
test_image = np.random.randn(*test_image_shape)
83-
test_image_tensor = constant_op.constant(test_image,
80+
test_image_tensor = tf.constant(test_image,
8481
shape=test_image_shape)
8582
test_transform = transform_ops.angles_to_projective_transforms(
8683
np.pi / 2, 4, 4)
@@ -99,7 +96,7 @@ def _test_grad_different_shape(self, input_shape, output_shape):
9996
with self.cached_session():
10097
test_image_shape = input_shape
10198
test_image = np.random.randn(*test_image_shape)
102-
test_image_tensor = constant_op.constant(test_image,
99+
test_image_tensor = tf.constant(test_image,
103100
shape=test_image_shape)
104101
test_transform = transform_ops.angles_to_projective_transforms(
105102
np.pi / 2, 4, 4)
@@ -134,19 +131,19 @@ def test_grad(self):
134131
@tf_test_util.run_all_in_graph_and_eager_modes
135132
def test_transform_data_types(self):
136133
for dtype in _DTYPES:
137-
image = constant_op.constant([[1, 2], [3, 4]], dtype=dtype)
134+
image = tf.constant([[1, 2], [3, 4]], dtype=dtype)
138135
with self.test_session(use_gpu=True):
139136
self.assertAllEqual(
140137
np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype()),
141138
transform_ops.transform(image, [1] * 8))
142139

143140
@tf_test_util.run_all_in_graph_and_eager_modes
144141
def test_transform_eager(self):
145-
image = constant_op.constant([[1., 2.], [3., 4.]])
142+
image = tf.constant([[1., 2.], [3., 4.]])
146143
self.assertAllEqual(
147144
np.array([[4, 4], [4, 4]]),
148145
transform_ops.transform(image, [1] * 8))
149146

150147

151148
if __name__ == "__main__":
152-
test.main()
149+
tf.test.main()

tensorflow_addons/custom_ops/text/python/skip_gram_ops.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,14 @@
1818
from __future__ import print_function
1919

2020
import csv
21+
import tensorflow as tf
2122

22-
from tensorflow.python.framework import dtypes
23-
from tensorflow.python.framework import load_library
24-
from tensorflow.python.framework import ops
2523
from tensorflow.python.framework import random_seed
26-
from tensorflow.python.ops import array_ops
24+
from tensorflow.python.framework import ops
2725
from tensorflow.python.ops import lookup_ops
28-
from tensorflow.python.ops import math_ops
29-
from tensorflow.python.ops import random_ops
30-
from tensorflow.python.platform import gfile
3126
from tensorflow.python.platform import resource_loader
32-
from tensorflow.python.training import input as input_ops
3327

34-
skip_gram_ops = load_library.load_op_library(
28+
skip_gram_ops = tf.load_op_library(
3529
resource_loader.get_path_to_datafile("_skip_gram_ops.so"))
3630

3731
ops.NotDifferentiable("SkipGramGenerateCandidates")
@@ -197,7 +191,7 @@ def skip_gram_sample(input_tensor,
197191
batch_capacity = (batch_capacity
198192
if (batch_capacity is not None and batch_capacity > 0)
199193
else 100 * batch_size)
200-
return input_ops.batch(
194+
return tf.train.batch(
201195
[tokens, labels],
202196
batch_size,
203197
capacity=batch_capacity,
@@ -209,9 +203,9 @@ def skip_gram_sample(input_tensor,
209203
def skip_gram_sample_with_text_vocab(input_tensor,
210204
vocab_freq_file,
211205
vocab_token_index=0,
212-
vocab_token_dtype=dtypes.string,
206+
vocab_token_dtype=tf.dtypes.string,
213207
vocab_freq_index=1,
214-
vocab_freq_dtype=dtypes.float64,
208+
vocab_freq_dtype=tf.dtypes.float64,
215209
vocab_delimiter=",",
216210
vocab_min_count=0,
217211
vocab_subsampling=None,
@@ -330,7 +324,7 @@ def skip_gram_sample_with_text_vocab(input_tensor,
330324
# vocab terms).
331325
calculated_corpus_size = 0.0
332326
vocab_size = 0
333-
with gfile.GFile(vocab_freq_file, mode="r") as f:
327+
with tf.io.gfile.GFile(vocab_freq_file, mode="r") as f:
334328
reader = csv.reader(f, delimiter=vocab_delimiter)
335329
for row in reader:
336330
if vocab_token_index >= len(row) or vocab_freq_index >= len(row):
@@ -404,16 +398,16 @@ def _filter_input(input_tensor, vocab_freq_table, vocab_min_count,
404398
# Filters out elements in input_tensor that are not found in
405399
# vocab_freq_table (table returns a default value of -1 specified above when
406400
# an element is not found).
407-
mask = math_ops.not_equal(freq, vocab_freq_table.default_value)
401+
mask = tf.math.not_equal(freq, vocab_freq_table.default_value)
408402

409403
# Filters out elements whose vocab frequencies are less than the threshold.
410404
if vocab_min_count is not None:
411-
cast_threshold = math_ops.cast(vocab_min_count, freq.dtype)
412-
mask = math_ops.logical_and(mask,
413-
math_ops.greater_equal(freq, cast_threshold))
405+
cast_threshold = tf.cast(vocab_min_count, freq.dtype)
406+
mask = tf.math.logical_and(mask,
407+
tf.math.greater_equal(freq, cast_threshold))
414408

415-
input_tensor = array_ops.boolean_mask(input_tensor, mask)
416-
freq = array_ops.boolean_mask(freq, mask)
409+
input_tensor = tf.boolean_mask(input_tensor, mask)
410+
freq = tf.boolean_mask(freq, mask)
417411

418412
if not vocab_subsampling:
419413
return input_tensor
@@ -428,21 +422,21 @@ def _filter_input(input_tensor, vocab_freq_table, vocab_min_count,
428422
# tokens).
429423
with ops.name_scope(
430424
"subsample_vocab", values=[input_tensor, freq, vocab_subsampling]):
431-
corpus_size = math_ops.cast(corpus_size, dtypes.float64)
432-
freq = math_ops.cast(freq, dtypes.float64)
433-
vocab_subsampling = math_ops.cast(vocab_subsampling, dtypes.float64)
425+
corpus_size = tf.cast(corpus_size, tf.dtypes.float64)
426+
freq = tf.cast(freq, tf.dtypes.float64)
427+
vocab_subsampling = tf.cast(vocab_subsampling, tf.dtypes.float64)
434428

435429
# From tensorflow_models/tutorials/embedding/word2vec_kernels.cc, which is
436430
# suppose to correlate with Eq. 5 in http://arxiv.org/abs/1310.4546.
437-
keep_prob = ((math_ops.sqrt(freq /
431+
keep_prob = ((tf.math.sqrt(freq /
438432
(vocab_subsampling * corpus_size)) + 1.0) *
439433
(vocab_subsampling * corpus_size / freq))
440-
random_prob = random_ops.random_uniform(
441-
array_ops.shape(freq),
434+
random_prob = tf.random.uniform(
435+
tf.shape(freq),
442436
minval=0,
443437
maxval=1,
444-
dtype=dtypes.float64,
438+
dtype=tf.dtypes.float64,
445439
seed=seed)
446440

447-
mask = math_ops.less_equal(random_prob, keep_prob)
448-
return array_ops.boolean_mask(input_tensor, mask)
441+
mask = tf.math.less_equal(random_prob, keep_prob)
442+
return tf.boolean_mask(input_tensor, mask)

0 commit comments

Comments
 (0)