From 7a4619eb22c23449e05ddc89747170468baf4e2c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 7 Dec 2025 13:44:38 -0800 Subject: [PATCH 1/3] Trying out differential + parallel block, and fused out project option for ParallelScalingBlock to see if it's worthwhile now... --- timm/models/vision_transformer.py | 209 ++++++++++++++++++++++++++++-- 1 file changed, 200 insertions(+), 9 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b0d1b24298..5f5c5a6535 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -303,8 +303,9 @@ def __init__( mlp_layer: Optional[Type[nn.Module]] = None, # not used attn_layer: Optional[LayerType] = None, # not used depth: int = 0, # not used - device = None, - dtype = None, + fuse_out_proj: bool = False, + device=None, + dtype=None, ) -> None: super().__init__() dd = {'device': device, 'dtype': dtype} @@ -330,11 +331,20 @@ def __init__( self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias, **dd) self.mlp_drop = nn.Dropout(proj_drop) self.mlp_act = act_layer() - self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias, **dd) + + if fuse_out_proj: + # Fused output projection for both attention and MLP + self.out_proj = nn.Linear(dim + mlp_hidden_dim, dim, bias=proj_bias, **dd) + self.attn_out_proj = None + self.mlp_out_proj = None + else: + # Separate output projections + self.out_proj = None + self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias, **dd) + self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias, **dd) self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -371,16 +381,184 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> x_attn = attn @ v x_attn = x_attn.transpose(1, 2).reshape(B, N, C) - x_attn = self.attn_out_proj(x_attn) - # MLP activation, dropout, fc2 + # MLP activation & dropout x_mlp = self.mlp_act(x_mlp) x_mlp = self.mlp_drop(x_mlp) - x_mlp = self.mlp_out_proj(x_mlp) + + # Output projection (fused or separate) + if self.out_proj is not None: + y = self.out_proj(torch.cat((x_attn, x_mlp), dim=-1)) + else: + y = self.attn_out_proj(x_attn) + self.mlp_out_proj(x_mlp) # Add residual w/ drop path & layer scale applied - y = self.drop_path(self.ls(x_attn + x_mlp)) - x = x + y + x = x + self.drop_path(self.ls(y)) + return x + + +class DiffParallelScalingBlock(nn.Module): + """ Parallel ViT block with Differential Attention (MLP & Attention in parallel). + + Combines the parallel MLP+Attention structure from 'Scaling Vision Transformers to + 22 Billion Parameters' (https://arxiv.org/abs/2302.05442) with differential attention + from 'Differential Transformer' (https://arxiv.org/abs/2410.05258). + """ + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, + proj_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = LayerNorm, + mlp_layer: Optional[Type[nn.Module]] = None, + attn_layer: Optional[LayerType] = None, + depth: int = 0, + dual_lambda: bool = False, + device=None, + dtype=None, + ) -> None: + super().__init__() + dd = {'device': device, 'dtype': dtype} + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported' + self.num_heads = num_heads + self.head_dim = dim // num_heads // 2 # Half head_dim for diff attention + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + mlp_hidden_dim = int(mlp_ratio * dim) + in_proj_out_dim = mlp_hidden_dim + 3 * dim + + self.in_norm = norm_layer(dim, **dd) + self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias, **dd) + self.in_split = [mlp_hidden_dim] + [dim] * 3 + if qkv_bias: + self.register_buffer('qkv_bias', None) + self.register_parameter('mlp_bias', None) + else: + self.register_buffer('qkv_bias', torch.zeros(3 * dim, **dd), persistent=False) + self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim, **dd)) + + self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.attn_drop_p = attn_drop + + # Differential attention specific + self.sub_norm = RmsNorm(2 * self.head_dim, eps=1e-5, **dd) + self.dual_lambda = dual_lambda + if dual_lambda: + self.lambda_a = nn.Parameter(torch.empty((), dtype=torch.float32, device=device)) + self.lambda_b = nn.Parameter(torch.empty((), dtype=torch.float32, device=device)) + self.lambda_q1 = self.lambda_k1 = self.lambda_q2 = self.lambda_k2 = None + else: + self.lambda_a = self.lambda_b = None + self.lambda_q1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) + self.lambda_k1 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) + self.lambda_q2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) + self.lambda_k2 = nn.Parameter(torch.empty(self.head_dim, dtype=torch.float32, device=device)) + + self.mlp_drop = nn.Dropout(proj_drop) + self.mlp_act = act_layer() + + # Fused output projection for both attention and MLP + self.out_proj = nn.Linear(dim + mlp_hidden_dim, dim, bias=proj_bias, **dd) + + self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.lambda_init = 0.8 + self.set_lambda_init(depth) + self.reset_parameters() + + def set_lambda_init(self, depth: int): + self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * depth) + + def reset_parameters(self): + if self.dual_lambda: + nn.init.zeros_(self.lambda_a) + nn.init.zeros_(self.lambda_b) + else: + nn.init.normal_(self.lambda_q1, mean=0, std=0.1) + nn.init.normal_(self.lambda_k1, mean=0, std=0.1) + nn.init.normal_(self.lambda_q2, mean=0, std=0.1) + nn.init.normal_(self.lambda_k2, mean=0, std=0.1) + + def _compute_lambda(self) -> torch.Tensor: + if self.lambda_a is not None: + lambda_1 = torch.exp(self.lambda_a) + lambda_2 = torch.exp(self.lambda_b) + else: + lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()) + lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()) + return lambda_1 - lambda_2 + self.lambda_init + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + B, N, C = x.shape + + # Combined MLP fc1 & qkv projections + y = self.in_norm(x) + if self.mlp_bias is not None: + y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias))) + else: + y = self.in_proj(y) + x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1) + + # Reshape for differential attention (2x heads with half head_dim for q/k) + q = q.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(B, N, 2 * self.num_heads, self.head_dim).transpose(1, 2) + v = v.reshape(B, N, self.num_heads, 2 * self.head_dim).transpose(1, 2) + + q, k = self.q_norm(q), self.k_norm(k) + + lambda_full = self._compute_lambda().type_as(q) + + if self.fused_attn: + q = q.reshape(B, self.num_heads, 2, N, self.head_dim) + k = k.reshape(B, self.num_heads, 2, N, self.head_dim) + q1, q2 = q.unbind(2) + k1, k2 = k.unbind(2) + + dropout_p = self.attn_drop_p if self.training else 0.0 + attn1 = F.scaled_dot_product_attention(q1, k1, v, attn_mask=attn_mask, dropout_p=dropout_p) + attn2 = F.scaled_dot_product_attention(q2, k2, v, attn_mask=attn_mask, dropout_p=dropout_p) + + x_attn = attn1 - lambda_full * attn2 + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = maybe_add_mask(attn, attn_mask) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + attn = attn.view(B, self.num_heads, 2, N, N) + attn = attn[:, :, 0] - lambda_full * attn[:, :, 1] + x_attn = attn @ v + + x_attn = self.sub_norm(x_attn) + x_attn = x_attn * (1 - self.lambda_init) + x_attn = x_attn.transpose(1, 2).reshape(B, N, C) + + # MLP activation & dropout + x_mlp = self.mlp_act(x_mlp) + x_mlp = self.mlp_drop(x_mlp) + + # Fused output projection + y = self.out_proj(torch.cat((x_attn, x_mlp), dim=-1)) + + # Add residual w/ drop path & layer scale applied + x = x + self.drop_path(self.ls(y)) return x @@ -2528,6 +2706,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_dpwee_patch16_reg1_gap_256.sbb_in1k': _cfg( + #hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), 'vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), @@ -4038,6 +4219,16 @@ def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionT return model +@register_model +def vit_dpwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=DiffParallelScalingBlock, + ) + model = _create_vision_transformer( + 'vit_dpwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + @register_model def vit_little_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( From ab2f4b713b77b7f7e9f89be7c4e84b0d392870f2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 8 Dec 2025 09:33:07 -0800 Subject: [PATCH 2/3] Add dwee model def, and weights for dwee and dpwee --- timm/models/vision_transformer.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 5f5c5a6535..7243082d13 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -2703,11 +2703,14 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_dwee_patch16_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), 'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), 'vit_dpwee_patch16_reg1_gap_256.sbb_in1k': _cfg( - #hf_hub_id='timm/', + hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), 'vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k': _cfg( hf_hub_id='timm/', @@ -4208,6 +4211,17 @@ def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTr return model +@register_model +def vit_dwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', attn_layer='diff', + ) + model = _create_vision_transformer( + 'vit_dwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( @@ -4229,6 +4243,7 @@ def vit_dpwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> Vision 'vit_dpwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) return model + @register_model def vit_little_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( From a5f40c9055b39915711d751c998b1817bc2fa9f1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 8 Dec 2025 09:33:28 -0800 Subject: [PATCH 3/3] Add weights_only=False to avg_checkpoints, needs to load training state dict --- avg_checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/avg_checkpoints.py b/avg_checkpoints.py index 6cedcb7da5..821e25b5c2 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -47,7 +47,7 @@ def checkpoint_metric(checkpoint_path): if not checkpoint_path or not os.path.isfile(checkpoint_path): return {} print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path)) - checkpoint = torch.load(checkpoint_path, map_location='cpu') + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) metric = None if 'metric' in checkpoint: metric = checkpoint['metric']