diff --git a/visualize/baselines/ViT/LVViT_LRP.py b/visualize/baselines/ViT/LVViT_LRP.py index 5bce215..bdcc2b8 100644 --- a/visualize/baselines/ViT/LVViT_LRP.py +++ b/visualize/baselines/ViT/LVViT_LRP.py @@ -3,6 +3,7 @@ """ import torch import torch.nn as nn +import numpy as np from einops import rearrange from modules.layers_ours import * @@ -42,7 +43,7 @@ def compute_rollout_attention(all_layer_matrices, start_layer=0): num_tokens = all_layer_matrices[0].shape[1] batch_size = all_layer_matrices[0].shape[0] eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device) - all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))] + all_layer_matrices = np.add(all_layer_matrices, eye).tolist() # all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True) # for i in range(len(all_layer_matrices))] joint_attention = all_layer_matrices[start_layer]