1+ /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ http://www.apache.org/licenses/LICENSE-2.0
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+ // See docs for ImageConnectedComponents in ../ops/image_ops.cc, and description
16+ // of the algorithm in connected_components.h.
17+
18+ #define EIGEN_USE_THREADS
19+
20+ #include " tensorflow_addons/custom_ops/image/cc/kernels/connected_components.h"
21+ #include " tensorflow/core/framework/op_kernel.h"
22+ #include " tensorflow/core/framework/register_types.h"
23+ #include " tensorflow/core/framework/types.h"
24+ #include " tensorflow/core/platform/types.h"
25+
26+ namespace tensorflow {
27+
28+ using tensorflow::functor::BlockedImageUnionFindFunctor;
29+ using tensorflow::functor::FindRootFunctor;
30+ using tensorflow::functor::ImageConnectedComponentsFunctor;
31+ using tensorflow::functor::TensorRangeFunctor;
32+
33+ using OutputType = typename BlockedImageUnionFindFunctor<bool >::OutputType;
34+
35+ // Computes connected components on batches of 2D images.
36+ template <typename Device, typename T>
37+ class ImageConnectedComponents : public OpKernel {
38+ public:
39+ explicit ImageConnectedComponents (OpKernelConstruction* ctx)
40+ : OpKernel(ctx) {}
41+
42+ void Compute (OpKernelContext* ctx) override {
43+ const Tensor& images_t = ctx->input (0 );
44+ OP_REQUIRES (ctx, images_t .shape ().dims () == 3 ,
45+ errors::InvalidArgument (" Input images must have rank 3" ));
46+ Tensor forest_t , rank_t ;
47+ OP_REQUIRES_OK (ctx, ctx->allocate_temp (tensorflow::DT_INT64,
48+ images_t .shape (), &forest_t ));
49+ OP_REQUIRES_OK (ctx, ctx->allocate_temp (tensorflow::DT_INT64,
50+ images_t .shape (), &rank_t ));
51+ Tensor* output_t ;
52+ OP_REQUIRES_OK (ctx, ctx->allocate_output (0 , images_t .shape (), &output_t ));
53+
54+ // Fill forest with values from 0 to n - 1, so that each node points to
55+ // itself.
56+ TensorRangeFunctor<Device>()(ctx->eigen_device <Device>(),
57+ forest_t .flat <OutputType>());
58+ auto rank = rank_t .tensor <OutputType, 3 >();
59+ rank.device (ctx->eigen_device <Device>()) = rank.constant (OutputType (0 ));
60+
61+ const auto images = images_t .tensor <T, 3 >();
62+ auto forest = forest_t .tensor <OutputType, 3 >();
63+ ImageConnectedComponentsFunctor<Device, T>()(
64+ ctx, output_t ->flat <OutputType>(), images, forest, rank);
65+ }
66+ };
67+
68+ using CPUDevice = Eigen::ThreadPoolDevice;
69+
70+ namespace functor {
71+
72+ // Connected components CPU implementation. See `connected_components.h` for a
73+ // description of the algorithm.
74+ template <typename T>
75+ struct ImageConnectedComponentsFunctor <CPUDevice, T> {
76+ void operator ()(OpKernelContext* ctx,
77+ typename TTypes<OutputType>::Flat output,
78+ typename TTypes<T, 3 >::ConstTensor images,
79+ typename TTypes<OutputType, 3 >::Tensor forest,
80+ typename TTypes<OutputType, 3 >::Tensor rank) {
81+ const int64 num_images = images.dimension (0 ),
82+ num_rows = images.dimension (1 ), num_cols = images.dimension (2 ),
83+ num_elements = images.size ();
84+ // Bail out early for an empty image--no work to do.
85+ if (num_elements == 0 ) {
86+ return ;
87+ }
88+ auto worker_threads = ctx->device ()->tensorflow_cpu_worker_threads ();
89+ BlockedImageUnionFindFunctor<T> union_find (
90+ images.data (), num_rows, num_cols, forest.data (), rank.data ());
91+ while (union_find.can_merge ()) {
92+ union_find.merge_blocks ();
93+ int64 num_blocks_vertically = union_find.num_blocks_vertically ();
94+ int64 num_blocks_horizontally = union_find.num_blocks_horizontally ();
95+ // Merging each block calls union_down for each pixel in a row of the
96+ // block, and union_right for each pixel in a column of the block. Assume
97+ // 20 instructions for each call to union_down or union_right. find() may
98+ // loop more while searching for the root, but this should not be very
99+ // significant.
100+ int cost = (union_find.block_height () + union_find.block_width ()) * 20 ;
101+ Shard (worker_threads->num_threads , worker_threads->workers ,
102+ num_images * num_blocks_vertically * num_blocks_horizontally, cost,
103+ [&union_find, num_blocks_vertically, num_blocks_horizontally](
104+ int64 start_block, int64 limit_block) {
105+ for (int64 i = start_block; i < limit_block; i++) {
106+ int64 block_x = i % num_blocks_horizontally;
107+ int64 block_y =
108+ (i / num_blocks_horizontally) % num_blocks_vertically;
109+ int64 image =
110+ i / (num_blocks_horizontally * num_blocks_vertically);
111+ union_find.merge_internal_block_edges (image, block_y, block_x);
112+ }
113+ });
114+ }
115+ FindRootFunctor<CPUDevice, T>()(ctx->eigen_device <CPUDevice>(), output,
116+ images.data (), union_find);
117+ }
118+ };
119+
120+ } // end namespace functor
121+
122+ #define REGISTER_IMAGE_CONNECTED_COMPONENTS (TYPE ) \
123+ REGISTER_KERNEL_BUILDER (Name(" ImageConnectedComponents" ) \
124+ .Device(DEVICE_CPU) \
125+ .TypeConstraint<TYPE>(" dtype" ), \
126+ ImageConnectedComponents<CPUDevice, TYPE>)
127+ // Connected components (arguably) make sense for number, bool, and string types
128+ TF_CALL_NUMBER_TYPES (REGISTER_IMAGE_CONNECTED_COMPONENTS);
129+ TF_CALL_bool (REGISTER_IMAGE_CONNECTED_COMPONENTS);
130+ TF_CALL_string (REGISTER_IMAGE_CONNECTED_COMPONENTS);
131+ #undef REGISTER_IMAGE_CONNECTED_COMPONENTS
132+
133+ // TODO(ringwalt): Implement on GPU. We probably want to stick to the original
134+ // algorithm by Stava and Benes there for efficiency (computing small blocks in
135+ // shared memory in CUDA thread blocks, instead of starting with single-pixel
136+ // blocks).
137+
138+ } // end namespace tensorflow
0 commit comments