@@ -78,7 +78,6 @@ class DDParserTask(Task):
7878 Args:
7979 task(string): The name of task.
8080 model(string): The model name in the task.
81- static_mode(bool): The flag to control in the static/dygraph mode.
8281 tree(bool): Ensure the output conforms to the tree structure.
8382 prob(bool): Whether to return the probability of predicted heads.
8483 use_pos(bool): Whether to return the postag.
@@ -167,7 +166,8 @@ def _construct_model(self, model):
167166 n_rels = len (self .rel_vocab ),
168167 n_words = len (self .word_vocab ),
169168 pad_index = self .word_pad_index ,
170- eos_index = self .word_eos_index , )
169+ bos_index = self .word_bos_index ,
170+ eos_index = self .word_eos_index ,)
171171 # Load the model parameter for the predict
172172 state_dict = paddle .load (
173173 os .path .join (self ._task_path , self .model , "model.pdparams" ))
@@ -249,15 +249,12 @@ def _run_model(self, inputs):
249249 self .input_handles [0 ].copy_from_cpu (words )
250250 self .input_handles [1 ].copy_from_cpu (wp )
251251 self .predictor .run ()
252- s_arc = self .output_handle [0 ].copy_to_cpu ()
253- s_rel = self .output_handle [1 ].copy_to_cpu ()
254- words = self .output_handle [2 ].copy_to_cpu ()
252+ arc_preds = self .output_handle [0 ].copy_to_cpu ()
253+ rel_preds = self .output_handle [1 ].copy_to_cpu ()
254+ s_arc = self .output_handle [2 ].copy_to_cpu ()
255+ mask = self .output_handle [3 ].copy_to_cpu ().astype ('bool' )
255256
256- mask = np .logical_and (
257- np .logical_and (words != self .word_pad_index ,
258- words != self .word_bos_index ),
259- words != self .word_eos_index , )
260- arc_preds , rel_preds = decode (s_arc , s_rel , mask , self .tree )
257+ arc_preds , rel_preds = decode (arc_preds , rel_preds , s_arc , mask , self .tree )
261258
262259 arcs .extend ([arc_pred [m ] for arc_pred , m in zip (arc_preds , mask )])
263260 rels .extend ([rel_pred [m ] for rel_pred , m in zip (rel_preds , mask )])
@@ -458,16 +455,13 @@ def probability(s_arc, arc_preds):
458455 return arc_probs
459456
460457
461- def decode (s_arc , s_rel , mask , tree = True ):
462-
463- lens = np .sum (mask .astype (int ), axis = - 1 )
464- arc_preds = np .argmax (s_arc , axis = - 1 )
458+ def decode (arc_preds , rel_preds , s_arc , mask , tree ):
459+ """decode"""
460+ lens = np .sum (mask , - 1 )
465461
466462 bad = [not istree (seq [:i + 1 ]) for i , seq in zip (lens , arc_preds )]
467463 if tree and any (bad ):
468464 arc_preds [bad ] = eisner (s_arc [bad ], mask [bad ])
469-
470- rel_preds = np .argmax (s_rel , axis = - 1 )
471465 rel_preds = [
472466 rel_pred [np .arange (len (arc_pred )), arc_pred ]
473467 for arc_pred , rel_pred in zip (arc_preds , rel_preds )
@@ -704,4 +698,4 @@ def inorder_traversal(self, node):
704698
705699def istree (sequence ):
706700 """Is the sequence a project tree"""
707- return DepTree (sequence ).judge_legal ()
701+ return DepTree (sequence ).judge_legal ()
0 commit comments