Skip to content

Commit 3931a9b

Browse files
authored
Add encode_text_for_generation in XLNetTokenizer (#278)
* Add encode_text_for_generation in XLNetTokenizer * Add unittest for encode_text_for_generation
1 parent ccd226c commit 3931a9b

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

texar/torch/data/tokenizers/xlnet_tokenizer.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,57 @@ def encode_text(self,
282282

283283
return input_ids, segment_ids, input_mask
284284

285+
def encode_text_for_generation(
286+
self,
287+
text: str,
288+
max_seq_length: Optional[int] = None,
289+
append_eos_token: bool = True) -> Tuple[List[int], int]:
290+
r"""Adds special tokens to a sequence and computes the corresponding
291+
sequence length for XLNet specific tasks. The sequence will be truncated
292+
if its length is larger than ``max_seq_length``.
293+
294+
A XLNet sequence has the following format:
295+
`[bos_token]` X `[eos_token]` `[pad_token]`
296+
297+
Args:
298+
text: Input text.
299+
max_seq_length: Maximum sequence length.
300+
append_eos_token: Whether to append ``eos_token`` after the
301+
sequence.
302+
303+
Returns:
304+
A tuple of `(input_ids, seq_len)`, where
305+
306+
- ``input_ids``: A list of input token ids with added
307+
special tokens.
308+
- ``seq_len``: The sequence length.
309+
"""
310+
if max_seq_length is None:
311+
max_seq_length = self.max_len
312+
313+
token_ids = self.map_text_to_id(text)
314+
assert isinstance(token_ids, list)
315+
316+
bos_token_id = self._map_token_to_id(self.bos_token)
317+
eos_token_id = self._map_token_to_id(self.eos_token)
318+
pad_token_id = self._map_token_to_id(self.pad_token)
319+
320+
if append_eos_token:
321+
input_ids = token_ids[:max_seq_length - 2]
322+
input_ids = [bos_token_id] + input_ids + [eos_token_id]
323+
else:
324+
input_ids = token_ids[:max_seq_length - 1]
325+
input_ids = [bos_token_id] + input_ids
326+
327+
seq_len = len(input_ids)
328+
329+
# Pad up to the maximum sequence length.
330+
input_ids = input_ids + [pad_token_id] * (max_seq_length - seq_len)
331+
332+
assert len(input_ids) == max_seq_length
333+
334+
return input_ids, seq_len
335+
285336
@staticmethod
286337
def default_hparams() -> Dict[str, Any]:
287338
r"""Returns a dictionary of hyperparameters with default values.

texar/torch/data/tokenizers/xlnet_tokenizer_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,24 @@ def test_encode_text(self):
250250
self.assertListEqual(segment_ids, [0, 0, 0, 1, 1, 1, 2])
251251
self.assertListEqual(input_mask, [0, 0, 0, 0, 0, 0, 0])
252252

253+
def test_encode_text_for_generation(self):
254+
text_1 = u"lower newer"
255+
256+
text_1_ids = self.tokenizer.map_text_to_id(text_1)
257+
258+
input_ids, seq_len = \
259+
self.tokenizer.encode_text_for_generation(text=text_1,
260+
max_seq_length=10)
261+
262+
bos_token_id = self.tokenizer.map_token_to_id(self.tokenizer.bos_token)
263+
eos_token_id = self.tokenizer.map_token_to_id(self.tokenizer.eos_token)
264+
pad_token_id = self.tokenizer.map_token_to_id(self.tokenizer.pad_token)
265+
266+
self.assertListEqual(input_ids,
267+
[bos_token_id] + text_1_ids + [eos_token_id] +
268+
[pad_token_id, pad_token_id, pad_token_id])
269+
self.assertEqual(seq_len, 7)
270+
253271

254272
if __name__ == "__main__":
255273
unittest.main()

0 commit comments

Comments
 (0)