Skip to content

Commit 4bb70ec

Browse files
swapnull7gpengzhi
authored andcommitted
Add T5Tokenizer based on SentencePieceTokenizer (#283)
* Add text style transfer (#1) * initial commit * bug fixes and adjusting conv inputs * separate forward function for Discriminator and Generator and disable Gen training for debugging * remove debugger statement * bug fix * detaching stuff before accumulating * refactor and add component as optional parameter * Add optimizer for and backprop against encoder * Add in README * Add text style transfer with improvements (#2) * initial commit * bug fixes and adjusting conv inputs * separate forward function for Discriminator and Generator and disable Gen training for debugging * remove debugger statement * bug fix * detaching stuff before accumulating * refactor and add component as optional parameter * Add optimizer for and backprop against encoder * Add in README * more fixes to eval mode * create optimizers so that they can be saved * fix typo * restore optimizers * Update ctrl_gen_model.py * remove tensorflow import * Add text style transfer (#3) * Add text style transfer (#4) * initial commit * bug fixes and adjusting conv inputs * separate forward function for Discriminator and Generator and disable Gen training for debugging * remove debugger statement * bug fix * detaching stuff before accumulating * refactor and add component as optional parameter * Add optimizer for and backprop against encoder * Add in README * more fixes to eval mode * create optimizers so that they can be saved * fix typo * linting issues * add type annotation for encoder * fix linting * Isolate AE in training * works after changing the learning rate * remove debugger * Add text style transfer (#5) * Reviewed changes * linting * Add text style transfer (#6) * initial commit * linting * Fix docs build issue * Fix typo * init_commit * modularize t5 and comment out debugging statements * Add decorators for pretrained_tests * remove changes from text-style-transfer * remove collect variable changes * remove text-style-transfer from docs * more clean up and removing debugger statements * more clean up * fix linting * more linting * Update utils.rst * linting and fixing minor bugs in gpt2-tests * skipping pretrained tests * fix documentation error * fix linting * Update gpt2_test.py * refactor gin reading function * revert using identity, use nn.Module instead * Update decoder_base.py * fix type for T5Decoder * fix linting * add a standalone test for T5 * Adding T5 Tokenizer * adding import to __init__.py * linting fix * making review changes * reviewed changes
1 parent 3931a9b commit 4bb70ec

File tree

14 files changed

+311
-25
lines changed

14 files changed

+311
-25
lines changed

docs/code/data.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ Tokenizer
3737
.. autoclass:: texar.torch.data.XLNetTokenizer
3838
:members:
3939

40+
:hidden:`T5Tokenizer`
41+
~~~~~~~~~~~~~~~~~~~~~~~~
42+
.. autoclass:: texar.torch.data.T5Tokenizer
43+
:members:
44+
4045
Vocabulary
4146
==========
4247

docs/code/modules.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,14 @@ Regressors
252252
.. autoclass:: texar.torch.modules.XLNetRegressor
253253
:members:
254254

255+
EncoderDecoders
256+
================
257+
258+
:hidden:`T5EncoderDecoder`
259+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
260+
.. autoclass:: texar.torch.modules.T5EncoderDecoder
261+
:members:
262+
255263
Pre-trained
256264
===========
257265

@@ -285,6 +293,11 @@ Pre-trained
285293
.. autoclass:: texar.torch.modules.PretrainedXLNetMixin
286294
:members:
287295

296+
:hidden:`PretrainedT5Mixin`
297+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
298+
.. autoclass:: texar.torch.modules.PretrainedT5Mixin
299+
:members:
300+
288301
Connectors
289302
==========
290303

texar/torch/data/tokenizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@
2121
from texar.torch.data.tokenizers.tokenizer_base import *
2222
from texar.torch.data.tokenizers.xlnet_tokenizer import *
2323
from texar.torch.data.tokenizers.sentencepiece_tokenizer import *
24+
from texar.torch.data.tokenizers.t5_tokenizer import *

texar/torch/data/tokenizers/bert_tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(self,
121121
vocab_file = os.path.join(self.pretrained_model_dir,
122122
self._VOCAB_FILE_MAP['vocab_file']
123123
[self.pretrained_model_name])
124-
assert self.pretrained_model_name is not None
124+
125125
if self._MAX_INPUT_SIZE.get(self.pretrained_model_name):
126126
self.max_len = self._MAX_INPUT_SIZE[self.pretrained_model_name]
127127
else:
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright 2019 The Texar Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Pre-trained T5 tokenizer.
16+
"""
17+
18+
from typing import Any, Dict, Optional
19+
20+
import os
21+
import re
22+
23+
from texar.torch.data.tokenizers.sentencepiece_tokenizer \
24+
import SentencePieceTokenizer
25+
from texar.torch.modules.pretrained.t5 import PretrainedT5Mixin
26+
27+
__all__ = [
28+
'T5Tokenizer',
29+
]
30+
31+
32+
class T5Tokenizer(SentencePieceTokenizer, PretrainedT5Mixin):
33+
r"""Pre-trained T5 Tokenizer.
34+
35+
Args:
36+
pretrained_model_name (optional): a `str`, the name of
37+
pre-trained model (e.g., `T5-Small`). Please refer to
38+
:class:`~texar.torch.modules.PretrainedT5Mixin` for
39+
all supported models.
40+
If None, the model name in :attr:`hparams` is used.
41+
cache_dir (optional): the path to a folder in which the
42+
pre-trained models will be cached. If `None` (default),
43+
a default directory (``texar_data`` folder under user's home
44+
directory) will be used.
45+
hparams (dict or HParams, optional): Hyperparameters. Missing
46+
hyperparameters will be set to default values. See
47+
:meth:`default_hparams` for the hyperparameter structure
48+
and default values.
49+
"""
50+
51+
_IS_PRETRAINED = True
52+
53+
_VOCAB_FILE_NAMES = {
54+
'vocab_file': 'sentencepiece.model'
55+
}
56+
57+
_MAX_INPUT_SIZE = {
58+
'T5-Small': 512,
59+
'T5-Base': 512,
60+
'T5-Large': 512,
61+
'T5-3B': 512,
62+
'T5-11B': 512
63+
}
64+
65+
def __init__(self,
66+
pretrained_model_name: Optional[str] = None,
67+
cache_dir: Optional[str] = None,
68+
hparams=None):
69+
70+
self.load_pretrained_config(pretrained_model_name, cache_dir, hparams)
71+
72+
if self.pretrained_model_dir is not None:
73+
assert self.pretrained_model_name is not None
74+
vocab_file = os.path.join(self.pretrained_model_dir,
75+
self._VOCAB_FILE_NAMES['vocab_file'])
76+
77+
if self._MAX_INPUT_SIZE.get(self.pretrained_model_name):
78+
self.max_len = self._MAX_INPUT_SIZE[self.pretrained_model_name]
79+
setattr(self.hparams, 'vocab_file', vocab_file)
80+
else:
81+
if self.hparams.get('max_len'):
82+
self.max_len = self.hparams['max_len']
83+
84+
# Add extra_ids to the special token list
85+
additional_special_tokens = []
86+
extra_ids = self.hparams['extra_ids']
87+
if extra_ids > 0:
88+
additional_special_tokens.extend(
89+
["<extra_id_{}>".format(i) for i in range(extra_ids)])
90+
91+
setattr(self.hparams, 'additional_special_tokens',
92+
additional_special_tokens)
93+
94+
super().__init__(hparams=None)
95+
96+
@staticmethod
97+
def default_hparams() -> Dict[str, Any]:
98+
r"""Returns a dictionary of hyperparameters with default values.
99+
100+
* The tokenizer is determined by the constructor argument
101+
:attr:`pretrained_model_name` if it's specified. In this case,
102+
`hparams` are ignored.
103+
* Otherwise, the tokenizer is determined by
104+
`hparams['pretrained_model_name']` if it's specified. All other
105+
configurations in `hparams` are ignored.
106+
* If the above two are `None`, the tokenizer is defined by the
107+
configurations in `hparams`.
108+
109+
.. code-block:: python
110+
111+
{
112+
"pretrained_model_name": "T5-Small",
113+
"vocab_file": None,
114+
"max_len": 512,
115+
"bos_token": None,
116+
"eos_token": "</s>",
117+
"unk_token": "<unk>",
118+
"pad_token": "<pad>",
119+
"extra_ids": 100,
120+
"additional_special_tokens": [],
121+
"name": "t5_tokenizer",
122+
}
123+
124+
Here:
125+
126+
`"pretrained_model_name"`: str or None
127+
The name of the pre-trained T5 model.
128+
129+
`"vocab_file"`: str or None
130+
The path to a sentencepiece vocabulary file.
131+
132+
`"max_len"`: int or None
133+
The maximum sequence length that this model might ever be used with.
134+
135+
`"bos_token"`: str or None
136+
Beginning of sentence token. Set None to disable ``bos_token``.
137+
138+
`"eos_token"`: str
139+
End of sentence token. Set None to disable ``eos_token``.
140+
141+
`"unk_token"`: str
142+
Unknown token. Set None to disable ``unk_token``.
143+
144+
`"pad_token"`: str
145+
Padding token. Set None to disable ``pad_token``.
146+
147+
`"extra_ids"`: int
148+
Add a number of extra ids added to the end of the vocabulary for
149+
use as sentinels. These tokens are accessible as `<extra_id_{%d}>`
150+
where `{%d}` is a number between 0 and extra_ids-1. Extra tokens
151+
are indexed from the end of the vocabulary up to beginning
152+
(<extra_id_0> is the last token in the vocabulary) (like in T5
153+
preprocessing) see:
154+
`https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117`
155+
156+
`"additional_special_tokens"`: list
157+
A list of additional special tokens.
158+
159+
`"name"`: str
160+
Name of the tokenizer.
161+
"""
162+
return {
163+
'pretrained_model_name': 'T5-Small',
164+
'vocab_file': None,
165+
'max_len': 512,
166+
'bos_token': None,
167+
'eos_token': '</s>',
168+
'unk_token': '<unk>',
169+
'pad_token': '<pad>',
170+
'extra_ids': 100,
171+
'additional_special_tokens': [],
172+
'name': 't5_tokenizer',
173+
'@no_typecheck': ['pretrained_model_name'],
174+
}
175+
176+
@property
177+
def vocab_size(self) -> int:
178+
return len(self.sp_model) + self.hparams['extra_ids']
179+
180+
def _map_token_to_id(self, token: str) -> int:
181+
if token.startswith("<extra_id_"):
182+
match = re.match(r"<extra_id_(\d+)>", token)
183+
num = int(match.group(1)) # type: ignore
184+
return self.vocab_size - num - 1
185+
return self.sp_model.PieceToId(token)
186+
187+
def _map_id_to_token(self, index: int) -> str:
188+
if index < self.sp_model.get_piece_size():
189+
token = self.sp_model.IdToPiece(index)
190+
else:
191+
token = "<extra_id_{}>".format(self.vocab_size - 1 - index)
192+
return token
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""
2+
Unit tests for T5 tokenizer.
3+
"""
4+
5+
import unittest
6+
7+
import os
8+
import tempfile
9+
10+
from texar.torch.utils.test import pretrained_test
11+
from texar.torch.data.tokenizers.t5_tokenizer import T5Tokenizer
12+
from texar.torch.data.data_utils import maybe_download
13+
14+
15+
class T5TokenizerTest(unittest.TestCase):
16+
17+
def setUp(self):
18+
self.tmp_dir = tempfile.TemporaryDirectory()
19+
self.SAMPLE_VOCAB = maybe_download(
20+
'https://github.com/google/sentencepiece/blob/master/'
21+
'python/test/test_model.model?raw=true', self.tmp_dir.name)
22+
23+
self.tokenizer = T5Tokenizer.load(self.SAMPLE_VOCAB)
24+
25+
self.tokenizer.save(self.tmp_dir.name)
26+
27+
def tearDown(self):
28+
self.tmp_dir.cleanup()
29+
30+
@pretrained_test
31+
def test_model_loading(self):
32+
for pretrained_model_name in T5Tokenizer.available_checkpoints():
33+
tokenizer = T5Tokenizer(
34+
pretrained_model_name=pretrained_model_name)
35+
36+
info = list(os.walk(tokenizer.pretrained_model_dir))
37+
_, _, files = info[0]
38+
39+
self.assertIn('sentencepiece.model', files)
40+
41+
_ = tokenizer.map_text_to_token(u"This is a test")
42+
43+
def test_roundtrip(self):
44+
tokenizer = T5Tokenizer.load(self.tmp_dir.name)
45+
46+
text = 'I saw a girl with a telescope.'
47+
ids = tokenizer.map_text_to_id(text)
48+
tokens = tokenizer.map_text_to_token(text)
49+
50+
self.assertEqual(text, tokenizer.map_id_to_text(ids))
51+
self.assertEqual(text, tokenizer.map_token_to_text(tokens))
52+
53+
text = '<extra_id_32> I saw a girl with a telescope.<extra_id_74>'
54+
ids = tokenizer.map_text_to_id(text)
55+
tokens = tokenizer.map_text_to_token(text)
56+
57+
self.assertEqual(text, tokenizer.map_id_to_text(ids))
58+
self.assertEqual(text, tokenizer.map_token_to_text(tokens))
59+
60+
61+
if __name__ == "__main__":
62+
unittest.main()

texar/torch/data/tokenizers/tokenizer_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,12 @@ def __init__(self, hparams):
8383
assert isinstance(value, (list, tuple)) and \
8484
all(isinstance(v, str) for v in value)
8585
else:
86-
assert isinstance(value, str)
86+
if value is not None:
87+
assert isinstance(value, str)
88+
else:
89+
warnings.warn(f"Trying to set None as value special "
90+
f"token '{key}'. Proceed only if you"
91+
f" are sure!", UserWarning)
8792
setattr(self, key, value)
8893

8994
@classmethod

texar/torch/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
from texar.torch.modules.networks import *
2424
from texar.torch.modules.pretrained import *
2525
from texar.torch.modules.regressors import *
26+
from texar.torch.modules.encoder_decoders import *

texar/torch/modules/decoders/t5_decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030

3131

3232
class T5Decoder(TransformerDecoder):
33-
r"""T5 decoder that applies multi-head self-attention with #todo rpr for
34-
sequence decoding.
33+
r"""T5 decoder that applies multi-head self-attention with relative
34+
position representation for sequence decoding.
3535
3636
It is a stack of
3737
:class:`~texar.torch.modules.MultiheadRPRAttention`,

texar/torch/modules/encoder_decoders/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@
1515
Modules of Texar library encoders.
1616
"""
1717

18-
from texar.torch.modules.encoder_decoders.t5_encoder_decoder \
19-
import T5EncoderDecoder
18+
from texar.torch.modules.encoder_decoders.encoder_decoder_base import *
19+
from texar.torch.modules.encoder_decoders.t5_encoder_decoder import *

0 commit comments

Comments
 (0)