Skip to content

Commit 8064035

Browse files
fsx950223seanpmorgan
authored andcommitted
Fix optical_flow test case (#527)
1 parent d92e084 commit 8064035

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

tensorflow_addons/layers/optical_flow_test.py

100644100755
Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import numpy as np
2121
import tensorflow as tf
22-
from tensorflow_addons.layers.optical_flow import correlation_cost, CorrelationCost
22+
from tensorflow_addons.layers.optical_flow import CorrelationCost
2323
from tensorflow_addons.utils import test_utils
2424

2525

@@ -31,15 +31,13 @@ def _forward(self, input_a, input_b, kernel_size, max_displacement,
3131
input_a_op = tf.convert_to_tensor(input_a, dtype=tf.float32)
3232
input_b_op = tf.convert_to_tensor(input_b, dtype=tf.float32)
3333

34-
output = correlation_cost(
35-
input_a_op,
36-
input_b_op,
34+
output = CorrelationCost(
3735
kernel_size=kernel_size,
3836
max_displacement=max_displacement,
3937
stride_1=stride_1,
4038
stride_2=stride_2,
4139
pad=pad,
42-
data_format=data_format)
40+
data_format=data_format)([input_a_op, input_b_op])
4341

4442
return output
4543

@@ -117,15 +115,13 @@ def _gradients(self, data_format):
117115
input_b_op = tf.convert_to_tensor(input_b)
118116

119117
def correlation_fn(input_a, input_b):
120-
return correlation_cost(
121-
input_a,
122-
input_b,
118+
return CorrelationCost(
123119
kernel_size=kernel_size,
124120
max_displacement=max_displacement,
125121
stride_1=stride_1,
126122
stride_2=stride_2,
127123
pad=pad,
128-
data_format=data_format)
124+
data_format=data_format)([input_a, input_b])
129125

130126
theoretical, numerical = tf.test.compute_gradient(
131127
correlation_fn, [input_a_op, input_b_op])

0 commit comments

Comments
 (0)