Skip to content

Commit c09d901

Browse files
authored
Merge pull request #4 from linjieccc/add_taskflow_ddparser
Replace np.argmax to paddle.argmax
2 parents 22f5224 + fe352cc commit c09d901

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

paddlenlp/taskflow/dependency_parsing.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ class DDParserTask(Task):
7878
Args:
7979
task(string): The name of task.
8080
model(string): The model name in the task.
81-
static_mode(bool): The flag to control in the static/dygraph mode.
8281
tree(bool): Ensure the output conforms to the tree structure.
8382
prob(bool): Whether to return the probability of predicted heads.
8483
use_pos(bool): Whether to return the postag.
@@ -167,7 +166,8 @@ def _construct_model(self, model):
167166
n_rels=len(self.rel_vocab),
168167
n_words=len(self.word_vocab),
169168
pad_index=self.word_pad_index,
170-
eos_index=self.word_eos_index, )
169+
bos_index=self.word_bos_index,
170+
eos_index=self.word_eos_index,)
171171
# Load the model parameter for the predict
172172
state_dict = paddle.load(
173173
os.path.join(self._task_path, self.model, "model.pdparams"))
@@ -249,15 +249,12 @@ def _run_model(self, inputs):
249249
self.input_handles[0].copy_from_cpu(words)
250250
self.input_handles[1].copy_from_cpu(wp)
251251
self.predictor.run()
252-
s_arc = self.output_handle[0].copy_to_cpu()
253-
s_rel = self.output_handle[1].copy_to_cpu()
254-
words = self.output_handle[2].copy_to_cpu()
252+
arc_preds = self.output_handle[0].copy_to_cpu()
253+
rel_preds = self.output_handle[1].copy_to_cpu()
254+
s_arc = self.output_handle[2].copy_to_cpu()
255+
mask = self.output_handle[3].copy_to_cpu().astype('bool')
255256

256-
mask = np.logical_and(
257-
np.logical_and(words != self.word_pad_index,
258-
words != self.word_bos_index),
259-
words != self.word_eos_index, )
260-
arc_preds, rel_preds = decode(s_arc, s_rel, mask, self.tree)
257+
arc_preds, rel_preds = decode(arc_preds, rel_preds, s_arc, mask, self.tree)
261258

262259
arcs.extend([arc_pred[m] for arc_pred, m in zip(arc_preds, mask)])
263260
rels.extend([rel_pred[m] for rel_pred, m in zip(rel_preds, mask)])
@@ -458,16 +455,13 @@ def probability(s_arc, arc_preds):
458455
return arc_probs
459456

460457

461-
def decode(s_arc, s_rel, mask, tree=True):
462-
463-
lens = np.sum(mask.astype(int), axis=-1)
464-
arc_preds = np.argmax(s_arc, axis=-1)
458+
def decode(arc_preds, rel_preds, s_arc, mask, tree):
459+
"""decode"""
460+
lens = np.sum(mask, -1)
465461

466462
bad = [not istree(seq[:i + 1]) for i, seq in zip(lens, arc_preds)]
467463
if tree and any(bad):
468464
arc_preds[bad] = eisner(s_arc[bad], mask[bad])
469-
470-
rel_preds = np.argmax(s_rel, axis=-1)
471465
rel_preds = [
472466
rel_pred[np.arange(len(arc_pred)), arc_pred]
473467
for arc_pred, rel_pred in zip(arc_preds, rel_preds)
@@ -704,4 +698,4 @@ def inorder_traversal(self, node):
704698

705699
def istree(sequence):
706700
"""Is the sequence a project tree"""
707-
return DepTree(sequence).judge_legal()
701+
return DepTree(sequence).judge_legal()

paddlenlp/taskflow/models/dependency_parsing_model.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@ def __init__(self,
2525
n_rels,
2626
n_words,
2727
pad_index,
28+
bos_index,
2829
eos_index,
2930
n_mlp_arc=500,
3031
n_mlp_rel=100):
3132
super(BiAffineParser, self).__init__()
3233
self.pad_index = pad_index
34+
self.bos_index = bos_index
3335
self.eos_index = eos_index
3436

3537
if encoding_model == "lstm-pe":
@@ -70,7 +72,14 @@ def forward(self, words, wp):
7072
s_arc_mask = paddle.unsqueeze(mask, 1)
7173
s_arc = s_arc * s_arc_mask + paddle.scale(
7274
paddle.cast(s_arc_mask, 'int32'), scale=1e5, bias=-1, bias_after_scale=False)
73-
return s_arc, s_rel, words
75+
76+
mask = paddle.cast(paddle.logical_and(
77+
paddle.logical_and(words != self.pad_index, words != self.bos_index),
78+
words != self.eos_index,
79+
), 'int32')
80+
arc_preds = paddle.argmax(s_arc, axis=-1)
81+
rel_preds = paddle.argmax(s_rel, axis=-1)
82+
return arc_preds, rel_preds, s_arc, mask
7483

7584

7685
class MLP(nn.Layer):
@@ -236,5 +245,4 @@ def index_sample(x, index):
236245
out = paddle.reshape(out, shape=[x_s[0], x_s[1], -1])
237246
else:
238247
out = paddle.reshape(out, shape=[x_s[0], -1])
239-
return out
240-
248+
return out

0 commit comments

Comments
 (0)