Skip to content

Commit c804ca8

Browse files
guillaumeklnseanpmorgan
authored andcommitted
Pass kwargs to wrapped cell in AttentionWrapper (#272)
1 parent 84bb63f commit c804ca8

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tensorflow_addons/seq2seq/attention_wrapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,7 +1857,7 @@ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
18571857
_alignment_history else ()
18581858
for alignment in initial_alignments))
18591859

1860-
def call(self, inputs, state):
1860+
def call(self, inputs, state, **kwargs):
18611861
"""Perform a step of attention-wrapped RNN.
18621862
18631863
- Step 1: Mix the `inputs` and previous step's `attention` output via
@@ -1878,6 +1878,7 @@ def call(self, inputs, state):
18781878
step.
18791879
state: An instance of `AttentionWrapperState` containing
18801880
tensors from the previous time step.
1881+
**kwargs: Dict, other keyword arguments for the cell call method.
18811882
18821883
Returns:
18831884
A tuple `(attention_or_cell_output, next_state)`, where:
@@ -1898,7 +1899,8 @@ def call(self, inputs, state):
18981899
# previous attention value.
18991900
cell_inputs = self._cell_input_fn(inputs, state.attention)
19001901
cell_state = state.cell_state
1901-
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
1902+
cell_output, next_cell_state = self._cell(
1903+
cell_inputs, cell_state, **kwargs)
19021904

19031905
cell_batch_size = (tf.compat.dimension_value(cell_output.shape[0])
19041906
or tf.shape(cell_output)[0])

0 commit comments

Comments
 (0)