Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion avg_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def checkpoint_metric(checkpoint_path):
if not checkpoint_path or not os.path.isfile(checkpoint_path):
return {}
print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path, map_location='cpu')
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
metric = None
if 'metric' in checkpoint:
metric = checkpoint['metric']
Expand Down
224 changes: 215 additions & 9 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,9 @@ def __init__(
mlp_layer: Optional[Type[nn.Module]] = None, # not used
attn_layer: Optional[LayerType] = None, # not used
depth: int = 0, # not used
device = None,
dtype = None,
fuse_out_proj: bool = False,
device=None,
dtype=None,
) -> None:
super().__init__()
dd = {'device': device, 'dtype': dtype}
Expand All @@ -330,11 +331,20 @@ def __init__(
self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias, **dd)

self.mlp_drop = nn.Dropout(proj_drop)
self.mlp_act = act_layer()
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias, **dd)

if fuse_out_proj:
# Fused output projection for both attention and MLP
self.out_proj = nn.Linear(dim + mlp_hidden_dim, dim, bias=proj_bias, **dd)
self.attn_out_proj = None
self.mlp_out_proj = None
else:
# Separate output projections
self.out_proj = None
self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias, **dd)

self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
Expand Down Expand Up @@ -371,16 +381,184 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) ->
x_attn = attn @ v

x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
x_attn = self.attn_out_proj(x_attn)

# MLP activation, dropout, fc2
# MLP activation & dropout
x_mlp = self.mlp_act(x_mlp)
x_mlp = self.mlp_drop(x_mlp)
x_mlp = self.mlp_out_proj(x_mlp)

# Output projection (fused or separate)
if self.out_proj is not None:
y = self.out_proj(torch.cat((x_attn, x_mlp), dim=-1))
else:
y = self.attn_out_proj(x_attn) + self.mlp_out_proj(x_mlp)

# Add residual w/ drop path & layer scale applied
y = self.drop_path(self.ls(x_attn + x_mlp))
x = x + y
x = x + self.drop_path(self.ls(y))
return x


class DiffParallelScalingBlock(nn.Module):
""" Parallel ViT block with Differential Attention (MLP & Attention in parallel).

Combines the parallel MLP+Attention structure from 'Scaling Vision Transformers to
22 Billion Parameters' (https://arxiv.org/abs/2302.05442) with differential attention
from 'Differential Transformer' (https://arxiv.org/abs/2410.05258).
"""
fused_attn: Final[bool]

def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.,
qkv_bias: bool = False,
qk_norm: bool = False,
scale_attn_norm: bool = False,
scale_mlp_norm: bool = False,
proj_bias: bool = True,
proj_drop: float = 0.,
attn_drop: float = 0.,
init_values: Optional[float] = None,
drop_path: float = 0.,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = LayerNorm,
mlp_layer: Optional[Type[nn.Module]] = None,
attn_layer: Optional[LayerType] = None,
depth: int = 0,
dual_lambda: bool = False,
device=None,
dtype=None,
) -> None:
super().__init__()
dd = {'device': device, 'dtype': dtype}
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported'
self.num_heads = num_heads
self.head_dim = dim // num_heads // 2 # Half head_dim for diff attention
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()
mlp_hidden_dim = int(mlp_ratio * dim)
in_proj_out_dim = mlp_hidden_dim + 3 * dim

self.in_norm = norm_layer(dim, **dd)
self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias, **dd)
self.in_split = [mlp_hidden_dim] + [dim] * 3
if qkv_bias:
self.register_buffer('qkv_bias', None)
self.register_parameter('mlp_bias', None)
else:
self.register_buffer('qkv_bias', torch.zeros(3 * dim, **dd), persistent=False)
self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim, **dd))

self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.attn_drop_p = attn_drop

# Differential attention specific
self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5, **dd)
self.dual_lambda = dual_lambda
if dual_lambda:
self.lambda_a = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
self.lambda_b = nn.Parameter(torch.empty((), dtype=torch.float32, device=device))
self.lambda_q1 = self.lambda_k1 = self.lambda_q2 = self.lambda_k2 = None
else:
self.lambda_a = self.lambda_b = None
self.lambda_q1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
self.lambda_k1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
self.lambda_q2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))
self.lambda_k2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device))

self.mlp_drop = nn.Dropout(proj_drop)
self.mlp_act = act_layer()

# Fused output projection for both attention and MLP
self.out_proj = nn.Linear(dim + mlp_hidden_dim, dim, bias=proj_bias, **dd)

self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

self.lambda_init = 0.8
self.set_lambda_init(depth)
self.reset_parameters()

def set_lambda_init(self, depth: int):
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth)

def reset_parameters(self):
if self.dual_lambda:
nn.init.zeros_(self.lambda_a)
nn.init.zeros_(self.lambda_b)
else:
nn.init.normal_(self.lambda_q1, mean=0, std=0.1)
nn.init.normal_(self.lambda_k1, mean=0, std=0.1)
nn.init.normal_(self.lambda_q2, mean=0, std=0.1)
nn.init.normal_(self.lambda_k2, mean=0, std=0.1)

def _compute_lambda(self) -> torch.Tensor:
if self.lambda_a is not None:
lambda_1 = torch.exp(self.lambda_a)
lambda_2 = torch.exp(self.lambda_b)
else:
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float())
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float())
return lambda_1 - lambda_2 + self.lambda_init

def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, N, C = x.shape

# Combined MLP fc1 & qkv projections
y = self.in_norm(x)
if self.mlp_bias is not None:
y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias)))
else:
y = self.in_proj(y)
x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)

# Reshape for differential attention (2x heads with half head_dim for q/k)
q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2)
v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2)

q, k = self.q_norm(q), self.k_norm(k)

lambda_full = self._compute_lambda().type_as(q)

if self.fused_attn:
q = q.reshape(B, self.num_heads, 2, N, self.head_dim)
k = k.reshape(B, self.num_heads, 2, N, self.head_dim)
q1, q2 = q.unbind(2)
k1, k2 = k.unbind(2)

dropout_p = self.attn_drop_p if self.training else 0.0
attn1 = F.scaled_dot_product_attention(q1, k1, v, attn_mask=attn_mask, dropout_p=dropout_p)
attn2 = F.scaled_dot_product_attention(q2, k2, v, attn_mask=attn_mask, dropout_p=dropout_p)

x_attn = attn1 - lambda_full * attn2
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = maybe_add_mask(attn, attn_mask)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

attn = attn.view(B, self.num_heads, 2, N, N)
attn = attn[:, :, 0] - lambda_full * attn[:, :, 1]
x_attn = attn @ v

x_attn = self.sub_norm(x_attn)
x_attn = x_attn * (1 - self.lambda_init)
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)

# MLP activation & dropout
x_mlp = self.mlp_act(x_mlp)
x_mlp = self.mlp_drop(x_mlp)

# Fused output projection
y = self.out_proj(torch.cat((x_attn, x_mlp), dim=-1))

# Add residual w/ drop path & layer scale applied
x = x + self.drop_path(self.ls(y))
return x


Expand Down Expand Up @@ -2525,9 +2703,15 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_dwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_dpwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
Expand Down Expand Up @@ -4027,6 +4211,17 @@ def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTr
return model


@register_model
def vit_dwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', attn_layer='diff',
)
model = _create_vision_transformer(
'vit_dwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
Expand All @@ -4038,6 +4233,17 @@ def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionT
return model


@register_model
def vit_dpwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5,
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=DiffParallelScalingBlock,
)
model = _create_vision_transformer(
'vit_dpwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_little_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
Expand Down