diff --git a/model_dd.py b/model_dd.py index 41a632b..722b236 100644 --- a/model_dd.py +++ b/model_dd.py @@ -59,17 +59,10 @@ def __init__(self, cpt_ids, Configs, cuda_=True): if self.use_future: self.relAtt = RelAtt(1, 1, (self.window, self.input_dim), heads=self.num_head, dim_head=self.input_dim, dropout=Configs.att_dropout) - # self.relAtt = RelAtt(self.window, 1, (1, self.input_dim), heads=self.num_head, dim_head=self.input_dim, - # dropout=Configs.att_dropout) + else: self.relAtt = RelAtt(1, 1, (self.slide_win+1, self.input_dim), heads=self.num_head, dim_head=self.input_dim, dropout=Configs.att_dropout) - # self.relAtt = RelAtt(self.slide_win+1, 1, (1, self.input_dim), heads=self.num_head, dim_head=self.input_dim, - # dropout=Configs.att_dropout) - - - # self.relAtt = Trans_RelAtt(1, 1, (self.window, self.input_dim), heads=self.num_head, dim_head=self.input_dim // 2, - # dropout=Configs.att_dropout) self.r = nn.Parameter(nn.init.uniform_(torch.zeros(3, self.input_dim)), requires_grad=True) self.num_feature = Configs.num_features @@ -121,9 +114,7 @@ def __init__(self, cpt_ids, Configs, cuda_=True): self.input_dim)) # nn.ParameterList([nn.Parameter(torch.randn(self.input_dim, self.input_dim)) for _ in range(3)]) def forward(self, inputs, str_src, str_dst, str_edge_type, chunks, label, loss_func, train=True, eps=1e-8): - # torch.autograd.set_detect_anomaly(True) - - # len_dial = len(inputs['input_ids']) + if self.model_type == 'albert': out = self.bert_encoder(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], token_type_ids=inputs['token_type_ids']) @@ -196,8 +187,6 @@ def forward(self, inputs, str_src, str_dst, str_edge_type, chunks, label, loss_f src_mask = torch.sum(masks, dim=-1) > 0 att_score = torch.softmax(self.get_att_masked(dot_sum, src_mask), dim=-1) * src_masks.ne(0) - # att_score = torch.softmax(dot_sum, dim=-1) * src_mask - # sent_mask_sum = torch.sum(src_masks.sum(dim=-1).ne(0)) + eps symbolic_repr = torch.sum(att_score.unsqueeze(2) * src_emb, dim=1) # /sent_mask_sum @@ -223,52 +212,13 @@ def forward(self, inputs, str_src, str_dst, str_edge_type, chunks, label, loss_f # feature fusion if self.num_feature == 3: - # feat = torch.cat((out_[utt_idx], hidden_rgcn[utt_idx], - # relatt_out[utt_idx], symbolic_repr), dim=-1) feat_ = torch.stack([out_[utt_idx], hidden_rgcn[utt_idx], symbolic_repr], dim=1).unsqueeze(2) feat = self.CoAtt(feat_).squeeze(1).squeeze(1) output = torch.log_softmax(self.linear(self.ac_tanh(self.dropout(self.linear_out(feat)))), dim=1) + + else: - # if self.use_layer_norm: - # output = torch.log_softmax(self.linear(self.ac_tanh(self.layer_norm(self.dropout(self.fusion(feat))))), dim=1) - # else: - # output = torch.log_softmax(self.linear(self.ac_tanh(self.dropout(self.fusion(feat)))), dim=1) - - elif self.num_feature == 4: - feat_l = torch.cat((out_[utt_idx], hidden_rgcn[utt_idx], symbolic_repr), dim=-1) - - feat_ = torch.stack([out_[utt_idx], hidden_rgcn[utt_idx], symbolic_repr], dim=1).unsqueeze(2) - feat = self.CoAtt(feat_).squeeze(1).squeeze(1) - - feat_x = torch.cat((feat_l, feat), dim=-1) - - output = torch.log_softmax(self.linear(self.ac_tanh(self.dropout(self.linear_out(feat_x)))), dim=1) - - # if self.use_layer_norm: - # output = torch.log_softmax(self.linear(self.ac_tanh(self.layer_norm(self.dropout(self.fusion(feat))))), dim=1) - # else: - # output = torch.log_softmax(self.linear(self.ac_tanh(self.dropout(self.fusion(feat)))), dim=1) - elif self.num_feature == 2: - # feat = torch.cat((out_[utt_idx], hidden_rgcn[utt_idx]), dim=-1) - - feat_ = torch.stack([out_[utt_idx], hidden_rgcn[utt_idx]], dim=1).unsqueeze(2) - feat = self.CoAtt(feat_).squeeze(1).squeeze(1) - - output = torch.log_softmax(self.linear(self.ac(self.dropout(self.linear_out(feat)))), dim=1) - - # if self.use_layer_norm: - # output = torch.log_softmax(self.linear(self.ac_tanh(self.layer_norm(self.dropout(self.fusion(feat))))), dim=1) - # else: - # output = torch.log_softmax(self.linear(self.ac_tanh(self.dropout(self.fusion(feat)))), dim=1) - elif self.num_feature == 1: - feat = out_[utt_idx] - if self.use_layer_norm: - output = torch.log_softmax(self.linear(self.ac_tanh(self.layer_norm(self.dropout(self.fusion(feat))))), dim=1) - else: - output = torch.log_softmax(self.linear(self.ac_tanh(self.dropout(self.fusion(feat)))), dim=1) - else: - # feat = out_[utt_idx] + hidden_rgcn[utt_idx] + relatt_out[utt_idx] + symbolic_repr feat = out_[utt_idx] + hidden_rgcn[utt_idx] + symbolic_repr if self.use_layer_norm: output = torch.log_softmax(self.linear_2(self.ac_tanh(self.layer_norm(self.dropout(self.fusion_2(feat))))), dim=1)