Skip to content

Commit a4882be

Browse files
WindQAQfacaiy
authored andcommitted
Migrate distort image ops (#67)
* migrate distort_image_ops * add test for distort_image_ops * modify BUILD file for distort_image_ops * add tf.fucntion decorator * fix assert regex * clean up internal api * update README * fix copyright * import *_hsv_in_yiq * fix name scope error * code format * remove sessions * fix wrong decorators * add TODO comments * remove tf_test_util * clean up messy stuff
1 parent f29c204 commit a4882be

File tree

9 files changed

+920
-0
lines changed

9 files changed

+920
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ developments that cannot be integrated into core TensorFlow
1616
| Sub-Package | Addon | Reference |
1717
|:----------------------- |:----------- |:---------------------------- |
1818
| tfa.activations | Sparsemax | https://arxiv.org/abs/1602.02068 |
19+
| tfa.image | adjust_hsv_in_yiq | |
20+
| tfa.image | random_hsv_in_yiq | |
1921
| tfa.image | transform | |
2022
| tfa.layers | GroupNormalization | https://arxiv.org/abs/1803.08494 |
2123
| tfa.layers | InstanceNormalization | https://arxiv.org/abs/1607.08022 |

tensorflow_addons/custom_ops/image/BUILD

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@ licenses(["notice"]) # Apache 2.0
22

33
package(default_visibility = ["//visibility:public"])
44

5+
cc_binary(
6+
name = "python/_distort_image_ops.so",
7+
srcs = [
8+
"cc/kernels/adjust_hsv_in_yiq_op.cc",
9+
"cc/kernels/adjust_hsv_in_yiq_op.h",
10+
"cc/ops/distort_image_ops.cc",
11+
],
12+
linkshared = 1,
13+
deps = [
14+
"@local_config_tf//:libtensorflow_framework",
15+
"@local_config_tf//:tf_header_lib",
16+
],
17+
copts = ["-pthread", "-std=c++11", "-D_GLIBCXX_USE_CXX11_ABI=0"]
18+
)
19+
20+
521
cc_binary(
622
name = "python/_image_ops.so",
723
srcs = [
@@ -26,15 +42,30 @@ py_library(
2642
srcs = ([
2743
"__init__.py",
2844
"python/__init__.py",
45+
"python/distort_image_ops.py",
2946
"python/transform.py",
3047
]),
3148
data = [
49+
":python/_distort_image_ops.so",
3250
":python/_image_ops.so",
3351
"//tensorflow_addons/utils:utils_py",
3452
],
3553
srcs_version = "PY2AND3",
3654
)
3755

56+
py_test(
57+
name = "distort_image_ops_test",
58+
size = "small",
59+
srcs = [
60+
"python/distort_image_ops_test.py",
61+
],
62+
main = "python/distort_image_ops_test.py",
63+
deps = [
64+
":images_ops_py",
65+
],
66+
srcs_version = "PY2AND3"
67+
)
68+
3869
# TODO: use cuda_py_test later.
3970
py_test(
4071
name = "transform_ops_test",

tensorflow_addons/custom_ops/image/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20+
from tensorflow_addons.custom_ops.image.python.distort_image_ops import adjust_hsv_in_yiq
21+
from tensorflow_addons.custom_ops.image.python.distort_image_ops import random_hsv_in_yiq
2022
# Transforms
2123
from tensorflow_addons.custom_ops.image.python.transform import transform
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#if GOOGLE_CUDA
17+
#define EIGEN_USE_GPU
18+
#endif // GOOGLE_CUDA
19+
20+
#include <memory>
21+
22+
#include "tensorflow/core/framework/register_types.h"
23+
#include "tensorflow/core/framework/tensor.h"
24+
#include "tensorflow/core/framework/tensor_shape.h"
25+
#include "tensorflow/core/lib/core/status.h"
26+
#include "tensorflow/core/platform/logging.h"
27+
#include "tensorflow/core/util/work_sharder.h"
28+
#include "tensorflow_addons/custom_ops/image/cc/kernels/adjust_hsv_in_yiq_op.h"
29+
30+
namespace tensorflow {
31+
32+
typedef Eigen::ThreadPoolDevice CPUDevice;
33+
typedef Eigen::GpuDevice GPUDevice;
34+
35+
class AdjustHsvInYiqOpBase : public OpKernel {
36+
protected:
37+
explicit AdjustHsvInYiqOpBase(OpKernelConstruction* context)
38+
: OpKernel(context) {}
39+
40+
struct ComputeOptions {
41+
const Tensor* input = nullptr;
42+
Tensor* output = nullptr;
43+
const Tensor* delta_h = nullptr;
44+
const Tensor* scale_s = nullptr;
45+
const Tensor* scale_v = nullptr;
46+
int64 channel_count = 0;
47+
};
48+
49+
virtual void DoCompute(OpKernelContext* context,
50+
const ComputeOptions& options) = 0;
51+
52+
void Compute(OpKernelContext* context) override {
53+
const Tensor& input = context->input(0);
54+
const Tensor& delta_h = context->input(1);
55+
const Tensor& scale_s = context->input(2);
56+
const Tensor& scale_v = context->input(3);
57+
OP_REQUIRES(context, input.dims() >= 3,
58+
errors::InvalidArgument("input must be at least 3-D, got shape",
59+
input.shape().DebugString()));
60+
OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_h.shape()),
61+
errors::InvalidArgument("delta_h must be scalar: ",
62+
delta_h.shape().DebugString()));
63+
OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_s.shape()),
64+
errors::InvalidArgument("scale_s must be scalar: ",
65+
scale_s.shape().DebugString()));
66+
OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_v.shape()),
67+
errors::InvalidArgument("scale_v must be scalar: ",
68+
scale_v.shape().DebugString()));
69+
auto channels = input.dim_size(input.dims() - 1);
70+
OP_REQUIRES(
71+
context, channels == kChannelSize,
72+
errors::InvalidArgument("input must have 3 channels but instead has ",
73+
channels, " channels."));
74+
75+
Tensor* output = nullptr;
76+
OP_REQUIRES_OK(context,
77+
context->allocate_output(0, input.shape(), &output));
78+
79+
if (input.NumElements() > 0) {
80+
const int64 channel_count = input.NumElements() / channels;
81+
ComputeOptions options;
82+
options.input = &input;
83+
options.delta_h = &delta_h;
84+
options.scale_s = &scale_s;
85+
options.scale_v = &scale_v;
86+
options.output = output;
87+
options.channel_count = channel_count;
88+
DoCompute(context, options);
89+
}
90+
}
91+
};
92+
93+
template <class Device>
94+
class AdjustHsvInYiqOp;
95+
96+
template <>
97+
class AdjustHsvInYiqOp<CPUDevice> : public AdjustHsvInYiqOpBase {
98+
public:
99+
explicit AdjustHsvInYiqOp(OpKernelConstruction* context)
100+
: AdjustHsvInYiqOpBase(context) {}
101+
102+
void DoCompute(OpKernelContext* context,
103+
const ComputeOptions& options) override {
104+
const Tensor* input = options.input;
105+
Tensor* output = options.output;
106+
const int64 channel_count = options.channel_count;
107+
auto input_data = input->shaped<float, 2>({channel_count, kChannelSize});
108+
const float delta_h = options.delta_h->scalar<float>()();
109+
const float scale_s = options.scale_s->scalar<float>()();
110+
const float scale_v = options.scale_v->scalar<float>()();
111+
auto output_data = output->shaped<float, 2>({channel_count, kChannelSize});
112+
float tranformation_matrix[kChannelSize * kChannelSize] = {0};
113+
internal::compute_tranformation_matrix<kChannelSize * kChannelSize>(
114+
delta_h, scale_s, scale_v, tranformation_matrix);
115+
const int kCostPerChannel = 10;
116+
const DeviceBase::CpuWorkerThreads& worker_threads =
117+
*context->device()->tensorflow_cpu_worker_threads();
118+
Shard(worker_threads.num_threads, worker_threads.workers, channel_count,
119+
kCostPerChannel, [&input_data, &output_data, &tranformation_matrix](
120+
int64 start_channel, int64 end_channel) {
121+
// Applying projection matrix to input RGB vectors.
122+
const float* p = input_data.data() + start_channel * kChannelSize;
123+
float* q = output_data.data() + start_channel * kChannelSize;
124+
for (int i = start_channel; i < end_channel; i++) {
125+
for (int q_index = 0; q_index < kChannelSize; q_index++) {
126+
q[q_index] = 0;
127+
for (int p_index = 0; p_index < kChannelSize; p_index++) {
128+
q[q_index] +=
129+
p[p_index] *
130+
tranformation_matrix[q_index + kChannelSize * p_index];
131+
}
132+
}
133+
p += kChannelSize;
134+
q += kChannelSize;
135+
}
136+
});
137+
}
138+
};
139+
140+
REGISTER_KERNEL_BUILDER(
141+
Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint<float>("T"),
142+
AdjustHsvInYiqOp<CPUDevice>);
143+
144+
#if GOOGLE_CUDA
145+
template <>
146+
class AdjustHsvInYiqOp<GPUDevice> : public AdjustHsvInYiqOpBase {
147+
public:
148+
explicit AdjustHsvInYiqOp(OpKernelConstruction* context)
149+
: AdjustHsvInYiqOpBase(context) {}
150+
151+
void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override {
152+
const int64 number_of_elements = options.input->NumElements();
153+
if (number_of_elements <= 0) {
154+
return;
155+
}
156+
const float* delta_h = options.delta_h->flat<float>().data();
157+
const float* scale_s = options.scale_s->flat<float>().data();
158+
const float* scale_v = options.scale_v->flat<float>().data();
159+
functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input,
160+
delta_h, scale_s, scale_v, options.output);
161+
}
162+
};
163+
164+
REGISTER_KERNEL_BUILDER(
165+
Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint<float>("T"),
166+
AdjustHsvInYiqOp<GPUDevice>);
167+
#endif
168+
169+
} // namespace tensorflow
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software
9+
distributed under the License is distributed on an "AS IS" BASIS,
10+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
See the License for the specific language governing permissions and
12+
limitations under the License.
13+
==============================================================================*/
14+
#ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
15+
#define TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_
16+
17+
#if GOOGLE_CUDA
18+
#define EIGEN_USE_GPU
19+
#endif // GOOGLE_CUDA
20+
21+
#include <cmath>
22+
#include "third_party/eigen3/Eigen/Core"
23+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24+
25+
#include "tensorflow/core/framework/op_kernel.h"
26+
#include "tensorflow/core/framework/register_types.h"
27+
#include "tensorflow/core/framework/types.h"
28+
29+
namespace tensorflow {
30+
31+
static constexpr int kChannelSize = 3;
32+
33+
namespace internal {
34+
35+
template <int MATRIX_SIZE>
36+
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void compute_tranformation_matrix(
37+
const float delta_h, const float scale_s, const float scale_v,
38+
float* matrix) {
39+
static_assert(MATRIX_SIZE == kChannelSize * kChannelSize,
40+
"Size of matrix should be 9.");
41+
// Projection matrix from RGB to YIQ. Numbers from wikipedia
42+
// https://en.wikipedia.org/wiki/YIQ
43+
Eigen::Matrix3f yiq;
44+
/* clang-format off */
45+
yiq << 0.299, 0.587, 0.114,
46+
0.596, -0.274, -0.322,
47+
0.211, -0.523, 0.312;
48+
Eigen::Matrix3f yiq_inverse;
49+
yiq_inverse << 1, 0.95617069, 0.62143257,
50+
1, -0.2726886, -0.64681324,
51+
1, -1.103744, 1.70062309;
52+
/* clang-format on */
53+
// Construct hsv linear transformation matrix in YIQ space.
54+
// https://beesbuzz.biz/code/hsv_color_transforms.php
55+
float vsu = scale_v * scale_s * std::cos(delta_h);
56+
float vsw = scale_v * scale_s * std::sin(delta_h);
57+
Eigen::Matrix3f hsv_transform;
58+
/* clang-format off */
59+
hsv_transform << scale_v, 0, 0,
60+
0, vsu, -vsw,
61+
0, vsw, vsu;
62+
/* clang-format on */
63+
// Compute final transformation matrix = inverse_yiq * hsv_transform * yiq
64+
Eigen::Map<Eigen::Matrix<float, 3, 3, Eigen::ColMajor>> eigen_matrix(matrix);
65+
eigen_matrix = yiq_inverse * hsv_transform * yiq;
66+
}
67+
} // namespace internal
68+
69+
#if GOOGLE_CUDA
70+
typedef Eigen::GpuDevice GPUDevice;
71+
72+
namespace functor {
73+
74+
struct AdjustHsvInYiqGPU {
75+
void operator()(OpKernelContext* ctx, int channel_count,
76+
const Tensor* const input, const float* const delta_h,
77+
const float* const scale_s, const float* const scale_v,
78+
Tensor* const output);
79+
};
80+
81+
} // namespace functor
82+
83+
#endif // GOOGLE_CUDA
84+
85+
} // namespace tensorflow
86+
87+
#endif // TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_

0 commit comments

Comments
 (0)