Skip to content

Commit b2daf1d

Browse files
gpengzhiZhitingHu
andauthored
Fix docs issues (#290)
* Fix a few doc issues - expose `DecoderBase` in doc - delete unnecessary code in examples/gpt-2 - update doc of TransformerDecoder and RNNDecoderBase * Make CI happy * Make CI happy again Co-authored-by: Zhiting Hu <zhitinghu@gmail.com>
1 parent 6906119 commit b2daf1d

File tree

4 files changed

+32
-3
lines changed

4 files changed

+32
-3
lines changed

docs/code/modules.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ Decoders
107107
Bahdanau
108108
Gumbel
109109

110+
:hidden:`DecoderBase`
111+
~~~~~~~~~~~~~~~~~~~~~~~~
112+
.. autoclass:: texar.torch.modules.DecoderBase
113+
:members:
114+
110115
:hidden:`RNNDecoderBase`
111116
~~~~~~~~~~~~~~~~~~~~~~~~
112117
.. autoclass:: texar.torch.modules.RNNDecoderBase

examples/gpt-2/prepare_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
help="The output directory where the pickle files will be generated. "
3636
"By default it is set to be the same as `--data-dir`.")
3737
parser.add_argument(
38-
"--pretrained-model-name", type=str, default="gpt2-small",
38+
'--pretrained-model-name', type=str, default='gpt2-small',
3939
choices=tx.modules.GPT2Decoder.available_checkpoints(),
4040
help="Name of the pre-trained checkpoint to load.")
4141
parser.add_argument(

texar/torch/modules/decoders/rnn_decoder_base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,15 @@ def forward(self, # type: ignore
107107
<https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/dynamic_decode>`_.
108108
109109
See Also:
110-
Arguments of :meth:`create_helper`.
110+
Arguments of :meth:`create_helper`, for arguments like
111+
:attr:`decoding_strategy`.
111112
112113
Args:
113114
inputs (optional): Input tensors for teacher forcing decoding.
114115
Used when :attr:`decoding_strategy` is set to
115116
``"train_greedy"``, or when `hparams`-configured helper is used.
116117
117-
The attr:`inputs` is a :tensor:`LongTensor` used as index to
118+
The :attr:`inputs` is a :tensor:`LongTensor` used as index to
118119
look up embeddings and feed in the decoder. For example, if
119120
:attr:`embedder` is an instance of
120121
:class:`~texar.torch.modules.WordEmbedder`, then :attr:`inputs`
@@ -143,6 +144,10 @@ def forward(self, # type: ignore
143144
that defines the decoding strategy. If given,
144145
``decoding_strategy`` and helper configurations in
145146
:attr:`hparams` are ignored.
147+
148+
:meth:`create_helper` can be used to create some of the common
149+
helpers for, e.g., teacher-forcing decoding, greedy decoding,
150+
sample decoding, etc.
146151
infer_mode (optional): If not `None`, overrides mode given by
147152
`self.training`.
148153
**kwargs: Other keyword arguments for constructing helpers

texar/torch/modules/decoders/transformer_decoders.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,25 @@ def forward(self, # type: ignore
451451
:attr:`hparams` are ignored.
452452
infer_mode (optional): If not `None`, overrides mode given by
453453
:attr:`self.training`.
454+
**kwargs (optional, dict): Other keyword arguments. Typically ones
455+
such as:
456+
457+
- **start_tokens**: A :tensor:`LongTensor` of shape
458+
``[batch_size]``, the start tokens.
459+
Used when :attr:`decoding_strategy` is ``"infer_greedy"`` or
460+
``"infer_sample"`` or when :attr:`beam_search` is set.
461+
Ignored when :attr:`context` is set.
462+
463+
When used with the Texar data module, to get ``batch_size``
464+
samples where ``batch_size`` is changing according to the
465+
data module, this can be set as
466+
:python:`start_tokens=torch.full_like(batch['length'],
467+
bos_token_id)`.
468+
469+
- **end_token**: An integer or 0D :tensor:`LongTensor`, the
470+
token that marks the end of decoding.
471+
Used when :attr:`decoding_strategy` is ``"infer_greedy"`` or
472+
``"infer_sample"``, or when :attr:`beam_search` is set.
454473
455474
Returns:
456475

0 commit comments

Comments
 (0)