-
Notifications
You must be signed in to change notification settings - Fork 2
Description
Hi, Thanks for your work, I previously encountered an issue where I couldn't directly use PyG's GNNExplainer to explain RGCN, and it would throw the same error as the one you encountered in the gnnexplainer.ipynb. Have you found a solution to this problem when trying to explain RGCN?
AssertionError Traceback (most recent call last)
Input In [22], in <cell line: 22>()
10 # model_args = (
11 # x_cell_mut,
12 # batch_drug,
13 # edge_features
14 # )
16 kwargs = {
17 "x_cell_mut": x_cell_mut,
18 "batch_drug": batch_drug,
19 "edge_feat": edge_features
20 }
---> 22 node_feature_mask, edge_mask = explainer.explain_graph(x = x, edge_index = edge_index, x_cell_mut = x_cell_mut, edge_feat = edge_features)
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch_geometric/nn/models/gnn_explainer.py:165, in GNNExplainer.explain_graph(self, x, edge_index, **kwargs)
163 print('debug h', h.size())
164 print('debug edge', edge_index.size())
--> 165 out = self.model(x=h, edge_index=edge_index, batch=batch, **kwargs)
166 loss = self.get_loss(out, prediction, None)
167 loss.backward()
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File /data/conghao001/FYP/GCN_Drug/models.py:749, in RGCNNet.forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight)
741 def forward(self, x, edge_index, batch, x_cell_mut, edge_feat, edge_weight=None):
742 # get graph input
743 # edge_weight is only used for decoding
744
745 # x, edge_index, batch = data.x, data.edge_index, data.batch
746 # edge_index = edge_index.long()
747 edge_feat = edge_feat.squeeze()
--> 749 x = self.conv1(x, edge_index, edge_type=edge_feat)
750 x = self.relu(x)
751 x = self.conv2(x, edge_index, edge_type=edge_feat)
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch_geometric/nn/conv/rgcn_conv.py:218, in RGCNConv.forward(self, x, edge_index, edge_type)
216 out += self.propagate(tmp, x=weight[i, x_l], size=size)
217 else:
--> 218 h = self.propagate(tmp, x=x_l, size=size)
219 out = out + (h @ weight[i])
221 root = self.root
File /data/conghao001/anaconda3/envs/gnndrug/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py:338, in MessagePassing.propagate(self, edge_index, size, **kwargs)
336 edge_mask = torch.cat([edge_mask, loop], dim=0)
337 print(out.size(self.node_dim), edge_mask.size(0))
--> 338 assert out.size(self.node_dim) == edge_mask.size(0)
339 out = out * edge_mask.view([-1] + [1] * (out.dim() - 1))
341 aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
AssertionError: