Skip to content

Commit 167310d

Browse files
authored
Migrate tf.contrib.seq2seq to addons. (#72)
* Initial commit for seq2seq. Only the V2 implementation is copied from tf.contrib.seq2seq, and exported. Current the beam_search_decoder is not broken due to unknown reason from c kernel, which need more time to debug. * Fix the unit test and export for beam_search. Only export the V1 implemenetation of beam_search_decoder. * Fix unit test failrue and format issue. * Fix dependency issue for attentionWrapper. It was depending on the contrib/framework/tensor_util which is now ported to tf_addon/util/ * Fix unit test for AttentionWrapper. Previously the output projection dense layer was initialized to be ones, which then cause the data to sampler.sample() to be same within batch, which is why the expected_sample_id mean() to be 0 (it always pick the first index since all the data among the index are the same). Due to the unknown dependency difference between OSS and internal build, sometimes the data will loss some precision, and cause the sample_id to change randomly. This change update the output dense layer to have random weights, which will then lead to much diverse but determined result. * Remove the "v2" suffix from all the class names. Since the V1 implemenetation is totally removed, there is no point to distinguish between v1 and v2. * Update README.md for generic guildline, and sample code. * Fix some typo and slight style improvement. * Fix the unit test due to a rename of the class. * Reorg all the py and c lib for seq2seq. The c kernel has been moved to custom_ops and imported from there. The python lib has been moved out from /python package, and now placed directly under seq2seq. * Reformat all the python module with PEP8. * Update README with Maintainers. * Documentation cleanup. * Remove the internal bug link since it is not available externally. * Update the shape check for AttentionWrapper. Remove the existing usage of tensor_util, and replace it with control deps and assert_equal, which should achieve the same behavior for shape verification. * Fix lint and format issue. * More lint fix. * More and more lint fix.
1 parent 2f9c91f commit 167310d

20 files changed

+8307
-0
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
licenses(["notice"]) # Apache 2.0
2+
3+
package(default_visibility = ["//visibility:public"])
4+
5+
6+
cc_binary(
7+
name = "_beam_search_ops.so",
8+
srcs = [
9+
"cc/kernels/beam_search_ops.cc",
10+
"cc/kernels/beam_search_ops.h",
11+
# "cc/kernels/beam_search_ops_gpu.cu.cc",
12+
"cc/ops/beam_search_ops.cc",
13+
],
14+
copts = [
15+
"-pthread",
16+
"-std=c++11",
17+
"-D_GLIBCXX_USE_CXX11_ABI=0",
18+
],
19+
linkshared = 1,
20+
deps = [
21+
"@local_config_tf//:libtensorflow_framework",
22+
"@local_config_tf//:tf_header_lib",
23+
],
24+
)
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#define EIGEN_USE_THREADS
17+
18+
#if GOOGLE_CUDA
19+
#define EIGEN_USE_GPU
20+
#endif // GOOGLE_CUDA
21+
22+
#include <memory>
23+
#include <vector>
24+
25+
#include "tensorflow/core/framework/op_kernel.h"
26+
#include "tensorflow/core/framework/register_types.h"
27+
#include "tensorflow/core/framework/tensor.h"
28+
#include "tensorflow/core/framework/tensor_shape.h"
29+
#include "tensorflow/core/framework/tensor_types.h"
30+
#include "tensorflow/core/framework/types.h"
31+
#include "tensorflow/core/platform/logging.h"
32+
#include "tensorflow/core/platform/macros.h"
33+
#include "tensorflow/core/util/work_sharder.h"
34+
#include "tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops.h"
35+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
36+
37+
namespace tensorflow {
38+
39+
typedef Eigen::ThreadPoolDevice CPUDevice;
40+
typedef Eigen::GpuDevice GPUDevice;
41+
42+
template <typename Device, typename T>
43+
class GatherTreeOp : public OpKernel {
44+
public:
45+
explicit GatherTreeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
46+
47+
void Compute(OpKernelContext* ctx) override {
48+
const Device& device = ctx->eigen_device<Device>();
49+
const Tensor& step_ids = ctx->input(0);
50+
const Tensor& parent_ids = ctx->input(1);
51+
const Tensor& max_sequence_lengths = ctx->input(2);
52+
const Tensor& end_token = ctx->input(3);
53+
const TensorShape& step_ids_shape = step_ids.shape();
54+
OP_REQUIRES(
55+
ctx, step_ids_shape.dims() == 3,
56+
errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
57+
step_ids_shape.DebugString()));
58+
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(max_sequence_lengths.shape()),
59+
errors::InvalidArgument(
60+
"max_sequence_lengths must be a vector, saw shape: ",
61+
max_sequence_lengths.shape().DebugString()));
62+
OP_REQUIRES(
63+
ctx, TensorShapeUtils::IsScalar(end_token.shape()),
64+
errors::InvalidArgument("end_token must be a scalar, saw shape: ",
65+
end_token.shape().DebugString()));
66+
OP_REQUIRES(
67+
ctx, step_ids_shape == parent_ids.shape(),
68+
errors::InvalidArgument(
69+
"step_ids.shape must match parent_ids.shape. but shapes are: ",
70+
step_ids_shape.DebugString(), " and ",
71+
parent_ids.shape().DebugString()));
72+
OP_REQUIRES(
73+
ctx,
74+
step_ids_shape.dim_size(1) == max_sequence_lengths.shape().dim_size(0),
75+
errors::InvalidArgument("batch size dimensions step_ids.shape[1] and "
76+
"max_sequence_lengths.shape[0] must match. "
77+
"but shapes are: ",
78+
step_ids_shape.DebugString(), " and ",
79+
max_sequence_lengths.shape().DebugString()));
80+
Tensor* beams;
81+
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
82+
typename TTypes<T, 3>::ConstTensor step_ids_t(step_ids.tensor<T, 3>());
83+
typename TTypes<T, 3>::ConstTensor parent_ids_t(parent_ids.tensor<T, 3>());
84+
typename TTypes<int32>::ConstVec max_seq_lens_t =
85+
max_sequence_lengths.vec<int32>();
86+
typename TTypes<T>::ConstScalar end_token_t(end_token.scalar<T>());
87+
typename TTypes<T, 3>::Tensor beams_t(beams->tensor<T, 3>());
88+
const T end_token_value = end_token_t();
89+
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
90+
max_seq_lens_t, end_token_value, beams_t);
91+
}
92+
};
93+
94+
#define REGISTER_KERNEL(T) \
95+
REGISTER_KERNEL_BUILDER( \
96+
Name("GatherTree").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
97+
GatherTreeOp<CPUDevice, T>);
98+
REGISTER_KERNEL(int32);
99+
#undef REGISTER_KERNEL
100+
101+
namespace functor {
102+
103+
// CPU specialization
104+
template <>
105+
struct GatherTree<CPUDevice, int32> {
106+
void operator()(OpKernelContext* ctx, const CPUDevice& d,
107+
TTypes<int32, 3>::ConstTensor step_ids,
108+
TTypes<int32, 3>::ConstTensor parent_ids,
109+
TTypes<int32>::ConstVec max_sequence_lengths,
110+
const int32 end_token, TTypes<int32, 3>::Tensor beams) {
111+
const int32 max_time = parent_ids.dimension(0);
112+
const int32 batch_size = parent_ids.dimension(1);
113+
const int32 beam_width = parent_ids.dimension(2);
114+
beams.setConstant(end_token);
115+
116+
auto DoWork = [&, ctx, end_token](int start_batch_beam,
117+
int limit_batch_beam) {
118+
for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
119+
const int32 batch = i / beam_width;
120+
const int32 beam = i % beam_width;
121+
const int32 max_seq_len_b =
122+
Eigen::numext::mini(max_time, max_sequence_lengths(batch));
123+
if (max_seq_len_b <= 0) {
124+
continue;
125+
}
126+
beams(max_seq_len_b - 1, batch, beam) =
127+
step_ids(max_seq_len_b - 1, batch, beam);
128+
int32 parent = parent_ids(max_seq_len_b - 1, batch, beam);
129+
for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
130+
if (parent < 0 || parent > beam_width) {
131+
ctx->SetStatus(
132+
errors::InvalidArgument("Saw invalid parent id ", parent,
133+
" at (batch, time, beam) == (", batch,
134+
", ", level, ", ", beam, ")"));
135+
return;
136+
}
137+
beams(level, batch, beam) = step_ids(level, batch, parent);
138+
parent = parent_ids(level, batch, parent);
139+
}
140+
// Not necessary when using a BeamSearchDecoder, but necessary
141+
// when a user feeds in possibly broken trajectory (i.e., non-eos
142+
// entries in a beam following eos entries).
143+
bool finished = false;
144+
for (int32 time = 0; time < max_seq_len_b; ++time) {
145+
if (finished) {
146+
beams(time, batch, beam) = end_token;
147+
} else if (beams(time, batch, beam) == end_token) {
148+
finished = true;
149+
}
150+
}
151+
}
152+
};
153+
// Guesstimate of cost; ~5 lookup/store/compare per inner beam
154+
// traversal time step.
155+
const int64 batch_beam_cost =
156+
Eigen::TensorOpCost::DivCost<int32>() +
157+
6 * Eigen::TensorOpCost::AddCost<int32>() +
158+
2 * max_time * (5 * Eigen::TensorOpCost::AddCost<int32>());
159+
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
160+
Shard(worker_threads.num_threads, worker_threads.workers,
161+
batch_size * beam_width, batch_beam_cost, DoWork);
162+
}
163+
};
164+
165+
} // namespace functor
166+
167+
#if GOOGLE_CUDA
168+
namespace functor {
169+
#define DECLARE_GPU_SPEC(T) \
170+
template <> \
171+
void GatherTree<GPUDevice, T>::operator()( \
172+
OpKernelContext* ctx, const GPUDevice& d, \
173+
typename TTypes<T, 3>::ConstTensor step_ids, \
174+
typename TTypes<T, 3>::ConstTensor parent_ids, \
175+
TTypes<int32>::ConstVec max_sequence_lengths, const T end_token, \
176+
typename TTypes<T, 3>::Tensor beams); \
177+
extern template struct GatherTree<GPUDevice, T>;
178+
179+
DECLARE_GPU_SPEC(int32);
180+
#undef DECLARE_GPU_SPEC
181+
} // end namespace functor
182+
183+
#define REGISTER_GPU_KERNEL(T) \
184+
REGISTER_KERNEL_BUILDER(Name("GatherTree") \
185+
.Device(DEVICE_GPU) \
186+
.TypeConstraint<T>("T") \
187+
.HostMemory("end_token"), \
188+
GatherTreeOp<GPUDevice, T>);
189+
190+
REGISTER_GPU_KERNEL(int32);
191+
#undef REGISTER_GPU_KERNEL
192+
#endif // GOOGLE_CUDA
193+
194+
} // end namespace tensorflow
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
17+
#define TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
18+
19+
#include "tensorflow/core/framework/tensor_types.h"
20+
#include "tensorflow/core/platform/types.h"
21+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22+
23+
namespace tensorflow {
24+
class OpKernelContext;
25+
26+
namespace functor {
27+
28+
template <typename Device, typename T>
29+
struct GatherTree {
30+
void operator()(OpKernelContext* ctx, const Device& d,
31+
typename TTypes<T, 3>::ConstTensor step_ids,
32+
typename TTypes<T, 3>::ConstTensor parent_ids,
33+
TTypes<int32>::ConstVec max_sequence_lengths,
34+
const T end_token, typename TTypes<T, 3>::Tensor beams);
35+
};
36+
37+
} // namespace functor
38+
} // namespace tensorflow
39+
40+
#endif // TENSORFLOW_CONTRIB_SEQ2SEQ_KERNELS_BEAM_SEARCH_OPS_H_
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#if GOOGLE_CUDA
17+
18+
#define EIGEN_USE_GPU
19+
20+
#include "tensorflow_addons/custom_ops/seq2seq/cc/kernels/beam_search_ops.h"
21+
#include "tensorflow/core/util/cuda_kernel_helper.h"
22+
23+
namespace tensorflow {
24+
namespace functor {
25+
26+
typedef Eigen::GpuDevice GPUDevice;
27+
28+
template <typename T>
29+
__global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
30+
const int32 beam_width, const T* step_ids,
31+
const T* parent_ids,
32+
const int32* max_sequence_lengths,
33+
const T end_token, T* beams) {
34+
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
35+
const int32 batch = i / beam_width;
36+
const int32 beam = i % beam_width;
37+
38+
const int32 max_seq_len_b =
39+
Eigen::numext::mini(max_time, ldg(max_sequence_lengths + batch));
40+
if (max_seq_len_b <= 0) {
41+
continue;
42+
}
43+
44+
#define GET_IX(time_ix, beam_ix) \
45+
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
46+
const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam);
47+
beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
48+
int32 parent = ldg(parent_ids + initial_beam_ix);
49+
bool found_bad = false;
50+
for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
51+
const int32 level_beam_ix = GET_IX(level, beam);
52+
const int32 level_parent_ix = GET_IX(level, parent);
53+
if (parent < 0 || parent > beam_width) {
54+
beams[level_beam_ix] = -1;
55+
parent = -1;
56+
found_bad = true;
57+
} else {
58+
beams[level_beam_ix] = ldg(step_ids + level_parent_ix);
59+
parent = ldg(parent_ids + level_parent_ix);
60+
}
61+
}
62+
// Not necessary when using a BeamSearchDecoder, but necessary
63+
// when a user feeds in possibly broken trajectory (i.e., non-eos
64+
// entries in a beam following eos entries).
65+
if (!found_bad) {
66+
bool finished = false;
67+
for (int32 time = 0; time < max_seq_len_b; ++time) {
68+
const int32 level_beam_ix = GET_IX(time, beam);
69+
if (finished) {
70+
beams[level_beam_ix] = end_token;
71+
} else if (beams[level_beam_ix] == end_token) {
72+
finished = true;
73+
}
74+
}
75+
}
76+
#undef GET_IX
77+
}
78+
}
79+
80+
template <typename T>
81+
struct GatherTree<GPUDevice, T> {
82+
void operator()(OpKernelContext* ctx, const GPUDevice& d,
83+
typename TTypes<T, 3>::ConstTensor step_ids,
84+
typename TTypes<T, 3>::ConstTensor parent_ids,
85+
TTypes<int32>::ConstVec max_sequence_length,
86+
const T end_token, typename TTypes<T, 3>::Tensor beams) {
87+
const int32 max_time = parent_ids.dimension(0);
88+
const int32 batch_size = parent_ids.dimension(1);
89+
const int32 beam_width = parent_ids.dimension(2);
90+
// First kernel launch to "zero" things out
91+
beams.device(d) = beams.constant(end_token);
92+
93+
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d);
94+
TF_CHECK_OK(CudaLaunchKernel(
95+
GatherTreeOpKernel<T>, config.block_count, config.thread_per_block, 0,
96+
d.stream(), batch_size, max_time, beam_width, step_ids.data(),
97+
parent_ids.data(), max_sequence_length.data(), end_token,
98+
beams.data()));
99+
}
100+
};
101+
102+
#define DEFINE_GPU_SPECS(T) template struct GatherTree<GPUDevice, T>;
103+
104+
DEFINE_GPU_SPECS(int32);
105+
#undef DEFINE_GPU_SPECS
106+
107+
} // end namespace functor
108+
} // end namespace tensorflow
109+
#endif // GOOGLE_CUDA

0 commit comments

Comments
 (0)