diff --git a/src/transformers/models/blt/configuration_blt.py b/src/transformers/models/blt/configuration_blt.py index 326176af5e9a..e439d849a3d8 100644 --- a/src/transformers/models/blt/configuration_blt.py +++ b/src/transformers/models/blt/configuration_blt.py @@ -335,6 +335,7 @@ def __init__( tie_word_embeddings: Optional[bool] = False, initializer_range: Optional[float] = 0.02, rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, + use_cache: Optional[bool] = False, **kwargs, ): # Basic model configuration @@ -406,16 +407,37 @@ def __init__( ) self.rope_parameters = rope_parameters + if "use_cache" not in kwargs: + kwargs["use_cache"] = use_cache + self.use_cache = kwargs["use_cache"] # Remove tie_word_embeddings from kwargs to avoid duplicate parameter error kwargs.pop("tie_word_embeddings", None) super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) +class BltTextConfig(PreTrainedConfig): + """ + Configuration class for the Blt Text component. + """ + + pass + + +class BltVisionConfig(PreTrainedConfig): + """ + Configuration class for the Blt Vision component. + """ + + pass + + __all__ = [ "BltConfig", "BltPatcherConfig", "BltLocalEncoderConfig", "BltLocalDecoderConfig", "BltGlobalTransformerConfig", + "BltTextConfig", + "BltVisionConfig", ] diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index b640a1b3db09..a1df2196f062 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from collections.abc import Callable from typing import Optional, Union @@ -27,6 +28,7 @@ import torch.nn as nn import torch.nn.functional as F +from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...generation import GenerationMixin @@ -45,6 +47,8 @@ BltLocalDecoderConfig, BltLocalEncoderConfig, BltPatcherConfig, + BltTextConfig, + BltVisionConfig, ) @@ -427,6 +431,356 @@ def forward( return attn_output, attn_weights +class BltPrecomputedAspectRatioEmbedding(nn.Module): + def __init__(self, config: BltVisionConfig, is_gated: bool = True): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size) + if is_gated: + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class BltPrecomputedPositionEmbedding(nn.Module): + def __init__(self, config: BltVisionConfig): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter(torch.zeros(1)) + + # position embedding + position_embedding = torch.randn(self.num_patches, self.hidden_size) + self.embedding = nn.Parameter(self.scale * position_embedding) + + # tile position embedding + self.tile_embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, self.max_num_tiles * self.num_patches * self.hidden_size + ) + + def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + # position embeddings + gated_position_embedding = (1 - self.gate.tanh()) * self.embedding + hidden_state = hidden_state + gated_position_embedding.view(1, 1, self.num_patches, self.hidden_size) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size + ) + gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +class BltVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class BltVisionAttention(nn.Module): + def __init__(self, config: BltVisionConfig): + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + self.scaling = self.head_dim**-0.5 + self.num_key_value_groups = 1 + + self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=False) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + query = self.q_proj(hidden_state) + key = self.k_proj(hidden_state) + value = self.v_proj(hidden_state) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim).transpose(1, 2) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) + value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query, + key, + value, + attention_mask, + dropout=0.0, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(batch_size, q_seq_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class BltVisionEncoderLayer(nn.Module): + def __init__(self, config: BltVisionConfig, is_gated: bool = False): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = BltVisionAttention(config) + self.mlp = BltVisionMLP(config) + + self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) + + if is_gated: + self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) + self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state, attn_weights = self.self_attn(hidden_state, attention_mask=attention_mask) + if self.is_gated: + hidden_state = self.gate_attn.tanh() * hidden_state + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + if self.is_gated: + hidden_state = self.gate_ffn.tanh() * hidden_state + hidden_state = residual + hidden_state + + return hidden_state + + +class BltTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + BltTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class BltTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Optional[BltTextConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.config = config + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.layer_idx = layer_idx + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.scaling = self.head_dim**-0.5 + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.q_norm = BltTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = BltTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_norm(key_states) + if past_key_values is not None: + # if we have a new image + new tokens, we only computed key_states on that new image + # we still update the cross key states, past_image, new_image. And use it! + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif cache_position[0] != 0: + key_states, value_states = ( + past_key_values.layers[self.layer_idx].keys, + past_key_values.layers[self.layer_idx].values, + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + attention_interface: Callable = eager_attention_forward + + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + +class BltTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + # Ignore copy + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class BltCrossAttentionDecoderLayer(GradientCheckpointingLayer): + """Cross-attention transformer block with tanh-gated attention and feedforward.""" + + def __init__(self, config: BltTextConfig, layer_idx: int) -> None: + super().__init__() + self.layer_idx = layer_idx + self.cross_attn = BltTextCrossAttention(config, layer_idx=layer_idx) + + self.input_layernorm = BltTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) + + self.mlp = BltTextMLP(config) + self.post_attention_layernorm = BltTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + attention_mask: torch.Tensor, + full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, attn_weights = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + past_key_values=past_key_values, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if full_text_row_masked_out_mask is not None: + hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + + return hidden_states + + @auto_docstring class BltPreTrainedModel(PreTrainedModel): config: BltConfig @@ -444,6 +798,183 @@ class BltPreTrainedModel(PreTrainedModel): "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), } + @torch.no_grad() + def _init_weights(self, module): + """ + Initialize BLT weights following the original ByteLatentTransformer: + + - Most weights are drawn from a truncated normal. + - Scale is ~ 1 / sqrt(model_dim) (or 1 / sqrt(hidden_dim) for FFN outputs). + - Norm layers are set to weight = 1, bias = 0. + """ + class_name = module.__class__.__name__ + + # Norms: RMSNorm / LayerNorm + if isinstance(module, (BltRMSNorm, nn.LayerNorm)) or "RMSNorm" in class_name or "LayerNorm" in class_name: + if getattr(module, "weight", None) is not None: + nn.init.ones_(module.weight) + if getattr(module, "bias", None) is not None: + nn.init.zeros_(module.bias) + return + + # Embeddings (encoder / patcher / hash embeddings) + if isinstance(module, nn.Embedding): + hidden_size = getattr(self.config, "hidden_size", None) + if hidden_size is None and hasattr(self.config, "encoder_config"): + hidden_size = getattr(self.config.encoder_config, "hidden_size", None) + if hidden_size is None: + hidden_size = module.embedding_dim + + std = hidden_size**-0.5 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.padding_idx is not None: + nn.init.zeros_(module.weight[module.padding_idx]) + return + + # Self-attention / cross-attention projections + if isinstance(module, (BltSelfAttention, BltCrossAttention)) or class_name in ( + "MllamaTextSelfAttention", + "MllamaTextCrossAttention", + ): + dim = getattr(self.config, "hidden_size", None) + if dim is None and hasattr(module, "hidden_size"): + dim = module.hidden_size + if dim is None: + for name in ("q_proj", "k_proj", "v_proj", "o_proj", "dense"): + proj = getattr(module, name, None) + if proj is not None and hasattr(proj, "weight"): + dim = proj.weight.shape[-1] + break + if dim is None: + return + + std = dim**-0.5 + + # Input projections (q, k, v) + for proj_name in ("q_proj", "k_proj", "v_proj"): + proj = getattr(module, proj_name, None) + if proj is not None and hasattr(proj, "weight"): + nn.init.trunc_normal_( + proj.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if getattr(proj, "bias", None) is not None: + nn.init.zeros_(proj.bias) + + # Output projection: o_proj or dense + o_proj = getattr(module, "o_proj", getattr(module, "dense", None)) + if o_proj is not None and hasattr(o_proj, "weight"): + nn.init.trunc_normal_( + o_proj.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if getattr(o_proj, "bias", None) is not None: + nn.init.zeros_(o_proj.bias) + return + + # MLP / FFN blocks + if isinstance(module, BltMLP) or class_name == "MllamaTextMLP": + hidden_size = getattr(self.config, "hidden_size", None) + if hidden_size is None and hasattr(self.config, "decoder_config"): + hidden_size = getattr(self.config.decoder_config, "hidden_size", None) + if hidden_size is None and hasattr(self.config, "encoder_config"): + hidden_size = getattr(self.config.encoder_config, "hidden_size", None) + + # Input-side std + in_std = None + if hidden_size is not None: + in_std = hidden_size**-0.5 + + gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None)) + up_proj = getattr(module, "up_proj", None) + down_proj = getattr(module, "down_proj", getattr(module, "fc2", None)) + + # gate / input projections + for proj in (gate_proj, up_proj): + if proj is not None and hasattr(proj, "weight"): + std = in_std or (proj.weight.shape[1] ** -0.5) + nn.init.trunc_normal_( + proj.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if getattr(proj, "bias", None) is not None: + nn.init.zeros_(proj.bias) + + # output/ down projections + if down_proj is not None and hasattr(down_proj, "weight"): + hidden_dim = down_proj.weight.shape[1] + out_std = hidden_dim**-0.5 + nn.init.trunc_normal_( + down_proj.weight, + mean=0.0, + std=out_std, + a=-3 * out_std, + b=3 * out_std, + ) + if getattr(down_proj, "bias", None) is not None: + nn.init.zeros_(down_proj.bias) + return + + # Generic Linear layers (projections, lm_head, etc.) + if isinstance(module, nn.Linear): + fan_in = module.in_features + std = fan_in**-0.5 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.bias is not None: + nn.init.zeros_(module.bias) + return + std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + init.normal_(module.weight, mean=0.0, std=std) + # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag + if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False): + init.zeros_(module.weight[module.padding_idx]) + elif isinstance(module, nn.LayerNorm): + init.ones_(module.weight) + init.zeros_(module.bias) + elif isinstance(module, BltTextRMSNorm): + init.ones_(module.weight) + elif isinstance(module, BltVisionModel): + init.normal_(module.class_embedding, std=std) + elif isinstance(module, BltPrecomputedPositionEmbedding): + init.normal_(module.embedding, std=std) + init.zeros_(module.gate) + elif isinstance(module, BltVisionEncoderLayer) and module.is_gated: + init.normal_(module.gate_attn, std=std) + init.normal_(module.gate_ffn, std=std) + elif isinstance(module, BltCrossAttentionDecoderLayer): + init.zeros_(module.cross_attn_attn_gate) + init.zeros_(module.cross_attn_mlp_gate) + elif isinstance(module, BltPrecomputedAspectRatioEmbedding): + if module.is_gated: + init.zeros_(module.gate) + class BltLocalEncoder(BltPreTrainedModel): config: BltLocalEncoderConfig @@ -1322,4 +1853,12 @@ def forward( ) -__all__ = ["BltPreTrainedModel", "BltModel", "BltPatcher", "BltForCausalLM"] +class BltVisionModel(BltPreTrainedModel): + pass + + +class BltTextModel(BltPreTrainedModel): + pass + + +__all__ = ["BltPreTrainedModel", "BltModel", "BltPatcher", "BltForCausalLM", "BltVisionModel", "BltTextModel"] diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index c8c0812b00c1..6728aaf77eee 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -360,8 +360,163 @@ class BltPreTrainedModel(MllamaPreTrainedModel): "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"), } + # Weight initialization is adapted from: + # - https://github.com/facebookresearch/blt/blob/main/bytelatent/model/blt.py + # - https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/transformers_modeling_backend/model/model.py + # + # Both implementations use truncated normal initialization with std ~ 1 / sqrt(d_model) + # (or 1 / sqrt(hidden_dim) for FFN outputs), and unit initialization for normalization layers. + # We follow the same scheme here, but expressed in the Transformers APIs. + + @torch.no_grad() def _init_weights(self, module): - raise AttributeError("No need to inherit it!") + """ + Initialize BLT weights following the original ByteLatentTransformer: + + - Most weights are drawn from a truncated normal. + - Scale is ~ 1 / sqrt(model_dim) (or 1 / sqrt(hidden_dim) for FFN outputs). + - Norm layers are set to weight = 1, bias = 0. + """ + class_name = module.__class__.__name__ + + # Norms: RMSNorm / LayerNorm + if isinstance(module, (BltRMSNorm, nn.LayerNorm)) or "RMSNorm" in class_name or "LayerNorm" in class_name: + if getattr(module, "weight", None) is not None: + nn.init.ones_(module.weight) + if getattr(module, "bias", None) is not None: + nn.init.zeros_(module.bias) + return + + # Embeddings (encoder / patcher / hash embeddings) + if isinstance(module, nn.Embedding): + hidden_size = getattr(self.config, "hidden_size", None) + if hidden_size is None and hasattr(self.config, "encoder_config"): + hidden_size = getattr(self.config.encoder_config, "hidden_size", None) + if hidden_size is None: + hidden_size = module.embedding_dim + + std = hidden_size**-0.5 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.padding_idx is not None: + nn.init.zeros_(module.weight[module.padding_idx]) + return + + # Self-attention / cross-attention projections + if isinstance(module, (BltSelfAttention, BltCrossAttention)) or class_name in ( + "MllamaTextSelfAttention", + "MllamaTextCrossAttention", + ): + dim = getattr(self.config, "hidden_size", None) + if dim is None and hasattr(module, "hidden_size"): + dim = module.hidden_size + if dim is None: + for name in ("q_proj", "k_proj", "v_proj", "o_proj", "dense"): + proj = getattr(module, name, None) + if proj is not None and hasattr(proj, "weight"): + dim = proj.weight.shape[-1] + break + if dim is None: + return + + std = dim**-0.5 + + # Input projections (q, k, v) + for proj_name in ("q_proj", "k_proj", "v_proj"): + proj = getattr(module, proj_name, None) + if proj is not None and hasattr(proj, "weight"): + nn.init.trunc_normal_( + proj.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if getattr(proj, "bias", None) is not None: + nn.init.zeros_(proj.bias) + + # Output projection: o_proj or dense + o_proj = getattr(module, "o_proj", getattr(module, "dense", None)) + if o_proj is not None and hasattr(o_proj, "weight"): + nn.init.trunc_normal_( + o_proj.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if getattr(o_proj, "bias", None) is not None: + nn.init.zeros_(o_proj.bias) + return + + # MLP / FFN blocks + if isinstance(module, BltMLP) or class_name == "MllamaTextMLP": + hidden_size = getattr(self.config, "hidden_size", None) + if hidden_size is None and hasattr(self.config, "decoder_config"): + hidden_size = getattr(self.config.decoder_config, "hidden_size", None) + if hidden_size is None and hasattr(self.config, "encoder_config"): + hidden_size = getattr(self.config.encoder_config, "hidden_size", None) + + # Input-side std + in_std = None + if hidden_size is not None: + in_std = hidden_size**-0.5 + + gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None)) + up_proj = getattr(module, "up_proj", None) + down_proj = getattr(module, "down_proj", getattr(module, "fc2", None)) + + # gate / input projections + for proj in (gate_proj, up_proj): + if proj is not None and hasattr(proj, "weight"): + std = in_std or (proj.weight.shape[1] ** -0.5) + nn.init.trunc_normal_( + proj.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if getattr(proj, "bias", None) is not None: + nn.init.zeros_(proj.bias) + + # output/ down projections + if down_proj is not None and hasattr(down_proj, "weight"): + hidden_dim = down_proj.weight.shape[1] + out_std = hidden_dim**-0.5 + nn.init.trunc_normal_( + down_proj.weight, + mean=0.0, + std=out_std, + a=-3 * out_std, + b=3 * out_std, + ) + if getattr(down_proj, "bias", None) is not None: + nn.init.zeros_(down_proj.bias) + return + + # Generic Linear layers (projections, lm_head, etc.) + if isinstance(module, nn.Linear): + fan_in = module.in_features + std = fan_in**-0.5 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=std, + a=-3 * std, + b=3 * std, + ) + if module.bias is not None: + nn.init.zeros_(module.bias) + return + + # Fallback to parent default initialization. + super()._init_weights(module) def _update_causal_mask(self, module): raise AttributeError("No need to inherit it!") @@ -987,9 +1142,19 @@ def forward( ) +class BltVisionModel(BltPreTrainedModel): + pass + + +class BltTextModel(BltPreTrainedModel): + pass + + __all__ = [ "BltPreTrainedModel", "BltModel", "BltPatcher", "BltForCausalLM", + "BltVisionModel", + "BltTextModel", ] diff --git a/tests/models/blt/test_modeling_blt.py b/tests/models/blt/test_modeling_blt.py index 56ee012aa98c..b9c96df3f537 100644 --- a/tests/models/blt/test_modeling_blt.py +++ b/tests/models/blt/test_modeling_blt.py @@ -177,10 +177,6 @@ class BltModelTest(CausalLMModelTest, unittest.TestCase): # used in `test_torch_compile_for_training` _torch_compile_train_cls = BltForCausalLM if is_torch_available() else None - @unittest.skip("BLT model requires special handling for training overfit test") - def test_training_overfit(self): - pass - @pytest.mark.generate @parameterized.expand([("greedy", 1), ("beam search", 2)]) @unittest.skip( diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index ac379f618823..9aad5a7d7748 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -362,6 +362,8 @@ "IdeficsPerceiverConfig": True, # TODO: @Arthur/Joao (`hidden_act` unused) "GptOssConfig": True, + "BltTextConfig": True, + "BltVisionConfig": True, } ) diff --git a/utils/check_repo.py b/utils/check_repo.py index 80d6a3f3223f..f3c9f2f97114 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -197,6 +197,8 @@ "BltLocalDecoder", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. "BltGlobalTransformer", # Building part of bigger (tested) model. Tested implicitly through BLTForCausalLM. "Florence2VisionBackbone", # Building part of bigger (tested) model. Tested implicitly through Florence2ForConditionalGeneration. + "BltTextModel", # Placeholder wrapper; no dedicated functionality yet. + "BltVisionModel", # Placeholder wrapper; no dedicated functionality yet. ] ) @@ -403,6 +405,8 @@ "Qwen3OmniMoeTalkerModel", # Building part of a bigger model "Qwen3OmniMoeThinkerForConditionalGeneration", # Building part of a bigger model "Qwen3OmniMoeThinkerTextModel", # Building part of a bigger model + "BltTextModel", # Placeholder wrapper; no dedicated functionality yet. + "BltVisionModel", # Placeholder wrapper; no dedicated functionality yet. ]