|
x_res = x |
|
cls_token = x[:, 0:1] |
|
cls_token = self.gamma2 * self.mlp(cls_token) |
|
x = torch.cat([cls_token, x[:, 1:]], dim=1) |
|
x = x_res + self.drop_path(x) |
Hi! Thank you for your great work! According to CaiT, I think the code should be in the following form:
cls_token = x[:, 0:1] + self.drop_path(self.gamma2 * self.mlp(x[:, 0:1] ))
x = torch.cat([cls_token, x[:, 1:], dim=1)
FAN/models/fan.py
Lines 311 to 315 in ee1b7df
Hi! Thank you for your great work! According to CaiT, I think the code should be in the following form:
cls_token = x[:, 0:1] + self.drop_path(self.gamma2 * self.mlp(x[:, 0:1] ))x = torch.cat([cls_token, x[:, 1:], dim=1)