@@ -282,6 +282,57 @@ def encode_text(self,
282282
283283 return input_ids , segment_ids , input_mask
284284
285+ def encode_text_for_generation (
286+ self ,
287+ text : str ,
288+ max_seq_length : Optional [int ] = None ,
289+ append_eos_token : bool = True ) -> Tuple [List [int ], int ]:
290+ r"""Adds special tokens to a sequence and computes the corresponding
291+ sequence length for XLNet specific tasks. The sequence will be truncated
292+ if its length is larger than ``max_seq_length``.
293+
294+ A XLNet sequence has the following format:
295+ `[bos_token]` X `[eos_token]` `[pad_token]`
296+
297+ Args:
298+ text: Input text.
299+ max_seq_length: Maximum sequence length.
300+ append_eos_token: Whether to append ``eos_token`` after the
301+ sequence.
302+
303+ Returns:
304+ A tuple of `(input_ids, seq_len)`, where
305+
306+ - ``input_ids``: A list of input token ids with added
307+ special tokens.
308+ - ``seq_len``: The sequence length.
309+ """
310+ if max_seq_length is None :
311+ max_seq_length = self .max_len
312+
313+ token_ids = self .map_text_to_id (text )
314+ assert isinstance (token_ids , list )
315+
316+ bos_token_id = self ._map_token_to_id (self .bos_token )
317+ eos_token_id = self ._map_token_to_id (self .eos_token )
318+ pad_token_id = self ._map_token_to_id (self .pad_token )
319+
320+ if append_eos_token :
321+ input_ids = token_ids [:max_seq_length - 2 ]
322+ input_ids = [bos_token_id ] + input_ids + [eos_token_id ]
323+ else :
324+ input_ids = token_ids [:max_seq_length - 1 ]
325+ input_ids = [bos_token_id ] + input_ids
326+
327+ seq_len = len (input_ids )
328+
329+ # Pad up to the maximum sequence length.
330+ input_ids = input_ids + [pad_token_id ] * (max_seq_length - seq_len )
331+
332+ assert len (input_ids ) == max_seq_length
333+
334+ return input_ids , seq_len
335+
285336 @staticmethod
286337 def default_hparams () -> Dict [str , Any ]:
287338 r"""Returns a dictionary of hyperparameters with default values.
0 commit comments