diff --git a/contrib/models/Mistral-Small-3.1-24B-Instruct-2503/src/modeling_mistral3.py b/contrib/models/Mistral-Small-3.1-24B-Instruct-2503/src/modeling_mistral3.py index b7f64504..fba4fbdc 100644 --- a/contrib/models/Mistral-Small-3.1-24B-Instruct-2503/src/modeling_mistral3.py +++ b/contrib/models/Mistral-Small-3.1-24B-Instruct-2503/src/modeling_mistral3.py @@ -44,7 +44,9 @@ NeuronBaseForCausalLM, NeuronBaseModel, ) -from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm @@ -70,6 +72,7 @@ class Mistral3NeuronConfig(NeuronConfig): """ Mistral3-specific NeuronConfig that sets the attention class. """ + def __init__(self, **kwargs): super().__init__(**kwargs) self.attn_cls = NeuronMistral3Attention @@ -78,15 +81,15 @@ def __init__(self, **kwargs): class Mistral3InferenceConfig(InferenceConfig): """ Configuration class for Mistral3 inference on NeuronX. - + This config handles the text portion of the Mistral3 multimodal model. It reads from the nested text_config in the Mistral3 config.json. """ - + def add_derived_config(self): """Add derived configuration parameters""" self.num_cores_per_group = 1 - + def get_required_attributes(self) -> List[str]: """List of required attributes for the configuration""" return [ @@ -103,45 +106,45 @@ def get_required_attributes(self) -> List[str]: "intermediate_size", "head_dim", ] - + @classmethod - def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + def get_neuron_config_cls(cls) -> Type["Mistral3NeuronConfig"]: """Return the NeuronConfig class to use""" - return NeuronConfig - + return Mistral3NeuronConfig + @classmethod def from_pretrained(cls, model_path: str, **kwargs) -> "Mistral3InferenceConfig": """ Load configuration from a pretrained Mistral3 model directory. - + Mistral3 uses a nested config structure with text_config and vision_config. We extract the text_config for text-only inference. - + Args: model_path: Path to the model directory containing config.json **kwargs: Additional arguments to override configuration - + Returns: Mistral3InferenceConfig: Configuration object """ # Extract neuron_config from kwargs if it exists neuron_config = kwargs.pop("neuron_config", None) - + # Read config file config_path = os.path.join(model_path, "config.json") if not os.path.exists(config_path): raise FileNotFoundError(f"Configuration file not found at {config_path}") - + with open(config_path, "r") as f: config_dict = json.load(f) - + # Mistral3 has nested config structure - extract text_config if "text_config" in config_dict: text_config = config_dict["text_config"] else: # If no text_config, assume this is already a text config text_config = config_dict - + # Create config dict with defaults from text config inference_config = { "hidden_size": text_config.get("hidden_size", 5120), @@ -149,7 +152,9 @@ def from_pretrained(cls, model_path: str, **kwargs) -> "Mistral3InferenceConfig" "num_hidden_layers": text_config.get("num_hidden_layers", 40), "num_key_value_heads": text_config.get("num_key_value_heads", 8), "vocab_size": text_config.get("vocab_size", 131072), - "max_position_embeddings": text_config.get("max_position_embeddings", 131072), + "max_position_embeddings": text_config.get( + "max_position_embeddings", 131072 + ), "rope_theta": text_config.get("rope_theta", 1000000000.0), "rms_norm_eps": text_config.get("rms_norm_eps", 1e-05), "hidden_act": text_config.get("hidden_act", "silu"), @@ -169,10 +174,10 @@ def from_pretrained(cls, model_path: str, **kwargs) -> "Mistral3InferenceConfig" "bos_token_id": text_config.get("bos_token_id", 1), "eos_token_id": text_config.get("eos_token_id", 2), } - + # Override with any provided kwargs inference_config.update(kwargs) - + # Create config object # If neuron_config is None, create a default one for inference if neuron_config is None: @@ -183,7 +188,7 @@ def from_pretrained(cls, model_path: str, **kwargs) -> "Mistral3InferenceConfig" batch_size=1, seq_len=128, ) - + config = cls(neuron_config=neuron_config, **inference_config) return config @@ -191,29 +196,31 @@ def from_pretrained(cls, model_path: str, **kwargs) -> "Mistral3InferenceConfig" class NeuronMistral3Attention(NeuronAttentionBase): """ Mistral3 attention implementation for NeuronX. - + Uses the same attention mechanism as standard Mistral with: - Grouped Query Attention (GQA) with 32 query heads and 8 KV heads - Rotary Position Embeddings (RoPE) with very high theta (1B) - Optional sliding window attention - + Inherits from NeuronAttentionBase which provides: - Flash attention computation - KV cache management - Tensor parallel support """ - + def __init__(self, config: InferenceConfig): # Create rotary embeddings with Mistral3's high rope_theta # Use explicit head_dim from config instead of calculating it - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + rotary_emb = RotaryEmbedding( head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ) - + # Initialize base attention with Mistral3 parameters super().__init__( config=config, @@ -229,26 +236,26 @@ def __init__(self, config: InferenceConfig): class NeuronMistral3DecoderLayer(nn.Module): """ Mistral3 decoder layer for NeuronX. - + Architecture: - Pre-norm architecture with RMSNorm - Self-attention with GQA - MLP with SwiGLU activation - Residual connections - + This matches the standard Mistral decoder layer architecture. """ - + def __init__(self, config: InferenceConfig): super().__init__() self.hidden_size = config.hidden_size - + # Self-attention with GQA self.self_attn = NeuronMistral3Attention(config) - + # MLP with SwiGLU activation (same as Llama/Mistral) self.mlp = NeuronLlamaMLP(config) - + # Layer normalization (RMSNorm) self.input_layernorm = get_rmsnorm_cls()( config.hidden_size, @@ -258,7 +265,7 @@ def __init__(self, config: InferenceConfig): config.hidden_size, eps=config.rms_norm_eps, ) - + def forward( self, hidden_states: torch.Tensor, @@ -266,23 +273,25 @@ def forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: """ Forward pass for Mistral3 decoder layer. - + Args: hidden_states: Input tensor of shape (batch, seq_len, hidden_size) attention_mask: Attention mask position_ids: Position IDs for RoPE past_key_value: Cached key/value states for generation - + Returns: Tuple of (hidden_states, present_key_value, cos_cache, sin_cache, None) """ # Pre-norm + Self Attention + Residual residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, @@ -291,13 +300,15 @@ def forward( **kwargs, ) hidden_states = residual + hidden_states - + # Pre-norm + MLP + Residual residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states)[0] # MLP returns tuple, take first element + hidden_states = self.mlp(hidden_states)[ + 0 + ] # MLP returns tuple, take first element hidden_states = residual + hidden_states - + # Return format expected by framework outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) return outputs @@ -306,28 +317,30 @@ def forward( class NeuronMistral3Model(NeuronBaseModel): """ Mistral3 model for NeuronX (text-only). - + This is the base transformer model without the LM head. It consists of: - Token embeddings - Stack of decoder layers - Final layer normalization - LM head (for causal language modeling) - + The model follows the NeuronX pattern: - setup_attr_for_model: Set up model attributes - init_model: Initialize model components - No custom forward method (handled by base class) """ - + def setup_attr_for_model(self, config: Mistral3InferenceConfig): """ Setup attributes required by the NeuronX framework. - + This method is called during initialization to set up model-specific attributes needed for compilation and inference. """ - self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) self.tp_degree = config.neuron_config.tp_degree self.hidden_size = config.hidden_size self.num_attention_heads = config.num_attention_heads @@ -335,11 +348,11 @@ def setup_attr_for_model(self, config: Mistral3InferenceConfig): self.max_batch_size = config.neuron_config.max_batch_size self.buckets = config.neuron_config.buckets self.sliding_window = getattr(config, "sliding_window", None) - + def init_model(self, config: Mistral3InferenceConfig): """ Initialize the model components. - + This method creates all the model layers: - Token embeddings - Transformer decoder layers @@ -348,7 +361,7 @@ def init_model(self, config: Mistral3InferenceConfig): """ self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - + # Token embeddings with vocabulary parallelism self.embed_tokens = ParallelEmbedding( config.vocab_size, @@ -359,18 +372,21 @@ def init_model(self, config: Mistral3InferenceConfig): pad=True, sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, ) - + # Stack of decoder layers self.layers = nn.ModuleList( - [NeuronMistral3DecoderLayer(config) for _ in range(config.num_hidden_layers)] + [ + NeuronMistral3DecoderLayer(config) + for _ in range(config.num_hidden_layers) + ] ) - + # Final layer normalization self.norm = get_rmsnorm_cls()( config.hidden_size, eps=config.rms_norm_eps, ) - + # Language modeling head self.lm_head = ColumnParallelLinear( config.hidden_size, @@ -385,31 +401,31 @@ def init_model(self, config: Mistral3InferenceConfig): class NeuronMistral3ForCausalLM(NeuronBaseForCausalLM): """ Mistral3 For Causal Language Modeling on NeuronX. - + This is the main class for text generation with Mistral3. It wraps the base model and provides: - Weight loading from HuggingFace checkpoints - State dict conversion to NeuronX format - Compilation and inference APIs - + Usage: # Load and compile config = Mistral3InferenceConfig.from_pretrained(model_path, neuron_config=neuron_config) model = NeuronMistral3ForCausalLM(config) model.compile() - + # Generate text output = model.generate(input_ids, max_length=100) """ - + # Specify the model class to use _model_cls = NeuronMistral3Model - + @staticmethod def load_hf_model(model_path, **kwargs): """ Load the HuggingFace model for weight extraction. - + Note: This is used for weight loading, not for inference. We can't directly use transformers.Mistral3ForConditionalGeneration since we only need the text model weights. @@ -418,42 +434,45 @@ def load_hf_model(model_path, **kwargs): # The base class will handle extracting the relevant weights try: from transformers import AutoModelForCausalLM + return AutoModelForCausalLM.from_pretrained(model_path, **kwargs) except Exception as e: print(f"Warning: Could not load HF model: {e}") # Return None to allow manual weight loading return None - + @staticmethod - def convert_hf_to_neuron_state_dict(state_dict: dict, config: InferenceConfig) -> dict: + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: InferenceConfig + ) -> dict: """ Convert HuggingFace state dict to NeuronX format. - + This handles: 1. Extracting text model weights from multimodal checkpoint 2. Adding rank utilities for tensor parallelism 3. Handling any weight name mapping if needed - + Mistral3 multimodal checkpoint structure: - language_model.model.layers.X.self_attn.q_proj.weight -> layers.X.self_attn.qkv_proj.q_proj.weight - language_model.model.embed_tokens.weight -> embed_tokens.weight - language_model.lm_head.weight -> lm_head.weight - + Args: state_dict: HuggingFace checkpoint state dict config: Model configuration - + Returns: Converted state dict in NeuronX format """ neuron_config = config.neuron_config - + # Handle multimodal checkpoint - extract language_model weights converted_state_dict = {} for key, value in state_dict.items(): # Strip multimodal prefixes new_key = key - + # Remove language_model prefix if present if key.startswith("language_model.model."): # language_model.model.layers.X -> layers.X @@ -468,25 +487,31 @@ def convert_hf_to_neuron_state_dict(state_dict: dict, config: InferenceConfig) - new_key = key.replace("text_model.", "") elif key.startswith("model."): new_key = key.replace("model.", "") - + # Map attention weight names to qkv_proj structure expected by NeuronX # HF: layers.X.self_attn.q_proj.weight # NeuronX: layers.X.self_attn.qkv_proj.q_proj.weight if ".self_attn.q_proj." in new_key: - new_key = new_key.replace(".self_attn.q_proj.", ".self_attn.qkv_proj.q_proj.") + new_key = new_key.replace( + ".self_attn.q_proj.", ".self_attn.qkv_proj.q_proj." + ) elif ".self_attn.k_proj." in new_key: - new_key = new_key.replace(".self_attn.k_proj.", ".self_attn.qkv_proj.k_proj.") + new_key = new_key.replace( + ".self_attn.k_proj.", ".self_attn.qkv_proj.k_proj." + ) elif ".self_attn.v_proj." in new_key: - new_key = new_key.replace(".self_attn.v_proj.", ".self_attn.qkv_proj.v_proj.") - + new_key = new_key.replace( + ".self_attn.v_proj.", ".self_attn.qkv_proj.v_proj." + ) + converted_state_dict[new_key] = value - + # Add rank utilities for vocabulary parallelism if neuron_config.vocab_parallel: converted_state_dict["embed_tokens.rank_util.rank"] = torch.arange( 0, neuron_config.local_ranks_size ) - + # Add rank utilities for attention layers (needed for tensor parallelism) num_layers = config.num_hidden_layers tp_degree = neuron_config.tp_degree @@ -494,21 +519,23 @@ def convert_hf_to_neuron_state_dict(state_dict: dict, config: InferenceConfig) - converted_state_dict[f"layers.{i}.self_attn.rank_util.rank"] = torch.arange( 0, tp_degree, dtype=torch.int32 ) - + # Add rank utility for base model - converted_state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) - + converted_state_dict["rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + return converted_state_dict - + @staticmethod def update_state_dict_for_tied_weights(state_dict): """ Handle tied weights between embeddings and LM head. - + If tie_word_embeddings is True, copy embedding weights to LM head. """ state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() - + @classmethod def get_config_cls(cls): """Return the configuration class"""