@@ -1857,7 +1857,7 @@ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
18571857 _alignment_history else ()
18581858 for alignment in initial_alignments ))
18591859
1860- def call (self , inputs , state ):
1860+ def call (self , inputs , state , ** kwargs ):
18611861 """Perform a step of attention-wrapped RNN.
18621862
18631863 - Step 1: Mix the `inputs` and previous step's `attention` output via
@@ -1878,6 +1878,7 @@ def call(self, inputs, state):
18781878 step.
18791879 state: An instance of `AttentionWrapperState` containing
18801880 tensors from the previous time step.
1881+ **kwargs: Dict, other keyword arguments for the cell call method.
18811882
18821883 Returns:
18831884 A tuple `(attention_or_cell_output, next_state)`, where:
@@ -1898,7 +1899,8 @@ def call(self, inputs, state):
18981899 # previous attention value.
18991900 cell_inputs = self ._cell_input_fn (inputs , state .attention )
19001901 cell_state = state .cell_state
1901- cell_output , next_cell_state = self ._cell (cell_inputs , cell_state )
1902+ cell_output , next_cell_state = self ._cell (
1903+ cell_inputs , cell_state , ** kwargs )
19021904
19031905 cell_batch_size = (tf .compat .dimension_value (cell_output .shape [0 ])
19041906 or tf .shape (cell_output )[0 ])
0 commit comments