1919
2020import numpy as np
2121import tensorflow as tf
22- from tensorflow_addons .layers .optical_flow import correlation_cost , CorrelationCost
22+ from tensorflow_addons .layers .optical_flow import CorrelationCost
2323from tensorflow_addons .utils import test_utils
2424
2525
@@ -31,15 +31,13 @@ def _forward(self, input_a, input_b, kernel_size, max_displacement,
3131 input_a_op = tf .convert_to_tensor (input_a , dtype = tf .float32 )
3232 input_b_op = tf .convert_to_tensor (input_b , dtype = tf .float32 )
3333
34- output = correlation_cost (
35- input_a_op ,
36- input_b_op ,
34+ output = CorrelationCost (
3735 kernel_size = kernel_size ,
3836 max_displacement = max_displacement ,
3937 stride_1 = stride_1 ,
4038 stride_2 = stride_2 ,
4139 pad = pad ,
42- data_format = data_format )
40+ data_format = data_format )([ input_a_op , input_b_op ])
4341
4442 return output
4543
@@ -117,15 +115,13 @@ def _gradients(self, data_format):
117115 input_b_op = tf .convert_to_tensor (input_b )
118116
119117 def correlation_fn (input_a , input_b ):
120- return correlation_cost (
121- input_a ,
122- input_b ,
118+ return CorrelationCost (
123119 kernel_size = kernel_size ,
124120 max_displacement = max_displacement ,
125121 stride_1 = stride_1 ,
126122 stride_2 = stride_2 ,
127123 pad = pad ,
128- data_format = data_format )
124+ data_format = data_format )([ input_a , input_b ])
129125
130126 theoretical , numerical = tf .test .compute_gradient (
131127 correlation_fn , [input_a_op , input_b_op ])
0 commit comments