Skip to content

Commit 2d977c8

Browse files
sayoojbkWindQAQ
authored andcommitted
Connected components (#409)
* ported to tf2.0 ,but need to implement gpu kernel * ported the connected_comp to tf2.0 - GPU kernel to be done * removed some pylin errors * Update connected_components.h * Update connected_components.h * ops name changed to connected_component from segmentation * added connected_components to build file * removed non-whitelited pylint errors * Revert "removed non-whitelited pylint errors" This reverts commit 57f14f5. * removed non whitelisted pylint errors * made changes to the kernel files of image BUILD * connected components ported successfully only unknown dim test is remaining * removed the buildifier.* files accidently committed to the PR * made necessary naming convention changes and assigned as submodule maintainer * Updated contributor list * added pylint disable
1 parent 0811550 commit 2d977c8

File tree

9 files changed

+746
-1
lines changed

9 files changed

+746
-1
lines changed

tensorflow_addons/custom_ops/image/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ cc_library(
9393
cc_binary(
9494
name = "_image_ops.so",
9595
srcs = [
96+
"cc/kernels/connected_components.cc",
97+
"cc/kernels/connected_components.h",
9698
"cc/kernels/euclidean_distance_transform_op.cc",
9799
"cc/kernels/euclidean_distance_transform_op.h",
98100
"cc/kernels/image_projective_transform_op.cc",
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
// See docs for ImageConnectedComponents in ../ops/image_ops.cc, and description
16+
// of the algorithm in connected_components.h.
17+
18+
#define EIGEN_USE_THREADS
19+
20+
#include "tensorflow_addons/custom_ops/image/cc/kernels/connected_components.h"
21+
#include "tensorflow/core/framework/op_kernel.h"
22+
#include "tensorflow/core/framework/register_types.h"
23+
#include "tensorflow/core/framework/types.h"
24+
#include "tensorflow/core/platform/types.h"
25+
26+
namespace tensorflow {
27+
28+
using tensorflow::functor::BlockedImageUnionFindFunctor;
29+
using tensorflow::functor::FindRootFunctor;
30+
using tensorflow::functor::ImageConnectedComponentsFunctor;
31+
using tensorflow::functor::TensorRangeFunctor;
32+
33+
using OutputType = typename BlockedImageUnionFindFunctor<bool>::OutputType;
34+
35+
// Computes connected components on batches of 2D images.
36+
template <typename Device, typename T>
37+
class ImageConnectedComponents : public OpKernel {
38+
public:
39+
explicit ImageConnectedComponents(OpKernelConstruction* ctx)
40+
: OpKernel(ctx) {}
41+
42+
void Compute(OpKernelContext* ctx) override {
43+
const Tensor& images_t = ctx->input(0);
44+
OP_REQUIRES(ctx, images_t.shape().dims() == 3,
45+
errors::InvalidArgument("Input images must have rank 3"));
46+
Tensor forest_t, rank_t;
47+
OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_INT64,
48+
images_t.shape(), &forest_t));
49+
OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_INT64,
50+
images_t.shape(), &rank_t));
51+
Tensor* output_t;
52+
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
53+
54+
// Fill forest with values from 0 to n - 1, so that each node points to
55+
// itself.
56+
TensorRangeFunctor<Device>()(ctx->eigen_device<Device>(),
57+
forest_t.flat<OutputType>());
58+
auto rank = rank_t.tensor<OutputType, 3>();
59+
rank.device(ctx->eigen_device<Device>()) = rank.constant(OutputType(0));
60+
61+
const auto images = images_t.tensor<T, 3>();
62+
auto forest = forest_t.tensor<OutputType, 3>();
63+
ImageConnectedComponentsFunctor<Device, T>()(
64+
ctx, output_t->flat<OutputType>(), images, forest, rank);
65+
}
66+
};
67+
68+
using CPUDevice = Eigen::ThreadPoolDevice;
69+
70+
namespace functor {
71+
72+
// Connected components CPU implementation. See `connected_components.h` for a
73+
// description of the algorithm.
74+
template <typename T>
75+
struct ImageConnectedComponentsFunctor<CPUDevice, T> {
76+
void operator()(OpKernelContext* ctx,
77+
typename TTypes<OutputType>::Flat output,
78+
typename TTypes<T, 3>::ConstTensor images,
79+
typename TTypes<OutputType, 3>::Tensor forest,
80+
typename TTypes<OutputType, 3>::Tensor rank) {
81+
const int64 num_images = images.dimension(0),
82+
num_rows = images.dimension(1), num_cols = images.dimension(2),
83+
num_elements = images.size();
84+
// Bail out early for an empty image--no work to do.
85+
if (num_elements == 0) {
86+
return;
87+
}
88+
auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
89+
BlockedImageUnionFindFunctor<T> union_find(
90+
images.data(), num_rows, num_cols, forest.data(), rank.data());
91+
while (union_find.can_merge()) {
92+
union_find.merge_blocks();
93+
int64 num_blocks_vertically = union_find.num_blocks_vertically();
94+
int64 num_blocks_horizontally = union_find.num_blocks_horizontally();
95+
// Merging each block calls union_down for each pixel in a row of the
96+
// block, and union_right for each pixel in a column of the block. Assume
97+
// 20 instructions for each call to union_down or union_right. find() may
98+
// loop more while searching for the root, but this should not be very
99+
// significant.
100+
int cost = (union_find.block_height() + union_find.block_width()) * 20;
101+
Shard(worker_threads->num_threads, worker_threads->workers,
102+
num_images * num_blocks_vertically * num_blocks_horizontally, cost,
103+
[&union_find, num_blocks_vertically, num_blocks_horizontally](
104+
int64 start_block, int64 limit_block) {
105+
for (int64 i = start_block; i < limit_block; i++) {
106+
int64 block_x = i % num_blocks_horizontally;
107+
int64 block_y =
108+
(i / num_blocks_horizontally) % num_blocks_vertically;
109+
int64 image =
110+
i / (num_blocks_horizontally * num_blocks_vertically);
111+
union_find.merge_internal_block_edges(image, block_y, block_x);
112+
}
113+
});
114+
}
115+
FindRootFunctor<CPUDevice, T>()(ctx->eigen_device<CPUDevice>(), output,
116+
images.data(), union_find);
117+
}
118+
};
119+
120+
} // end namespace functor
121+
122+
#define REGISTER_IMAGE_CONNECTED_COMPONENTS(TYPE) \
123+
REGISTER_KERNEL_BUILDER(Name("ImageConnectedComponents") \
124+
.Device(DEVICE_CPU) \
125+
.TypeConstraint<TYPE>("dtype"), \
126+
ImageConnectedComponents<CPUDevice, TYPE>)
127+
// Connected components (arguably) make sense for number, bool, and string types
128+
TF_CALL_NUMBER_TYPES(REGISTER_IMAGE_CONNECTED_COMPONENTS);
129+
TF_CALL_bool(REGISTER_IMAGE_CONNECTED_COMPONENTS);
130+
TF_CALL_string(REGISTER_IMAGE_CONNECTED_COMPONENTS);
131+
#undef REGISTER_IMAGE_CONNECTED_COMPONENTS
132+
133+
// TODO(ringwalt): Implement on GPU. We probably want to stick to the original
134+
// algorithm by Stava and Benes there for efficiency (computing small blocks in
135+
// shared memory in CUDA thread blocks, instead of starting with single-pixel
136+
// blocks).
137+
138+
} // end namespace tensorflow

0 commit comments

Comments
 (0)