From 2c0a466657092b930ef44f4d24d170134d3d90bf Mon Sep 17 00:00:00 2001 From: vx120 <893600387@qq.com> Date: Thu, 13 Nov 2025 14:07:18 +0800 Subject: [PATCH 01/29] add muon clip optimizer --- swift/plugin/muonclip.py | 378 ++++++++++++++++++++++++++++++++ swift/plugin/optimizer.py | 34 +++ swift/ui/llm_train/llm_train.py | 9 + swift/ui/llm_train/optimizer.py | 26 ++- 4 files changed, 446 insertions(+), 1 deletion(-) create mode 100644 swift/plugin/muonclip.py diff --git a/swift/plugin/muonclip.py b/swift/plugin/muonclip.py new file mode 100644 index 0000000000..f2febe68c8 --- /dev/null +++ b/swift/plugin/muonclip.py @@ -0,0 +1,378 @@ +import os +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from typing import Optional, Callable +import copy + + +def newton_schulz(G: torch.Tensor, steps: int = 5, eps: float = 1e-7) -> torch.Tensor: + """ + Newton-Schulz iteration for matrix orthogonalization. + """ + # Coefficients from Muon paper + a, b, c = (3.4445, -4.7750, 2.0315) + + # Convert to float for precision + X = G.float() + X /= (X.norm() + eps) + + # Handle rectangular matrices by transposing + if G.size(0) > G.size(1): + X = X.T + transposed = True + else: + transposed = False + + # Newton-Schulz iterations + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + + # Transpose back if needed + if transposed: + X = X.T + + return X.to(G.dtype) + + +@torch.compile +def zeropower_via_newtonschulz5(G, steps): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. + Optimized version with torch.compile. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + if G.size(0) > G.size(1): + X = X.T + # Ensure spectral norm is at most 1 + X = X / (X.norm() + 1e-7) + # Perform the NS iterations + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + + if G.size(0) > G.size(1): + X = X.T + return X + + +class MuonClip(torch.optim.Optimizer): + """ + Fixed MuonClip Optimizer - Properly combines Muon optimizer with QK-Clip. + + This implementation includes fixes for the deepcopy issue with weight_norm. + """ + + def __init__( + self, + params, + lr: float = 1e-3, + momentum: float = 0.95, + weight_decay: float = 0.01, + tau: float = 100.0, + ns_steps: int = 5, + eps: float = 1e-8, + nesterov: bool = True, + adamw_betas: tuple = (0.9, 0.95), + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= momentum <= 1.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + if not 0.0 < tau: + raise ValueError(f"Invalid tau value: {tau}") + + defaults = dict( + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + tau=tau, + ns_steps=ns_steps, + eps=eps, + nesterov=nesterov, + adamw_betas=adamw_betas, + ) + super(MuonClip, self).__init__(params, defaults) + + # For QK-Clip functionality + self.model = None + self.attention_layers = [] + self.step_count = 0 + + # Store parameter names for classification + self.param_names = {} + + # 修复:避免在初始化时立即分类参数,等待set_model调用 + self._params_classified = False + + def _classify_parameters(self): + """Properly classify parameters into Muon and AdamW groups.""" + if self._params_classified: + return + + for group in self.param_groups: + muon_params = [] + adamw_params = [] + + for p in group['params']: + if p.requires_grad: + # 修复:使用更安全的方式获取参数名称 + param_name = self._get_param_name(p) + + # Use Muon for 2D+ parameters that are not embeddings or lm_head + if (p.ndim >= 2 and + param_name is not None and + not any(name in param_name for name in ['embed', 'lm_head', 'weight_g', 'weight_v'])): + self.state[p]['use_muon'] = True + muon_params.append(p) + else: + # Use AdamW for 1D parameters, embeddings, and output layers + # 特别处理weight_norm相关的参数 + self.state[p]['use_muon'] = False + adamw_params.append(p) + + # Store the classified parameters + group['muon_params'] = muon_params + group['adamw_params'] = adamw_params + + self._params_classified = True + + def _get_param_name(self, param): + """Get parameter name by finding it in the model.""" + if self.model is None: + return None + + try: + for name, p in self.model.named_parameters(): + if p is param: + return name + except RuntimeError as e: + # 处理可能的deepcopy错误 + if "deepcopy" in str(e) or "weight_norm" in str(e): + print(f"Warning: Could not get parameter name due to deepcopy issue: {e}") + return None + raise e + return None + + def set_model(self, model: nn.Module): + """ + Set model reference for QK-Clip functionality and parameter name resolution. + """ + self.model = model + + # 修复:先移除可能的weight_norm,然后再进行参数操作 + self._handle_weight_norm_issues() + + # Try to get attention layers from model + if hasattr(model, 'get_attention_layers'): + self.attention_layers = model.get_attention_layers() + else: + # Fallback: try to find attention layers automatically + self.attention_layers = self._find_attention_layers(model) + + # Now classify parameters + self._classify_parameters() + + def _handle_weight_norm_issues(self): + """处理weight_norm相关的深度拷贝问题""" + if self.model is None: + return + + # 检查模型中是否使用了weight_norm + has_weight_norm = False + for module in self.model.modules(): + if hasattr(module, 'weight_g') or hasattr(module, 'weight_v'): + has_weight_norm = True + break + + if has_weight_norm: + print("Warning: Model may contain weight_norm layers which can cause deepcopy issues.") + print("Consider using torch.nn.utils.remove_weight_norm if possible.") + + def _find_attention_layers(self, model): + """Try to find attention layers in the model automatically.""" + attention_layers = [] + for name, module in model.named_modules(): + # Support both Qwen2 (q_proj, k_proj, v_proj) and standard attention + if (hasattr(module, 'q_proj') and hasattr(module, 'k_proj') and hasattr(module, 'v_proj')) or \ + (hasattr(module, 'query') and hasattr(module, 'key') and hasattr(module, 'value')): + attention_layers.append((name, module)) + return attention_layers + + def adjust_lr_for_muon(self, lr: float, param_shape: tuple) -> float: + """ + Adjust learning rate for Muon parameters based on matrix dimensions. + """ + if len(param_shape) >= 2: + A, B = param_shape[0], param_shape[1] + adjusted_ratio = 0.2 * math.sqrt(max(A, B)) + adjusted_lr = lr * adjusted_ratio + return adjusted_lr + return lr + + def _apply_muon_update(self, p, grad, group): + """Apply Muon update for 2D+ parameters.""" + lr = group['lr'] + momentum = group['momentum'] + weight_decay = group['weight_decay'] + ns_steps = group['ns_steps'] + nesterov = group['nesterov'] + + state = self.state[p] + + # Initialize momentum buffer + if 'momentum_buffer' not in state: + state['momentum_buffer'] = torch.zeros_like(grad) + + buf = state['momentum_buffer'] + + # Apply momentum: M_t = μM_{t-1} + G_t + buf.mul_(momentum).add_(grad) + + # Prepare gradient for orthogonalization + if nesterov: + g = grad + momentum * buf + else: + g = buf + + # Flatten to 2D if needed for orthogonalization + original_shape = g.shape + if g.ndim > 2: + g_2d = g.view(g.shape[0], -1) + else: + g_2d = g + + # Apply Newton-Schulz orthogonalization + orthogonal_update = zeropower_via_newtonschulz5(g_2d, ns_steps) + + # Reshape back to original dimensions if needed + if g.ndim > 2: + orthogonal_update = orthogonal_update.view(original_shape) + + # Adjust learning rate for Muon + adjusted_lr = self.adjust_lr_for_muon(lr, p.shape) + + # Apply weight decay (AdamW style) + p.data.mul_(1 - lr * weight_decay) + + # Apply orthogonal update + p.data.add_(orthogonal_update, alpha=-adjusted_lr) + + def _apply_adamw_update(self, p, grad, group): + """Apply AdamW update for 1D parameters, embeddings, and output layers.""" + lr = group['lr'] + beta1, beta2 = group['adamw_betas'] + eps = group['eps'] + weight_decay = group['weight_decay'] + + state = self.state[p] + + # Initialize AdamW state + if 'step' not in state: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(grad) + state['exp_avg_sq'] = torch.zeros_like(grad) + + state['step'] += 1 + step = state['step'] + exp_avg = state['exp_avg'] + exp_avg_sq = state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = exp_avg_sq.sqrt().add_(eps) + + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + step_size = lr * math.sqrt(bias_correction2) / bias_correction1 + + # Apply weight decay + p.data.mul_(1 - lr * weight_decay) + + # Apply update + p.data.addcdiv_(exp_avg, denom, value=-step_size) + + def _apply_qk_clip(self): + """Apply QK-Clip to attention layers to prevent logit explosion.""" + if not self.attention_layers: + return + + tau = self.param_groups[0]['tau'] + + for layer_name, attention_layer in self.attention_layers: + # For Qwen2-style attention + if hasattr(attention_layer, 'q_proj') and hasattr(attention_layer, 'k_proj'): + max_logits = getattr(attention_layer, 'max_logits', 0.0) + + if max_logits > tau: + gamma = tau / max_logits + sqrt_gamma = math.sqrt(gamma) + + # Apply scaling to query and key projection weights + with torch.no_grad(): + attention_layer.q_proj.weight.data *= sqrt_gamma + attention_layer.k_proj.weight.data *= sqrt_gamma + + # Reset max_logits + if hasattr(attention_layer, 'max_logits'): + attention_layer.max_logits = 0.0 + + # For standard attention + elif hasattr(attention_layer, 'query') and hasattr(attention_layer, 'key'): + max_logits = getattr(attention_layer, 'max_logits', 0.0) + + if max_logits > tau: + gamma = tau / max_logits + sqrt_gamma = math.sqrt(gamma) + + with torch.no_grad(): + attention_layer.query.weight.data *= sqrt_gamma + attention_layer.key.weight.data *= sqrt_gamma + + if hasattr(attention_layer, 'max_logits'): + attention_layer.max_logits = 0.0 + + @torch.no_grad() + def step(self, closure: Optional[Callable] = None) -> Optional[float]: + """ + Performs a single optimization step with proper parameter classification. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # 确保参数已经分类 + if not self._params_classified and self.model is not None: + self._classify_parameters() + + for group in self.param_groups: + # Process Muon parameters (2D+) + for p in group.get('muon_params', []): + if p.grad is not None and p.grad.is_sparse is False: + self._apply_muon_update(p, p.grad, group) + + # Process AdamW parameters (1D, embeddings, output layers) + for p in group.get('adamw_params', []): + if p.grad is not None and p.grad.is_sparse is False: + self._apply_adamw_update(p, p.grad, group) + + # Apply QK-Clip for attention stability + self._apply_qk_clip() + + # Increment step counter + self.step_count += 1 + + return loss + diff --git a/swift/plugin/optimizer.py b/swift/plugin/optimizer.py index 522afe324b..862b784629 100644 --- a/swift/plugin/optimizer.py +++ b/swift/plugin/optimizer.py @@ -97,6 +97,39 @@ def create_muon_optimizer(args: 'TrainingArguments', model, dataset): **optim_args, ), None +def create_muon_clip_optimizer(args: 'TrainingArguments', model, dataset): + from swift.plugin.muonclip import MuonClip + + # parse args.optim_args + optim_args = {} + if args.optim_args: + for mapping in args.optim_args.replace(' ', '').split(','): + key, value = mapping.split('=') + optim_args[key] = value + + # Set default values for MuonClip parameters + lr = optim_args.get('lr', args.learning_rate) + momentum = float(optim_args.get('momentum', 0.95)) + weight_decay = float(optim_args.get('weight_decay', args.weight_decay)) + tau = float(optim_args.get('tau', 100.0)) + ns_steps = int(optim_args.get('ns_steps', 5)) + eps = float(optim_args.get('eps', 1e-7)) + + # Create MuonClip optimizer with all parameters + optimizer = MuonClip( + model.parameters(), + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + tau=tau, + ns_steps=ns_steps, + eps=eps, + ) + + # Set model reference for QK-Clip functionality + optimizer.set_model(model) + + return optimizer, None def get_param_startswith(model, chosen_prefix: List[str], @@ -161,5 +194,6 @@ def create_multimodal_optimizer(args: 'TrainingArguments', model, dataset): 'galore': create_galore_optimizer, 'lorap': create_lorap_optimizer, 'muon': create_muon_optimizer, + 'muonclip': create_muon_clip_optimizer, 'multimodal': create_multimodal_optimizer, } diff --git a/swift/ui/llm_train/llm_train.py b/swift/ui/llm_train/llm_train.py index 3f21cbba04..092de2f3d5 100644 --- a/swift/ui/llm_train/llm_train.py +++ b/swift/ui/llm_train/llm_train.py @@ -392,11 +392,14 @@ def train(cls, *args): kwargs.pop('use_liger_kernel') if other_kwargs.get('use_muon'): kwargs['use_muon'] = other_kwargs.pop('use_muon') + if other_kwargs.get('use_muonclip'): + kwargs['use_muonclip'] = other_kwargs.pop('use_muonclip') # filter kwargs tabs_relation_dict = cls.prepare_sub_to_filter() cls.remove_useless_args(kwargs, tabs_relation_dict) use_muon = kwargs.pop('use_muon', None) + use_muonclip = kwargs.pop('use_muonclip', None) if cls.group == 'llm_rlhf': cls.filter_rlhf_args(kwargs) try: @@ -431,6 +434,9 @@ def train(cls, *args): if use_muon: params += f'--optimizer {cls.quote}muon{cls.quote} ' command.extend(['--optimizer', 'muon']) + if use_muonclip: + params += f'--optimizer {cls.quote}muonclip{cls.quote} ' + command.extend(['--optimizer', 'muonclip']) more_params_cmd = more_params_cmd.strip() if more_params_cmd != '': params += f'{more_params_cmd} ' @@ -566,6 +572,9 @@ def remove_useless_args(cls, uncleaned_kwargs, tabs_relation_dict): target_value = 'multimodal' if uncleaned_kwargs.get('use_muon'): target_value = 'muon' + if uncleaned_kwargs.get('use_muonclip'): + target_value = 'muonclip' + for tab_key in tabs_to_filter.keys(): if tab_key == 'lora' and target_value in ('longlora', 'adalora'): diff --git a/swift/ui/llm_train/optimizer.py b/swift/ui/llm_train/optimizer.py index 15ad94246f..edcd91326c 100644 --- a/swift/ui/llm_train/optimizer.py +++ b/swift/ui/llm_train/optimizer.py @@ -85,6 +85,22 @@ class Optimizer(BaseUI): 'en': 'Using the Muon optimizer, set `--optimizer muon` in the command line' } }, + 'muonclip_tab': { + 'label': { + 'zh': 'MuonClip参数设置', + 'en': 'MuonClip Settings' + }, + }, + 'use_muonclip': { + 'label': { + 'zh': '使用MuonClip', + 'en': 'Use MuonClip' + }, + 'info': { + 'zh': '使用MuonClip优化器,将在命令行参数中设置`--optimizer muonclip`', + 'en': 'Using the MuonClip optimizer, set `--optimizer muonclip` in the command line' + } + }, 'multimodal_tab': { 'label': { 'zh': '多模态参数设置', @@ -115,7 +131,8 @@ class Optimizer(BaseUI): 'galore': ['use_galore', 'galore_with_embedding', 'galore_rank', 'galore_update_proj_gap'], 'lorap': ['lorap_lr_ratio'], 'multimodal': ['vit_lr', 'aligner_lr'], - 'muon': ['use_muon'] + 'muon': ['use_muon'], + 'muonclip': ['use_muonclip'] } @classmethod @@ -138,3 +155,10 @@ def do_build_ui(cls, base_tab: Type['BaseUI']): with gr.TabItem(elem_id='muon_tab'): with gr.Row(): gr.Checkbox(elem_id='use_muon', scale=4) + with gr.TabItem(elem_id='multimodal_tab'): + with gr.Row(): + gr.Textbox(elem_id='vit_lr', lines=1, scale=20) + gr.Textbox(elem_id='aligner_lr', lines=1, scale=20) + with gr.TabItem(elem_id='muonclip_tab'): + with gr.Row(): + gr.Checkbox(elem_id='use_muonclip', scale=4) From b565ecbde7fb0ce33edebf39cffbedc8e73798f2 Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 13 Nov 2025 15:06:14 +0800 Subject: [PATCH 02/29] [bugfix] fix mcore-bridge vpp (#6581) --- swift/megatron/model/gpt_bridge.py | 11 ++++++++--- swift/megatron/trainers/kto_trainer.py | 5 ++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index e15ea770a6..5d69a10df6 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -978,8 +978,9 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd else: yield from list(self._add_prefix(hf_state_dict, hf_prefix).items()) hf_state_dict = {} - for layer_idx in tqdm( - range(self.args.num_layers), dynamic_ncols=True, desc=tqdm_desc, disable=self.disable_tqmd): + layer_idx = 0 + prog_bar = tqdm(range(self.args.num_layers), dynamic_ncols=True, desc=tqdm_desc, disable=self.disable_tqmd) + while layer_idx < self.args.num_layers: lm_model = getattr(mg_model, 'language_model') if self.args.is_multimodal else mg_model if len(lm_model.decoder.layers) > 0: start_idx = lm_model.decoder.layers[0].layer_number - 1 @@ -990,6 +991,8 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd mg_layer = lm_model.decoder.layers[layer_idx - start_idx] else: if to_mcore: + layer_idx += 1 + prog_bar.update() continue else: mg_layer = None @@ -997,9 +1000,11 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd has_model = torch.tensor([mg_layer is not None], dtype=torch.bool, device='cuda') dist.all_reduce(has_model, group=self.pp_group) if not has_model: - mg_model = next(mg_models) + mg_model = next(mg_models) # compat vpp continue res = self._set_layer_state(mg_layer, hf_state_dict, f'{self.hf_layers_prefix}.', layer_idx, to_mcore) + layer_idx += 1 + prog_bar.update() if to_mcore: yield else: diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index a3d8cd2f01..d0a385aa41 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -50,7 +50,7 @@ def _kto_get_logps(self, output_tensor, data, is_KL: bool, is_ref: bool, length: return self.get_logps(output, labels, packed_seq_params, packed_seq_params.num_samples) def loss_func(self, output_tensor, *, data, kl_data, label): - length = data['packed_seq_params'].cu_seqlens_q[-1] + length = data['packed_seq_params'].cu_seqlens_q[-1] // self.args.context_parallel_size policy_logps = self._kto_get_logps(output_tensor, data, False, False, length) ref_logps = self._kto_get_logps(output_tensor, data, False, True, length) if self.args.calculate_KL: @@ -121,8 +121,7 @@ def forward_step(self, data_iterator, model): data.pop('loss_scale', None) kl_data.pop('loss_scale', None) - length = data['packed_seq_params'].cu_seqlens_q[-1] - + length = data['packed_seq_params'].cu_seqlens_q[-1] // self.args.context_parallel_size with torch.no_grad(), self.null_ref_context() as ref_models: ref_model = ref_models[vp_stage or 0] if self.args.calculate_KL: From 9d67dbbcb531be4ab32acd48ee15cefd87c19167 Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 13 Nov 2025 16:46:06 +0800 Subject: [PATCH 03/29] qwen2.5-vl compat qwen_vl_utils 0.14.0 (#6584) --- .../Megatron-SWIFT/Command-line-parameters.md | 1 + .../Megatron-SWIFT/Command-line-parameters.md | 1 + swift/llm/model/model/qwen.py | 30 ++++++++++++++++++- swift/llm/model/register.py | 5 ++-- 4 files changed, 34 insertions(+), 3 deletions(-) diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index a827e16aec..c43e60923b 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -36,6 +36,7 @@ - **注意:推荐flash_attn版本:2.7.4.post1/2.8.1**。在"ms-swift<3.7"的版本中,该参数的默认为'auto'。 - 如果安装'flash_attention_3',`--attention_backend flash`则优先使用fa3。训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/flash_attention_3)。 - optimizer: 优化器类型,可选为'adam'、'sgd'。默认为adam。 + - 注意:此'adam'为'adamw',参考[这里](https://github.com/NVIDIA/TransformerEngine/blob/d8f1e68f7c414f3e7985a8b41de4443b2f819af3/transformer_engine/pytorch/optimizers/fused_adam.py#L69-L70)。 - 🔥optimizer_cpu_offload: 将优化器状态卸载到 CPU,例如设置:`--use_precision_aware_optimizer true --optimizer_cpu_offload true --optimizer_offload_fraction 0.7`。默认为False。 - 该参数可以显著降低显存占用(但增加内存占用)。若global_batch_size较大,则对训练速度的影响不大。 - 🔥optimizer_offload_fraction: 卸载到 CPU 的优化器状态所占比例。默认为1.。 diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 27f511b4b8..59e94335d9 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -37,6 +37,7 @@ - **Note: The recommended `flash_attn` version is 2.7.4.post1/2.8.1**. In versions of `ms-swift` prior to 3.7, the default value for this parameter is `'auto'`. - If `flash_attention_3` is installed, specifying `--attention_backend flash` will prioritize using FA3. Refer to the training script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/flash_attention_3). - optimizer: Optimizer type, options are 'adam', 'sgd'. Default is adam. + - Note: This 'adam' is actually 'adamw'. See [here](https://github.com/NVIDIA/TransformerEngine/blob/d8f1e68f7c414f3e7985a8b41de4443b2f819af3/transformer_engine/pytorch/optimizers/fused_adam.py#L69-L70) for reference. - 🔥optimizer_cpu_offload: Offloads optimizer states to the CPU. For example, set: `--use_precision_aware_optimizer true --optimizer_cpu_offload true --optimizer_offload_fraction 0.7`. Defaults to `False`. - This parameter can significantly reduce GPU memory usage (at the cost of increased CPU memory consumption). When the `global_batch_size` is large, its impact on training speed is minimal. - 🔥optimizer_offload_fraction: The fraction of the optimizer state to offload to CPU. Default is `1.0`. diff --git a/swift/llm/model/model/qwen.py b/swift/llm/model/model/qwen.py index 98a3ed1934..b90e5743a6 100644 --- a/swift/llm/model/model/qwen.py +++ b/swift/llm/model/model/qwen.py @@ -1,9 +1,11 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import importlib.metadata import os from types import MethodType from typing import Any, Dict, Optional, Tuple, Type, Union import torch +from packaging import version from PIL import Image from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizerBase from transformers.dynamic_module_utils import get_class_from_dynamic_module @@ -737,6 +739,21 @@ def _new_read_video_decord(ele: dict): return res +def compat_qwen_vl_utils(image_patch_size: int): + spatial_merge_size = int(os.getenv('SPATIAL_MERGE_SIZE', '2')) + image_factor = image_patch_size * spatial_merge_size + env_vars_to_process = { + 'MAX_PIXELS': 'IMAGE_MAX_TOKEN_NUM', + 'MIN_PIXELS': 'IMAGE_MIN_TOKEN_NUM', + 'VIDEO_MAX_PIXELS': 'VIDEO_MAX_TOKEN_NUM', + 'VIDEO_MIN_PIXELS': 'VIDEO_MIN_TOKEN_NUM', + } + for source_var, target_var in env_vars_to_process.items(): + value = os.getenv(source_var) + if value and not os.getenv(target_var): + os.environ[target_var] = str(int(value) // image_factor**2) + + def get_model_tokenizer_qwen2_vl(*args, **kwargs): from transformers import Qwen2VLForConditionalGeneration kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen2VLForConditionalGeneration @@ -746,9 +763,18 @@ def get_model_tokenizer_qwen2_vl(*args, **kwargs): patch_get_input_embeddings(base_model.visual, 'patch_embed') from qwen_vl_utils import vision_process + import qwen_vl_utils check_qwen_vl_utils = kwargs.get('_check_qwen_vl_utils', True) if check_qwen_vl_utils: - require_version('qwen_vl_utils<0.0.12') + try: + qwen_vl_utils_version = importlib.metadata.version('qwen_vl_utils') + except importlib.metadata.PackageNotFoundError: + raise importlib.metadata.PackageNotFoundError( + "The 'qwen_vl_utils' distribution was not found and is required by this application.") + if version.parse(qwen_vl_utils_version) >= version.parse('0.0.14'): + compat_qwen_vl_utils(image_patch_size=14) + else: + require_version('qwen_vl_utils<0.0.12') global_vars = patch_qwen_vl_utils(vision_process) tokenizer.global_vars = global_vars # In order to have different hashes for the template. return model, tokenizer @@ -1060,6 +1086,7 @@ def forward( def get_model_tokenizer_qwen3_vl(model_dir, *args, **kwargs): from transformers import Qwen3VLForConditionalGeneration require_version('qwen_vl_utils>=0.0.14') + compat_qwen_vl_utils(image_patch_size=16) kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen3VLForConditionalGeneration kwargs['_check_qwen_vl_utils'] = False model, processor = get_model_tokenizer_qwen2_vl(model_dir, *args, **kwargs) @@ -1101,6 +1128,7 @@ def get_model_tokenizer_qwen3_vl(model_dir, *args, **kwargs): def get_model_tokenizer_qwen3_moe_vl(model_dir, *args, **kwargs): from transformers import Qwen3VLMoeForConditionalGeneration require_version('qwen_vl_utils>=0.0.14') + compat_qwen_vl_utils(image_patch_size=16) kwargs['automodel_class'] = kwargs['automodel_class'] or Qwen3VLMoeForConditionalGeneration kwargs['_check_qwen_vl_utils'] = False model, processor = get_model_tokenizer_qwen2_vl(model_dir, *args, **kwargs) diff --git a/swift/llm/model/register.py b/swift/llm/model/register.py index 1bbc9e8cce..fdb47e3b75 100644 --- a/swift/llm/model/register.py +++ b/swift/llm/model/register.py @@ -592,8 +592,9 @@ def _get_model_info(model_dir: str, model_type: Optional[str], quantization_conf architectures = HfConfigFactory.get_config_attr(config, 'architectures') model_types = get_matched_model_types(architectures) if len(model_types) > 1: - raise ValueError('Please explicitly pass the model_type. For reference, ' - f'the available model_types: {model_types}.') + raise ValueError('Failed to automatically match `model_type`. ' + f'Please explicitly pass the `model_type` for `{model_dir}`. ' + f'Recommended `model_types` include: {model_types}.') elif len(model_types) == 1: model_type = model_types[0] elif model_type not in MODEL_MAPPING: From 46cd05172f49b012249f68f455891717d8cb9605 Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 13 Nov 2025 21:16:49 +0800 Subject: [PATCH 04/29] [bugfix] fix packing_length (#6594) --- docs/source/Megatron-SWIFT/Command-line-parameters.md | 1 + docs/source/Megatron-SWIFT/Mcore-Bridge.md | 1 + docs/source_en/Megatron-SWIFT/Command-line-parameters.md | 2 ++ docs/source_en/Megatron-SWIFT/Mcore-Bridge.md | 1 + swift/llm/train/sft.py | 1 + swift/megatron/convert.py | 2 +- 6 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index c43e60923b..18033f3eb1 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -5,6 +5,7 @@ **训练参数**: - 🔥micro_batch_size: 每个device的批次大小,默认为1。 - 🔥global_batch_size: 总批次大小,等价于`micro_batch_size*数据并行大小*梯度累加步数`。默认为16。 + - 其中,`数据并行大小 (DP) = 总GPU数 / (TP × PP × CP)`。 - 🔥recompute_granularity: 重新计算激活的粒度,可选项为'full', 'selective'。其中full代表重新计算整个transformer layer,selective代表只计算transformer layer中的核心注意力部分。通常'selective'是推荐的。默认为'selective'。 - 当你设置为'selective'时,你可以通过指定`--recompute_modules`来选择对哪些部分进行重新计算。 - 🔥recompute_method: 该参数需将recompute_granularity设置为'full'才生效,可选项为'uniform', 'block'。默认为None。 diff --git a/docs/source/Megatron-SWIFT/Mcore-Bridge.md b/docs/source/Megatron-SWIFT/Mcore-Bridge.md index c13a62ccac..5579f43f2d 100644 --- a/docs/source/Megatron-SWIFT/Mcore-Bridge.md +++ b/docs/source/Megatron-SWIFT/Mcore-Bridge.md @@ -193,6 +193,7 @@ swift infer \ ## 导出与转换精度测试 Mcore-Bridge除了支持在训练中进行safetensors的转换和保存,也支持了`megatron export`命令用于单独的权重导出。`megatron export`支持在权重转换时,对转换精度进行测试,这在接入新模型时验证接入准确性很有帮助。通常,Megatron-SWIFT已经接入的模型不会出现精度不对齐的情况,你可以放心设置`--test_convert_precision false`。 +- 提示:多模态模型请关注`mean_diff (with loss)`字段,`mean_diff`因包含图像tokens且该部分不计算损失,有较大的diff。 全参数权重: ```shell diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 59e94335d9..ee6111bc94 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -6,6 +6,7 @@ - 🔥micro_batch_size: Batch size per device, default is 1. - 🔥global_batch_size: Total batch size, equivalent to `micro_batch_size * data parallel size * gradient accumulation steps`. Default is 16. + - Here, `Data Parallelism size (DP) = Total number of GPUs / (TP × PP × CP)`. - 🔥recompute_granularity: Granularity of activation recomputation, options are 'full', 'selective'. 'full' means recomputing the entire transformer layer, while 'selective' means only recomputing the core attention part of the transformer layer. 'selective' is generally recommended. Default is 'selective'. - When you set it to 'selective', you can specify `--recompute_modules` to choose which parts to recompute. - 🔥recompute_method: This parameter takes effect only when recompute_granularity is set to 'full', options are 'uniform', 'block'. Default is None. @@ -315,6 +316,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa - 🔥packing: Whether to use sequence packing to improve computational efficiency (achieving better load balancing across nodes and processes, and higher GPU utilization), at the cost of additional preprocessing time, while also stabilizing GPU memory usage. Defaults to `False`. Currently supported for CPT, SFT, DPO, KTO and RM. - Note: **Sequences within the same batch remain mutually invisible**, except for Qwen3-Next. - Note: **Packing reduces the number of samples in the dataset; please adjust the gradient accumulation steps and learning rate accordingly**. +- packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. - streaming: Stream data loading and processing, default is False. - Note: Since the length of a streaming dataset cannot be determined, the `--train_iters` parameter must be set. Also set the `max_epochs` parameter to ensure training exits after the specified number of epochs, and to validate and save the model weights accordingly. - Note: Streaming datasets can skip preprocessing wait time by overlapping preprocessing with training. Preprocessing for streaming datasets is performed only on rank 0 and then synchronized to other processes via data distribution. **This is generally less efficient than the data sharding approach used in non-streaming datasets.** When the training world_size is large, preprocessing and data distribution can become a training bottleneck. diff --git a/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md b/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md index 1e42db1aab..54a3d05694 100644 --- a/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md +++ b/docs/source_en/Megatron-SWIFT/Mcore-Bridge.md @@ -203,6 +203,7 @@ swift infer \ ## Export and Conversion Precision Testing In addition to supporting safetensors conversion and saving during training, Mcore-Bridge also supports the `megatron export` command for standalone weight export. `megatron export` supports conversion precision testing during weight conversion, which is very helpful for verifying accuracy when integrating new models. Typically, models already integrated into Megatron-SWIFT will not have precision misalignment issues, so you can confidently set `--test_convert_precision false`. +- Note: For multimodal models, please focus on the `mean_diff (with loss)` field. The `mean_diff` may show a large difference because it includes image tokens, and loss is not calculated for that portion. Full parameter weights: diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 502b895646..8b5c631a6b 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -161,6 +161,7 @@ def _post_process_datasets(self, datasets: List) -> List: template, dataset, num_proc=args.dataset_num_proc, + packing_length=args.packing_length, strict=args.strict, load_from_cache_file=args.load_from_cache_file) elif args.streaming: diff --git a/swift/megatron/convert.py b/swift/megatron/convert.py index 1e943d724f..2dc2263552 100644 --- a/swift/megatron/convert.py +++ b/swift/megatron/convert.py @@ -220,7 +220,7 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float print(f'token_mean_diff: {token_mean_diff}') print(f'mean_diff: {mean_diff}, max_diff: {max_diff}') print(f'mean_diff (with loss): {mean_diff_with_loss}, max_diff (with loss): {max_diff_with_loss} ' - '(Please check that mean_diff is less than 0.1).') + '(Please check that mean_diff (with loss) is less than 0.1).') hf_tokens = hf_logits.argmax(-1) mg_tokens = mg_logits.argmax(-1) print(f'hf_tokens: {hf_tokens[0].tolist()}\nmg_tokens: {mg_tokens[0].tolist()}') From ddbcbba3f7662513b719d6decfc5a8b55a9c5eed Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 13 Nov 2025 23:37:41 +0800 Subject: [PATCH 05/29] [dataset] support packing_num_proc (#6592) --- .../Instruction/Command-line-parameters.md | 1 + .../Megatron-SWIFT/Command-line-parameters.md | 1 + .../Instruction/Command-line-parameters.md | 1 + .../Megatron-SWIFT/Command-line-parameters.md | 1 + swift/llm/argument/base_args/base_args.py | 1 + swift/llm/dataset/utils.py | 73 +++++++++++++------ swift/llm/train/sft.py | 1 + 7 files changed, 57 insertions(+), 22 deletions(-) diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index c44563037c..53e4d349c8 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -447,6 +447,7 @@ Vera使用`target_modules`、`target_regex`、`modules_to_save`三个参数, - 注意:使用packing请结合`--attn_impl flash_attn`使用且"transformers>=4.44",具体查看[该PR](https://github.com/huggingface/transformers/pull/31629)。 - 注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**。 - packing_length: packing的长度。默认为None,设置为max_length。 +- packing_num_proc: packing的进程数,默认为1。需要注意的是,不同的`packing_num_proc`,最终形成的packed数据集是不同的。(该参数在流式packing时不生效) - lazy_tokenize: 是否使用lazy_tokenize。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(多模态模型则包括从磁盘中读取图片)。该参数默认为None,在LLM训练中默认为False,而MLLM训练默认为True,节约内存。 - 注意:若你要进行图像的数据增强,你需要将lazy_tokenize(或streaming)设置为True,并修改Template类中的encode方法。 - cached_dataset: 训练中使用缓存数据集(使用`swift export --to_cached_dataset true ...`命令产生),避免大型数据集训练时,tokenize过程占用gpu时间。默认为`[]`。例子参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset)。 diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 18033f3eb1..e1de0d6d27 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -300,6 +300,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 - 注意:**同一batch的不同序列之间依旧是不可见的**,除了Qwen3-Next。 - 注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**。 - packing_length: packing的长度。默认为None,设置为max_length。 +- packing_num_proc: packing的进程数,默认为1。需要注意的是,不同的`packing_num_proc`,最终形成的packed数据集是不同的。(该参数在流式packing时不生效) - streaming: 流式读取并处理数据集,默认False。 - 注意:因为流式数据集无法获得其长度,因此需要设置`--train_iters`参数。设置`max_epochs`参数确保训练到对应epochs时退出训练,并对权重进行验证和保存。 - 注意:流式数据集可以跳过预处理等待,将预处理时间与训练时间重叠。流式数据集的预处理只在rank0上进行,并通过数据分发的方式同步到其他进程,**其通常效率不如非流式数据集采用的数据分片读取方式**。当训练的world_size较大时,预处理和数据分发将成为训练瓶颈。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index d91f026ffe..74634419e3 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -455,6 +455,7 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine - Note: When using packing, please combine it with `--attn_impl flash_attn` and ensure "transformers>=4.44". For details, see [this PR](https://github.com/huggingface/transformers/pull/31629). - Note: **Packing reduces the number of samples in the dataset; please adjust the gradient accumulation steps and learning rate accordingly**. - packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. +- packing_num_proc: Number of processes for packing, default is 1. Note that different values of `packing_num_proc` will result in different packed datasets. (This parameter does not take effect during streaming packing) - lazy_tokenize: Whether to use lazy tokenization. If set to `False`, all dataset samples will be tokenized (and for multimodal models, images will be loaded from disk) before training begins. Default is `None`: in LLM training, it defaults to `False`; in MLLM training, it defaults to `True` to save memory. - Note: If you want to perform image data augmentation, you need to set `lazy_tokenize` (or `streaming`) to True and modify the `encode` method in the Template class. - cached_dataset: Use a cached dataset (generated with `swift export --to_cached_dataset true ...`) during training to avoid GPU time spent on tokenizing large datasets. Default is `[]`. Example: [here](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset). diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index ee6111bc94..2c5e465576 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -317,6 +317,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa - Note: **Sequences within the same batch remain mutually invisible**, except for Qwen3-Next. - Note: **Packing reduces the number of samples in the dataset; please adjust the gradient accumulation steps and learning rate accordingly**. - packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. +- packing_num_proc: Number of processes for packing, default is 1. Note that different values of `packing_num_proc` will result in different packed datasets. (This parameter does not take effect during streaming packing) - streaming: Stream data loading and processing, default is False. - Note: Since the length of a streaming dataset cannot be determined, the `--train_iters` parameter must be set. Also set the `max_epochs` parameter to ensure training exits after the specified number of epochs, and to validate and save the model weights accordingly. - Note: Streaming datasets can skip preprocessing wait time by overlapping preprocessing with training. Preprocessing for streaming datasets is performed only on rank 0 and then synchronized to other processes via data distribution. **This is generally less efficient than the data sharding approach used in non-streaming datasets.** When the training world_size is large, preprocessing and data distribution can become a training bottleneck. diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 31c647224f..0695f4d68c 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -84,6 +84,7 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat # dataset packing: bool = False packing_length: Optional[int] = None + packing_num_proc: int = 1 lazy_tokenize: Optional[bool] = None cached_dataset: List[str] = field(default_factory=list) custom_register_path: List[str] = field(default_factory=list) # .py diff --git a/swift/llm/dataset/utils.py b/swift/llm/dataset/utils.py index aaee489c37..a0ccd8338b 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/llm/dataset/utils.py @@ -1,5 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import math import multiprocessing as mp +from itertools import chain from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union import numpy as np @@ -8,7 +10,7 @@ from torch.utils.data import Dataset, IterableDataset from tqdm import tqdm -from swift.utils import get_logger, is_dist, is_master +from swift.utils import get_logger, is_dist, is_master, split_list from ..template import MaxLengthError from .preprocessor import RowPreprocessor @@ -128,6 +130,7 @@ def calculate_matched_group(template, sequences, packing_length: int, is_finishe class PackingDataset(Dataset): + PACKING_BATCH_SIZE = 1000 def __init__( self, @@ -138,6 +141,7 @@ def __init__( strict: bool = False, load_from_cache_file: bool = True, packing_length: Optional[int] = None, + packing_num_proc: int = 1, **kwargs, ): template.packing = True @@ -148,33 +152,58 @@ def __init__( self.strict = strict self.load_from_cache_file = load_from_cache_file self.packing_length = packing_length or self.template.max_length - self.workers = [] - self.packed_idx, self.packed_length = self.create_packed_idx() if is_master() else (None, None) + self.packing_num_proc = min(packing_num_proc, math.ceil(len(dataset) / self.PACKING_BATCH_SIZE)) + self._out_queue = mp.Queue() + if is_master(): + lengths = self.dataset['length'] + offset = 0 + chunked_lengths = split_list(lengths, self.packing_num_proc) + for i in range(self.packing_num_proc): + worker = mp.Process( + target=self.create_packed_idx, args=( + i, + offset, + chunked_lengths[i], + ), daemon=True) + worker.start() + offset += len(chunked_lengths[i]) + self.packed_idx = [[] for _ in range(self.packing_num_proc)] + self.packed_length = [[] for _ in range(self.packing_num_proc)] + desc = 'Packing: ' if self.packing_num_proc == 1 else f'Packing (num_proc={self.packing_num_proc}): ' + with tqdm(total=len(lengths), dynamic_ncols=True, desc=desc) as prog_bar: + finished_workers = 0 + while finished_workers < self.packing_num_proc: + rank, sequences, data_len = self._out_queue.get() + if data_len == -1: + finished_workers += 1 + continue + prog_bar.update(data_len) + self.packed_idx[rank] += [[x[0] for x in seq] for seq in sequences] + self.packed_length[rank] += [sum(x[1] for x in seq) for seq in sequences] + self.packed_idx = list(chain.from_iterable(self.packed_idx)) + self.packed_length = list(chain.from_iterable(self.packed_length)) + else: + self.packed_idx, self.packed_length = None, None if dist.is_initialized() and is_dist(): obj_list = [(self.packed_idx, self.packed_length)] dist.broadcast_object_list(obj_list) self.packed_idx, self.packed_length = obj_list[0] - def create_packed_idx(self): - lengths = self.dataset['length'] - data = [(i, length) for i, length in enumerate(lengths)] + def create_packed_idx(self, rank, offset, lengths): + data = [(i + offset, length) for i, length in enumerate(lengths)] i = 0 - PACKING_BATCH_SIZE = 1000 - input_data, packed_idx, packed_length = [], [], [] - with tqdm(total=len(data), dynamic_ncols=True, desc='Packing: ') as prog_bar: - while True: - new_data = data[i:i + PACKING_BATCH_SIZE] - input_data += new_data - prog_bar.update(len(new_data)) - if not input_data: - break - i += PACKING_BATCH_SIZE - is_finished = i >= len(data) - sequences, input_data = calculate_matched_group( - self.template, input_data, self.packing_length, is_finished=is_finished) - packed_idx += [[x[0] for x in seq] for seq in sequences] - packed_length += [sum(x[1] for x in seq) for seq in sequences] - return packed_idx, packed_length + input_data = [] + while True: + new_data = data[i:i + self.PACKING_BATCH_SIZE] + input_data += new_data + if not input_data: + break + i += self.PACKING_BATCH_SIZE + is_finished = i >= len(data) + sequences, input_data = calculate_matched_group( + self.template, input_data, self.packing_length, is_finished=is_finished) + self._out_queue.put((rank, sequences, len(new_data))) + self._out_queue.put((rank, [], -1)) def __getitem__(self, index): sequence = self.packed_idx[index] diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 8b5c631a6b..6b224ef0e5 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -162,6 +162,7 @@ def _post_process_datasets(self, datasets: List) -> List: dataset, num_proc=args.dataset_num_proc, packing_length=args.packing_length, + packing_num_proc=args.packing_num_proc, strict=args.strict, load_from_cache_file=args.load_from_cache_file) elif args.streaming: From 71fd2ea2cb065ae7bad5f6fda4e3487ce22fc00a Mon Sep 17 00:00:00 2001 From: tastelikefeet <58414341+tastelikefeet@users.noreply.github.com> Date: Fri, 14 Nov 2025 00:20:47 +0800 Subject: [PATCH 06/29] Fix emb loss scale (#6597) --- swift/plugin/loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py index 0dc72882eb..61b3eb4cda 100755 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -446,7 +446,7 @@ def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **kw similarity_matrix = torch.cat(logits_list, dim=1) # temperature scaling and CE similarity_matrix = similarity_matrix / temperature - loss = nn.CrossEntropyLoss()(similarity_matrix, labels) / world_size # avoid duplicate + loss = nn.CrossEntropyLoss()(similarity_matrix, labels) else: all_tensors = [] for tensor in split_tensors: @@ -499,7 +499,6 @@ def infonce_loss(outputs, labels, loss_scale=None, num_items_in_batch=None, **kw # next positive is neg+1 length += tensor.size(0) - 1 loss /= len(split_tensors) - loss /= world_size # avoid duplicate return loss From 1ee2cd43514ad005cace2baf09f066b491193694 Mon Sep 17 00:00:00 2001 From: Jintao Date: Fri, 14 Nov 2025 16:02:12 +0800 Subject: [PATCH 07/29] [megatron] compat megatron-core 0.12-0.14 (#6599) --- .../Megatron-SWIFT/Command-line-parameters.md | 2 +- docs/source/Megatron-SWIFT/Quick-start.md | 3 +- .../Megatron-SWIFT/Command-line-parameters.md | 2 +- docs/source_en/Megatron-SWIFT/Quick-start.md | 3 +- examples/models/qwen3_next/mcore.sh | 13 +++- swift/megatron/argument/megatron_args.py | 2 + swift/megatron/init.py | 24 +++--- swift/megatron/model/gpt/qwen3_next.py | 70 +++++++++++------ swift/megatron/model/gpt_bridge.py | 16 +++- swift/megatron/model/gpt_model.py | 78 ++++++++++++++----- swift/megatron/model/model_provider.py | 21 +++-- swift/megatron/trainers/base.py | 17 ++-- swift/megatron/trainers/kto_trainer.py | 6 +- swift/megatron/trainers/reward_trainer.py | 1 + swift/megatron/trainers/trainer.py | 4 +- swift/megatron/trainers/utils.py | 14 +++- swift/megatron/tuners/lora.py | 4 +- 17 files changed, 195 insertions(+), 85 deletions(-) diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index e1de0d6d27..f551b9d906 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -298,7 +298,7 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 - gradient_checkpointing_kwargs: 传入`torch.utils.checkpoint`中的参数。例如设置为`--gradient_checkpointing_kwargs '{"use_reentrant": false}'`。默认为None。该参数只对`vit_gradient_checkpointing`生效。 - 🔥packing: 是否使用序列packing提升计算效率(不同节点与进程更负载均衡,GPU利用率更高;但需要额外的预处理时间)并稳定显存占用,默认为False。当前支持CPT/SFT/DPO/KTO/RM。 - 注意:**同一batch的不同序列之间依旧是不可见的**,除了Qwen3-Next。 - - 注意:**packing会导致数据集样本数减少,请自行调节梯度累加数和学习率**。 + - 注意:**packing会导致数据集样本数减少,请自行调节global_batch_size和学习率**。 - packing_length: packing的长度。默认为None,设置为max_length。 - packing_num_proc: packing的进程数,默认为1。需要注意的是,不同的`packing_num_proc`,最终形成的packed数据集是不同的。(该参数在流式packing时不生效) - streaming: 流式读取并处理数据集,默认False。 diff --git a/docs/source/Megatron-SWIFT/Quick-start.md b/docs/source/Megatron-SWIFT/Quick-start.md index 9161bdaf55..8c92e2b6b9 100644 --- a/docs/source/Megatron-SWIFT/Quick-start.md +++ b/docs/source/Megatron-SWIFT/Quick-start.md @@ -27,6 +27,7 @@ pip install --no-build-isolation transformer_engine[pytorch] # pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5#egg=transformer_engine[pytorch] # apex +# 提示:Megatron-SWIFT可以在不含apex的环境下运行,额外设置`--no_gradient_accumulation_fusion true`即可。 git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ @@ -65,7 +66,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 | torch | >=2.0 | 2.7.1/2.8.0 | | | transformer_engine | >=2.3 | | | | apex | | 0.1 | | -| megatron_core | | 0.14 | | +| megatron_core | >=0.12 | 0.14 | | | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.57.1 | | | modelscope | >=1.23 | | | diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 2c5e465576..8e0ef3085a 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -315,7 +315,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa - gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to `None`. This parameter only takes effect when `vit_gradient_checkpointing` is enabled. - 🔥packing: Whether to use sequence packing to improve computational efficiency (achieving better load balancing across nodes and processes, and higher GPU utilization), at the cost of additional preprocessing time, while also stabilizing GPU memory usage. Defaults to `False`. Currently supported for CPT, SFT, DPO, KTO and RM. - Note: **Sequences within the same batch remain mutually invisible**, except for Qwen3-Next. - - Note: **Packing reduces the number of samples in the dataset; please adjust the gradient accumulation steps and learning rate accordingly**. + - Note: **Packing will reduce the number of dataset samples. Please adjust global_batch_size and learning rate accordingly**. - packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. - packing_num_proc: Number of processes for packing, default is 1. Note that different values of `packing_num_proc` will result in different packed datasets. (This parameter does not take effect during streaming packing) - streaming: Stream data loading and processing, default is False. diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 292922b0d6..ed46f0471f 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -26,6 +26,7 @@ pip install --no-build-isolation transformer_engine[pytorch] # pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@release_v2.5#egg=transformer_engine[pytorch] # apex +# Note: Megatron-SWIFT can run in environments without apex by setting `--no_gradient_accumulation_fusion true`. git clone https://github.com/NVIDIA/apex cd apex pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ @@ -65,7 +66,7 @@ Recommended Operating Environment: | torch | >=2.0 | 2.7.1/2.8.0 | | | transformer_engine | >=2.3 | | | | apex | | 0.1 | | -| megatron_core | | 0.14 | | +| megatron_core | >=0.12 | 0.14 | | | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.57.1 | | | modelscope | >=1.23 | | | diff --git a/examples/models/qwen3_next/mcore.sh b/examples/models/qwen3_next/mcore.sh index 6b36795beb..f520429868 100644 --- a/examples/models/qwen3_next/mcore.sh +++ b/examples/models/qwen3_next/mcore.sh @@ -11,7 +11,10 @@ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=8 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ megatron sft \ - --load Qwen3-Next-80B-A3B-Instruct-mcore \ + --model Qwen/Qwen3-Next-80B-A3B-Instruct \ + --load_safetensors true \ + --save_safetensors true \ + --merge_lora false \ --dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT#2000' \ 'swift/self-cognition#1000' \ --load_from_cache_file true \ @@ -23,7 +26,7 @@ megatron sft \ --moe_permute_fusion true \ --moe_grouped_gemm true \ --moe_shared_expert_overlap true \ - --moe_aux_loss_coeff 1e-3 \ + --moe_aux_loss_coeff 1e-6 \ --micro_batch_size 2 \ --global_batch_size 16 \ --recompute_granularity full \ @@ -47,3 +50,9 @@ megatron sft \ --attention_backend flash \ --model_author swift \ --model_name swift-robot + + +# CUDA_VISIBLE_DEVICES=0,1,2,3 \ +# swift infer \ +# --adapters megatron_output/Qwen3-Next-80B-A3B-Instruct/vx-xxx/checkpoint-xxx \ +# --stream true diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index d9f6ba9e22..556a631f36 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -454,6 +454,8 @@ def __post_init__(self): MegatronTunerMixin.__post_init__(self) os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' self._set_default() + if self.optimizer_cpu_offload: + require_version('megatron-core>=0.13') self.model_info, self.model_meta = get_model_info_meta( self.model, model_type=self.model_type, use_hf=self.use_hf, hub_token=self.hub_token) self.model_type = self.model_info.model_type diff --git a/swift/megatron/init.py b/swift/megatron/init.py index 6a591f4429..fcf602ed00 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -66,7 +66,7 @@ def _patch_mla_attention(): gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region, ) - megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') # Code borrowed from NVIDIA/Megatron-LM def forward( @@ -112,7 +112,7 @@ def forward( # Adjust key, value for inference # =================================================== # rotary_pos_emb = None - if megatron_core_013: + if mcore_013: query, key, value, _, attn_mask_type, _ = self._adjust_key_value_for_inference( inference_context, query, key, value, rotary_pos_emb=None) else: @@ -430,7 +430,7 @@ def _patch_TransformerLayer(): from megatron.training import get_args from megatron.core.transformer import TransformerLayer _origin_forward = TransformerLayer.forward - megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') def forward(self, *_args, **kwargs): """ @@ -439,7 +439,7 @@ def forward(self, *_args, **kwargs): This method calls the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations. """ - if not megatron_core_013: + if not mcore_013: return _origin_forward(self, *_args, **kwargs) hidden_states, context = self._forward_attention(*_args, **kwargs) args = get_args() @@ -551,11 +551,14 @@ def build_train_valid_test_datasets(build_train_valid_test_datasets_provider): def _patch_mrope(): from megatron.core.models.common.embeddings.rotary_pos_embedding import MultimodalRotaryEmbedding from megatron.core import parallel_state + import megatron.core from megatron.core.models.common.embeddings.rope_utils import (get_pos_emb_on_this_cp_rank, _apply_rotary_pos_emb_bshd) from megatron.core.models.common.embeddings import rope_utils from megatron.training import get_args + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + # Code borrowed from huggingface/transformers def apply_interleaved_mrope(freqs, mrope_section): """Apply interleaved MRoPE to 3D rotary embeddings. @@ -638,13 +641,16 @@ def _apply_rotary_pos_emb_thd( Returns: Tensor: Shape [t, h, d]. The input tensor after applying RoPE. """ - use_batched_rope = False if cp_group is not None: cp_size = cp_group.size() - cu_seqlens_for_batched = cu_seqlens // cp_size - use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item() + else: + args = get_args() + cp_size = args.context_parallel_size + cu_seqlens_for_batched = cu_seqlens // cp_size + use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item() if not use_batched_rope: logger.warning_once('Using non-batched RoPE, which may affect performance.') + kwargs = {'cp_group': cp_group} if mcore_013 else {} return _origin_apply_rotary_pos_emb_thd( t, cu_seqlens, @@ -652,10 +658,8 @@ def _apply_rotary_pos_emb_thd( rotary_interleaved=rotary_interleaved, multi_latent_attention=multi_latent_attention, mscale=mscale, - cp_group=cp_group, + **kwargs, ) - if cp_group is None: - raise ValueError('cp_group must be provided for THD format RoPE') return _apply_rotary_pos_emb_bshd( t.unsqueeze(1), diff --git a/swift/megatron/model/gpt/qwen3_next.py b/swift/megatron/model/gpt/qwen3_next.py index ab95d2a4d0..7a1419c5c5 100644 --- a/swift/megatron/model/gpt/qwen3_next.py +++ b/swift/megatron/model/gpt/qwen3_next.py @@ -2,6 +2,7 @@ from copy import deepcopy from typing import Optional, Tuple, Union +import megatron.core import torch from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm, _get_extra_te_kwargs from megatron.core.inference.contexts import BaseInferenceContext @@ -17,6 +18,7 @@ from megatron.core.transformer.transformer_layer import get_transformer_layer_offset from megatron.core.utils import deprecate_inference_params, is_fa_min_version from megatron.training import get_args +from packaging import version from swift.llm import ModelType from swift.utils import get_logger @@ -24,6 +26,7 @@ from ..gpt_bridge import GPTBridge from ..register import MegatronModelMeta, register_megatron_model +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') try: from flashattn_hopper.flash_attn_interface import _flash_attn_forward from flashattn_hopper.flash_attn_interface import flash_attn_with_kvcache as flash_attn3_with_kvcache @@ -58,6 +61,7 @@ class Qwen3NextSelfAttention(SelfAttention): def __init__(self, config: TransformerConfig, submodules: SelfAttentionSubmodules, *args, **kwargs): super(SelfAttention, self).__init__(config, submodules, *args, attention_type='self', **kwargs) + kwargs = {'tp_group': self.model_comm_pgs.tp} if mcore_013 else {} self.linear_qkv = build_module( submodules.linear_qkv, self.config.hidden_size, @@ -69,7 +73,7 @@ def __init__(self, config: TransformerConfig, submodules: SelfAttentionSubmodule skip_bias_add=False, is_expert=False, tp_comm_buffer_name='qkv', - tp_group=self.model_comm_pgs.tp, + **kwargs, ) if submodules.q_layernorm is not None: @@ -130,12 +134,22 @@ def forward( (Tuple[Tensor, Tensor]) Attention output and bias. """ - from megatron.core.utils import nvtx_range_pop, nvtx_range_push + try: + from megatron.core.utils import nvtx_range_pop, nvtx_range_push + except ImportError: + + def nvtx_range_pop(*args, **kwargs): + return + + def nvtx_range_push(*args, **kwargs): + return + # Check if we need to skip RoPE # no_rope is 0-indexed array and self.layer_number is 1-indexed - no_rope = (self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False) - if no_rope: - rotary_pos_emb = None + if hasattr(self.config, 'no_rope_freq'): + no_rope = (self.config.no_rope_freq[self.layer_number - 1] if self.config.no_rope_freq else False) + if no_rope: + rotary_pos_emb = None inference_context = deprecate_inference_params(inference_context, inference_params) @@ -194,17 +208,20 @@ def forward( if (in_decode_mode and self.config.enable_cuda_graph and inference_context.is_static_batching()): raise ValueError('CUDA graphs must use flash decode with static batching!') - query, key, value, rotary_pos_emb, attn_mask_type, block_table = ( - self._adjust_key_value_for_inference( - inference_context, - query, - key, - value, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - sequence_len_offset, - )) + result = self._adjust_key_value_for_inference( + inference_context, + query, + key, + value, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + ) + if mcore_013: + query, key, value, rotary_pos_emb, attn_mask_type, block_table = result + else: + query, key, value, rotary_pos_emb, attn_mask_type = result if packed_seq_params is not None: query = query.squeeze(1) @@ -215,6 +232,7 @@ def forward( # ================================================ # relative positional embedding (rotary embedding) # ================================================ + kwargs = {'cp_group': self.model_comm_pgs.cp} if mcore_013 else {} nvtx_range_push(suffix='rotary_pos_emb') if rotary_pos_emb is not None and not self.config.flash_decode: q_pos_emb, k_pos_emb = rotary_pos_emb @@ -239,18 +257,18 @@ def forward( q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q, - cp_group=self.model_comm_pgs.cp, + **kwargs, ) else: query = inference_context.apply_rotary_emb_query(query, q_pos_emb, self.config, cu_seqlens_q, - self.model_comm_pgs.cp) + **kwargs) if k_pos_emb is not None: key = apply_rotary_pos_emb( key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv, - cp_group=self.model_comm_pgs.cp, + **kwargs, ) # TODO, can apply positional embedding to value_layer so it has @@ -418,16 +436,17 @@ def forward(self, hidden_states: torch.Tensor, **kwargs): def get_local_layer_specs(config, layer_specs, vp_stage=None): - from megatron.core.transformer.enums import LayerType - num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage) + kwargs = {'vp_stage': vp_stage} if mcore_013 else {} + num_layers_to_build = get_num_layers_to_build(config, **kwargs) - if config.pipeline_model_parallel_layout is not None: + if getattr(config, 'pipeline_model_parallel_layout', None) is not None: + from megatron.core.transformer.enums import LayerType local_layer_specs = [ layer_specs[layer_id] for layer_id in config.pipeline_model_parallel_layout.get_layer_id_list( - layer_type=LayerType.decoder, vp_stage=vp_stage) + layer_type=LayerType.decoder, **kwargs) ] else: - offset = get_transformer_layer_offset(config, vp_stage=vp_stage) + offset = get_transformer_layer_offset(config, **kwargs) local_layer_specs = layer_specs[offset:offset + num_layers_to_build] return local_layer_specs @@ -446,13 +465,14 @@ def get_qwen3_next_transformer_layer_spec(config, vp_stage=None): config.linear_conv_kernel_dim = args.linear_conv_kernel_dim layer_norm_impl = TENorm + kwargs = {'use_kitchen': config.use_kitchen} if mcore_013 else {} moe_layer_spec = get_gpt_layer_with_transformer_engine_spec( num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm, multi_latent_attention=config.multi_latent_attention, moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, - use_kitchen=config.use_kitchen, + **kwargs, ) layer_specs = [] for layer_type in args.layer_types: diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 5d69a10df6..b86c16e188 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -20,6 +20,8 @@ logger = get_logger() +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + # Some ideas for LoRA conversion are referenced from: https://github.com/modelscope/ms-swift/pull/6225 class GPTBridge: @@ -43,7 +45,7 @@ def __init__(self, disable_tqmd: bool = False): self._init_meta_hf_model() self.hf_layers = deep_getattr(self.hf_model, self.hf_layers_prefix) self.module_mapping = {} - self.megatron_core_014 = version.parse(megatron.core.__version__) >= version.parse('0.14.0rc0') + self.mcore_014 = version.parse(megatron.core.__version__) >= version.parse('0.14.0rc0') megatron_model_meta = get_megatron_model_meta(self.args.hf_model_type) if self.args.is_multimodal and megatron_model_meta.visual_cls is not None: self.module_mapping = megatron_model_meta.visual_cls.module_mapping @@ -81,7 +83,7 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: } if self.args.task_type == 'causal_lm': dim0_keys.add('output_layer') - if not self.megatron_core_014: + if not self.mcore_014: # https://github.com/NVIDIA/Megatron-LM/commit/720c8b40d8e7e2de1dd303d792f29093101c5e72 dim0_keys.update({'linear_q_down_proj', 'linear_kv_down_proj'}) # RowLinear @@ -971,7 +973,13 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd hf_state_dict = {} mg_models = iter(mg_models) mg_model = next(mg_models) - if not to_mcore or mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): + if mcore_013: + is_pp_first_stage = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage) + is_pp_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage) + else: + is_pp_first_stage = mpu.is_pipeline_first_stage() + is_pp_last_stage = mpu.is_pipeline_last_stage() + if not to_mcore or is_pp_first_stage: hf_state_dict.update(self._convert_pre_process(mg_model, hf_state_dict, '', to_mcore)) if to_mcore: yield @@ -1010,7 +1018,7 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd else: yield from list(self._add_prefix(res, hf_prefix).items()) hf_state_dict = {} - if not to_mcore or mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=mg_model.vp_stage): + if not to_mcore or is_pp_last_stage: hf_state_dict.update(self._convert_post_process(mg_model, hf_state_dict, '', to_mcore)) if to_mcore: yield diff --git a/swift/megatron/model/gpt_model.py b/swift/megatron/model/gpt_model.py index 0aaa563277..b529a73337 100644 --- a/swift/megatron/model/gpt_model.py +++ b/swift/megatron/model/gpt_model.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from typing import Any, Dict, Literal, Optional, Tuple +import megatron.core import torch from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk from megatron.core.dist_checkpointing.mapping import ShardedStateDict @@ -16,12 +17,15 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import WrappedTensor, deprecate_inference_params from megatron.training import get_args +from packaging import version from swift.utils import get_logger from .rope import dynamic_rope_update, get_rope_inv_freq logger = get_logger() +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + class OutputLayerLinear(TELinear): @@ -77,6 +81,12 @@ def __init__( config.mscale_all_dim = hf_rope_scaling['mscale_all_dim'] config.rotary_scaling_factor = hf_rope_scaling['factor'] self.hf_rope_scaling = hf_rope_scaling + if mcore_013: + kwargs = {'vp_stage': vp_stage} + else: + self.vp_stage = vp_stage + assert vp_stage is None, 'megatron-core==0.12 does not support vp_stage' + kwargs = {} super().__init__( config, transformer_layer_spec, @@ -95,7 +105,7 @@ def __init__( scatter_embedding_sequence_parallel=scatter_embedding_sequence_parallel, seq_len_interpolation_factor=seq_len_interpolation_factor, mtp_block_spec=mtp_block_spec, - vp_stage=vp_stage, + **kwargs, ) if config.multi_latent_attention: self.rotary_pos_emb = RotaryEmbedding( @@ -293,25 +303,53 @@ def forward( ) args = get_args() - return self._postprocess( - hidden_states=hidden_states, - input_ids=input_ids, - position_ids=position_ids, - labels=labels if args.task_type == 'causal_lm' else None, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - mtp_in_postprocess=self.mtp_process, - loss_mask=loss_mask, - decoder_input=decoder_input, - attention_mask=attention_mask, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - runtime_gather_output=runtime_gather_output, - extra_block_kwargs=extra_block_kwargs, - inference_context=inference_context, - ) + labels = labels if args.task_type == 'causal_lm' else None + if mcore_013: + return self._postprocess( + hidden_states=hidden_states, + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + mtp_in_postprocess=self.mtp_process, + loss_mask=loss_mask, + decoder_input=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + runtime_gather_output=runtime_gather_output, + extra_block_kwargs=extra_block_kwargs, + inference_context=inference_context, + ) + else: + if not self.post_process: + return hidden_states + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + logits, _ = self.output_layer( + hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) + if has_config_logger_enabled(self.config): + payload = OrderedDict({ + 'input_ids': input_ids, + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'decoder_input': decoder_input, + 'logits': logits, + }) + log_config_to_disk(self.config, payload, prefix='input_and_logits') + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + loss = self.compute_language_model_loss(labels, logits) + + return loss def get_input_tensor(self): return self.decoder.input_tensor diff --git a/swift/megatron/model/model_provider.py b/swift/megatron/model/model_provider.py index 8edb17c21f..997f53a231 100644 --- a/swift/megatron/model/model_provider.py +++ b/swift/megatron/model/model_provider.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import TYPE_CHECKING, Optional, Union +import megatron.core import megatron.legacy import torch from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, get_gpt_layer_local_spec, @@ -11,6 +12,9 @@ from megatron.training import get_args, print_rank_0 from megatron.training.arguments import core_transformer_config_from_args from megatron.training.yaml_arguments import core_transformer_config_from_yaml +from packaging import version + +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') if TYPE_CHECKING: from .gpt_model import GPTModel @@ -29,14 +33,17 @@ def _get_transformer_layer_spec(use_te, config): """ args = get_args() if use_te: + if mcore_013: + kwargs = {'qk_l2_norm': args.qk_l2_norm, 'use_kitchen': config.use_kitchen} + else: + kwargs = {} return get_gpt_layer_with_transformer_engine_spec( args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention, moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, - qk_l2_norm=args.qk_l2_norm, - use_kitchen=config.use_kitchen, + **kwargs, ) else: return get_gpt_layer_local_spec( @@ -110,13 +117,13 @@ def oom_observer(device, alloc, device_alloc, device_free): transformer_layer_spec = megatron_model_meta.get_transformer_layer_spec(config, vp_stage=vp_stage) else: if args.num_experts: + if mcore_013: + kwargs = {'qk_l2_norm': args.qk_l2_norm, 'vp_stage': vp_stage} + else: + kwargs = {} # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( - config, - use_transformer_engine=use_te, - normalization=args.normalization, - qk_l2_norm=args.qk_l2_norm, - vp_stage=vp_stage) + config, use_transformer_engine=use_te, normalization=args.normalization, **kwargs) elif args.heterogeneous_layers_config_path is not None: transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te) else: diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index c4d0dc2aba..164fe0ee0a 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -14,7 +14,7 @@ from megatron.core import mpu from megatron.core.enums import ModelType from megatron.core.num_microbatches_calculator import get_num_microbatches -from megatron.core.optimizer import _update_min_and_max_lr_in_param_groups, param_group_identifier_keys +from megatron.core.optimizer import _update_min_and_max_lr_in_param_groups from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import RerunMode, get_rerun_state_machine from megatron.core.transformer.module import MegatronModule @@ -40,6 +40,11 @@ from .utils import (get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_packed_seq_params, get_swift_datasets_provider) +try: + from megatron.core.optimizer import param_group_identifier_keys +except ImportError: + param_group_identifier_keys = None + logger = get_logger() @@ -64,7 +69,7 @@ def _get_mean_metric(): 'train': collections.defaultdict(_get_mean_metric), 'eval': collections.defaultdict(_get_mean_metric) } - self.megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + self.mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @property def bridge(self): @@ -363,7 +368,8 @@ def _get_param_groups( } # Ensure param_group has required keys for matching when loading optimizer state # See MegatronOptimizer._filter_and_reorder_param_groups. - assert set(param_group.keys()) - set(param_group_identifier_keys) == {'params'} + if param_group_identifier_keys is not None: + assert set(param_group.keys()) - set(param_group_identifier_keys) == {'params'} param_groups.append(param_group) param_groups = _update_min_and_max_lr_in_param_groups( @@ -471,8 +477,7 @@ def _initialize_embedding(model): def _all_reduce_metric(self, metric: Dict[str, torch.Tensor], reduction=torch.distributed.ReduceOp.AVG) -> Dict[str, torch.Tensor]: - values = list(metric.values()) - reporting_metric = values[0].new_tensor(values) + reporting_metric = torch.stack(list(metric.values()), dim=0) torch.distributed.all_reduce(reporting_metric, reduction, group=mpu.get_data_parallel_group()) return {k: reporting_metric[i] for i, k in enumerate(metric.keys())} @@ -559,7 +564,7 @@ def evaluate( torch.cuda.empty_cache() if mpu.is_pipeline_last_stage(ignore_virtual=True): - if self.megatron_core_013: + if self.mcore_013: for key in loss_dicts[0].keys(): if key not in total_loss_dict: total_loss_dict[key] = torch.tensor([0.0, 0.0], dtype=torch.float).cuda() diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index d0a385aa41..f201767d3e 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -143,7 +143,11 @@ def forward_step(self, data_iterator, model): unwrapped_model.set_input_tensor(self._get_input_tensor(input_tensor, False, False, length, 0)) with self.stimer: output_tensor = model(**data) - dim = 1 if mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage) else 0 + if self.mcore_013: + is_pp_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage) + else: + is_pp_last_stage = mpu.is_pipeline_last_stage() + dim = 1 if is_pp_last_stage else 0 if self.args.calculate_KL: res = torch.concat([output_tensor, ref_output_tensor, KL_output_tensor, ref_KL_output_tensor], dim=dim) else: diff --git a/swift/megatron/trainers/reward_trainer.py b/swift/megatron/trainers/reward_trainer.py index 852f488ed2..08800826e7 100644 --- a/swift/megatron/trainers/reward_trainer.py +++ b/swift/megatron/trainers/reward_trainer.py @@ -16,6 +16,7 @@ class MegatronRewardTrainer(MegatronRLHFTrainer): def __init__(self, args, template): super().__init__(args, template) assert args.padding_free, 'Currently `rlhf_type="rm"` only supports padding_free.' + assert args.context_parallel_size == 1, 'Currently `rlhf_type="rm"` does not support context parallelism.' def loss_func(self, output_tensor, *, data): packed_seq_params = data.get('packed_seq_params') diff --git a/swift/megatron/trainers/trainer.py b/swift/megatron/trainers/trainer.py index 98422b8c43..0fc193fd21 100644 --- a/swift/megatron/trainers/trainer.py +++ b/swift/megatron/trainers/trainer.py @@ -76,7 +76,7 @@ def loss_func(self, loss = torch.cat([torch.sum(losses * loss_mask).view(1), loss_mask.sum().view(1)]) - if args.context_parallel_size > 1 and not self.megatron_core_013: + if args.context_parallel_size > 1 and not self.mcore_013: loss = all_reduce(loss, group=mpu.get_context_parallel_group()) # Check individual rank losses are not NaN prior to DP all-reduce. @@ -114,7 +114,7 @@ def loss_func(self, # Reduce loss for logging. reporting_loss = loss.detach().clone() lm_loss = loss[0] - if not self.megatron_core_013: + if not self.mcore_013: # fix megatron-lm bug # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index abfcfbd0cc..6879fe23bf 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -1,15 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict +import megatron.core import torch from megatron.core import mpu from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank from megatron.training import get_args +from packaging import version from swift.llm import get_packed_seq_params as _get_packed_seq_params from swift.llm import to_device +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + def get_swift_datasets_provider(train_dataset, val_dataset): @@ -37,9 +41,15 @@ def get_batch_on_this_tp_rank(data, vp_stage=None): batch = to_device(data, 'cuda', non_blocking=True) if args.pipeline_model_parallel_size == 1: return batch - if not mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage): + if mcore_013: + is_pp_first_stage = mpu.is_pipeline_first_stage(ignore_virtual=False, vp_stage=vp_stage) + is_pp_last_stage = mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage) + else: + is_pp_first_stage = mpu.is_pipeline_first_stage() + is_pp_last_stage = mpu.is_pipeline_last_stage() + if not is_pp_first_stage: batch['input_ids'] = None - if not mpu.is_pipeline_last_stage(ignore_virtual=False, vp_stage=vp_stage): + if not is_pp_last_stage: batch['labels'] = None batch['loss_scale'] = None diff --git a/swift/megatron/tuners/lora.py b/swift/megatron/tuners/lora.py index 2222a465be..815fa63d5c 100644 --- a/swift/megatron/tuners/lora.py +++ b/swift/megatron/tuners/lora.py @@ -29,7 +29,7 @@ from swift.utils import get_current_device from ..utils import tuners_sharded_state_dict -megatron_core_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') class LoraParallelLinear(MegatronModule, LoraLayer): @@ -99,7 +99,7 @@ def update_layer(self, adapter_name, r, *, lora_alpha, lora_dropout, init_lora_w 'config': self.config, 'is_expert': self.is_expert, } - if megatron_core_013: + if mcore_013: kwargs['tp_group'] = self.base_layer.tp_group if isinstance(self.base_layer, TopKRouter): router_shape = self.base_layer.weight.shape From d932941f7edc96cfa38f25e68cba6c874d7d0b62 Mon Sep 17 00:00:00 2001 From: Jintao Date: Fri, 14 Nov 2025 16:21:45 +0800 Subject: [PATCH 08/29] [kto] fix kto apo_zero_unpaired (#6601) --- swift/llm/template/base.py | 16 ++++++++++++---- swift/llm/train/kto.py | 21 +++++++++++---------- swift/megatron/trainers/kto_trainer.py | 8 ++++++-- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index fce20eb7d2..a402e56af6 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -357,11 +357,16 @@ def get_base_model(model): else: return model - def _rlhf_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: + def _rlhf_encode(self, inputs: TemplateInputs, check_rejected=True) -> Dict[str, Any]: chosen = inputs.chosen margin = chosen.margin chosen_encoded = self._encode_truncated(chosen) - rejected_encoded = self._encode_truncated(inputs.rejected) + if inputs.rejected is None: + if check_rejected: + raise ValueError('inputs.rejected is None') + rejected_encoded = {} + else: + rejected_encoded = self._encode_truncated(inputs.rejected) encoded = {} for prefix in ['chosen', 'rejected']: @@ -373,7 +378,7 @@ def _rlhf_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: return encoded def _kto_encode(self, inputs: TemplateInputs) -> Dict[str, Any]: - encoded = self._rlhf_encode(inputs) + encoded = self._rlhf_encode(inputs, check_rejected=False) encoded['label'] = bool(inputs.chosen.label) return encoded @@ -1485,7 +1490,10 @@ def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optiona kl_batch = self._fetch_inputs_startswith(batch, 'rejected_') res = self._data_collator(new_batch, padding_to=padding_to) - kl_res = self._data_collator(kl_batch, padding_to=padding_to) + if any(kl_batch): + kl_res = self._data_collator(kl_batch, padding_to=padding_to) + else: + kl_res = {} res = { **{f'completion_{k}': v for k, v in res.items()}, diff --git a/swift/llm/train/kto.py b/swift/llm/train/kto.py index 966c11cb61..63da51013b 100644 --- a/swift/llm/train/kto.py +++ b/swift/llm/train/kto.py @@ -41,16 +41,17 @@ def _get_kl_dataset(dataset: Optional[HfDataset], def prepare_kto_dataset(args, train_dataset, val_dataset): - world_size = get_dist_setting()[2] - if hasattr(args, 'global_batch_size') and args.global_batch_size is not None: - total_batch_size = args.global_batch_size - else: - total_batch_size = (world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps) - if total_batch_size <= 1: - raise ValueError('Batch size is 1 (too small). KTO will not work properly because the KL term ' - 'will be equivalent to the implied reward.') - train_dataset = _get_kl_dataset(train_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) - val_dataset = _get_kl_dataset(val_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) + if args.loss_type != 'apo_zero_unpaired': + world_size = get_dist_setting()[2] + if hasattr(args, 'global_batch_size') and args.global_batch_size is not None: + total_batch_size = args.global_batch_size + else: + total_batch_size = (world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps) + if total_batch_size <= 1: + raise ValueError('Batch size is 1 (too small). KTO will not work properly because the KL term ' + 'will be equivalent to the implied reward.') + train_dataset = _get_kl_dataset(train_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) + val_dataset = _get_kl_dataset(val_dataset, total_batch_size, args.dataset_num_proc, args.data_seed) label = train_dataset['label'] num_desirable = max(sum(label), 1) diff --git a/swift/megatron/trainers/kto_trainer.py b/swift/megatron/trainers/kto_trainer.py index f201767d3e..9ddb8ab343 100644 --- a/swift/megatron/trainers/kto_trainer.py +++ b/swift/megatron/trainers/kto_trainer.py @@ -76,7 +76,7 @@ def loss_func(self, output_tensor, *, data, kl_data, label): loss = loss.mean() mean_metric = { 'loss': loss.detach().clone(), - 'kl': kl.detach(), + 'kl': kl.squeeze().detach(), } metric = self._all_reduce_metric(mean_metric) sum_metric = { @@ -159,7 +159,11 @@ def _prepare_batch(self, data, vp_stage): num_samples = data.pop('num_samples') for key in ['completion_', 'KL_completion_']: _data = {k[len(key):]: v for k, v in data.items() if k.startswith(key)} - res.append(super()._prepare_batch(_data, vp_stage, num_samples)) + if not self.args.calculate_KL and key == 'KL_completion_': + _data = {} + else: + _data = super()._prepare_batch(_data, vp_stage, num_samples) + res.append(_data) res[0]['label'] = data['label'] return res From 3785cb963f32dcb935ce34ad542a7080faef5cb5 Mon Sep 17 00:00:00 2001 From: slin000111 <127832064+slin000111@users.noreply.github.com> Date: Fri, 14 Nov 2025 16:49:42 +0800 Subject: [PATCH 09/29] Fix a bug in the command line display on the UI. (#6603) --- swift/ui/llm_train/llm_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swift/ui/llm_train/llm_train.py b/swift/ui/llm_train/llm_train.py index 092de2f3d5..4426fbebb9 100644 --- a/swift/ui/llm_train/llm_train.py +++ b/swift/ui/llm_train/llm_train.py @@ -473,8 +473,8 @@ def train(cls, *args): else: cuda_param = '' if envs: - envs = envs.split(' ') - for env in envs: + env_list = envs.split(' ') + for env in env_list: k, v = env.split('=') all_envs[k] = v log_file = os.path.join(sft_args.logging_dir, 'run.log') From 1149f7311e6a53377d278d7d2e25514438e9e58f Mon Sep 17 00:00:00 2001 From: jinghanhu Date: Fri, 14 Nov 2025 20:47:44 +0800 Subject: [PATCH 10/29] Support Megatron GRPO (#6025) --- README.md | 3 +- README_CN.md | 3 +- .../Instruction/Command-line-parameters.md | 14 +- .../Instruction/GRPO/AdvancedResearch/GSPO.md | 4 +- docs/source/Instruction/Use-tuners.md | 2 +- .../Megatron-SWIFT/Command-line-parameters.md | 95 +- docs/source/Megatron-SWIFT/GRPO.md | 61 + .../source/Megatron-SWIFT/Multimodal-Model.md | 2 +- docs/source/Megatron-SWIFT/Quick-start.md | 1 + docs/source/index.rst | 1 + .../Instruction/Command-line-parameters.md | 8 +- .../Instruction/GRPO/AdvancedResearch/GSPO.md | 4 +- docs/source_en/Instruction/Use-tuners.md | 2 +- .../Megatron-SWIFT/Command-line-parameters.md | 100 +- docs/source_en/Megatron-SWIFT/GRPO.md | 61 + .../Megatron-SWIFT/Multimodal-Model.md | 2 +- docs/source_en/Megatron-SWIFT/Quick-start.md | 3 +- docs/source_en/index.rst | 1 + examples/megatron/grpo/dense_colocate.sh | 65 + examples/megatron/grpo/dense_server.sh | 72 + examples/megatron/grpo/moe_colocate_full.sh | 55 + examples/megatron/grpo/moe_colocate_lora.sh | 53 + swift/llm/dataset/dataset/llm.py | 7 + swift/llm/template/base.py | 2 + swift/megatron/argument/megatron_args.py | 188 ++- swift/megatron/argument/rlhf_args.py | 2 +- swift/megatron/train/rlhf.py | 45 +- swift/megatron/trainers/__init__.py | 1 + swift/megatron/trainers/base.py | 10 +- swift/megatron/trainers/grpo_trainer.py | 1405 +++++++++++++++++ swift/megatron/trainers/rlhf_mixin.py | 69 +- swift/megatron/trainers/utils.py | 257 ++- swift/megatron/tuners/lora.py | 1 + swift/trainers/arguments.py | 2 +- swift/trainers/rlhf_trainer/__init__.py | 2 + swift/trainers/rlhf_trainer/grpo_trainer.py | 2 +- swift/trainers/rlhf_trainer/rollout_mixin.py | 1 + swift/trainers/rlhf_trainer/vllm_client.py | 7 +- 38 files changed, 2523 insertions(+), 90 deletions(-) create mode 100644 docs/source/Megatron-SWIFT/GRPO.md create mode 100644 docs/source_en/Megatron-SWIFT/GRPO.md create mode 100644 examples/megatron/grpo/dense_colocate.sh create mode 100644 examples/megatron/grpo/dense_server.sh create mode 100644 examples/megatron/grpo/moe_colocate_full.sh create mode 100644 examples/megatron/grpo/moe_colocate_lora.sh create mode 100644 swift/megatron/trainers/grpo_trainer.py diff --git a/README.md b/README.md index b263879aec..455ff142d3 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ You can contact us and communicate with us by adding our group: - **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ. - 🍊 **RLHF Training**: Supports human alignment training methods such as DPO, GRPO, RM, PPO, GKD, KTO, CPO, SimPO, ORPO for both pure text and multi-modal large models. - 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding. -- 🥥 **Megatron Parallelism**: Supports accelerating CPT/SFT/DPO/KTO/RM using Megatron parallelism techniques, currently compatible with 200+ pure text large models, 100+ multi-modal large models. +- 🥥 **Megatron Parallelism**: Supports accelerating CPT/SFT/GRPO/DPO/KTO/RM using Megatron parallelism techniques, currently compatible with 200+ pure text large models, 100+ multi-modal large models. - **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline. - **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer. - 🍉 **Toolbox Capabilities**: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment. @@ -75,6 +75,7 @@ You can contact us and communicate with us by adding our group: ## 🎉 News +- 🎁 2025.11.14: Megatron GRPO is now available! Check out the [docs](./docs/source_en/Megatron-SWIFT/GRPO.md) and [examples](examples/megatron/grpo). - 🎁 2025.11.04: Support for [Mcore-Bridge](docs/source_en/Megatron-SWIFT/Mcore-Bridge.md), making Megatron training as simple and easy to use as transformers. - 🎁 2025.10.28: Ray [here](docs/source_en/Instruction/Ray.md). - 🎁 2025.10.28: Support [use yaml](examples/yaml) to configure command line parameters. diff --git a/README_CN.md b/README_CN.md index da2b914169..08a7f1b93d 100644 --- a/README_CN.md +++ b/README_CN.md @@ -62,7 +62,7 @@ - **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。 - 🍊 **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、GKD、KTO、CPO、SimPO、ORPO等人类对齐训练方法。 - 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。 -- 🥥 **Megatron并行技术**:支持使用Megatron并行技术对CPT/SFT/DPO/KTO/RM进行加速,现支持200+纯文本大模型和100+多模态大模型。 +- 🥥 **Megatron并行技术**:支持使用Megatron并行技术对CPT/SFT/GRPO/DPO/KTO/RM进行加速,现支持200+纯文本大模型和100+多模态大模型。 - **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。 - **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。 - 🍉 **工具箱能力**:不仅提供大模型和多模态大模型的训练支持,还涵盖其推理、评测、量化和部署全流程。 @@ -71,6 +71,7 @@ - **模型量化**:支持AWQ、GPTQ、FP8和BNB的量化导出,导出的模型支持使用vLLM/SGLang/LmDeploy推理加速,并支持继续训练。 ## 🎉 新闻 +- 🎁 2025.11.14: Megatron GRPO现已支持!查看[文档](./docs/source/Megatron-SWIFT/GRPO.md)和[示例](examples/megatron/grpo)。 - 🎁 2025.11.04: 支持[Mcore-Bridge](docs/source/Megatron-SWIFT/Mcore-Bridge.md),使Megatron训练像transformers一样简单易用。 - 🎁 2025.10.28: Ray [已支持](docs/source/Instruction/Ray.md)。 - 🎁 2025.10.28: 已支持[使用yaml](examples/yaml)配置命令行参数。 diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index 53e4d349c8..2a2f9db2d4 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -566,13 +566,13 @@ reward模型参数将在PPO、GRPO中使用。 - use_vllm: 是否使用 vLLM 作为 GRPO 生成的 infer_backend,默认为False。 - vllm_mode: vLLM 集成模式,可选项为 `server` 和 `colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。使用server端时, - vllm_mode server 参数 + - vllm_server_host: vLLM server host地址,默认为None。 + - vllm_server_port: vLLM server 服务端口,默认为8000。 - vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。 - - vllm_server_host:vLLM server host地址,默认为None。 - - vllm_server_port vLLM server 服务端口,默认为8000。 - - vllm_server_timeout 连接vLLM server的超时时间,默认为 240s。 + - vllm_server_timeout: 连接vLLM server的超时时间,默认为 240s。 - vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。 - async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`. - - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE:环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。 + - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE: 环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。 - vllm_mode colocate 参数(更多参数支持参考[vLLM参数](#vLLM参数)。) - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。 - vllm_max_model_len: vllm透传参数,默认为None。 @@ -581,7 +581,7 @@ reward模型参数将在PPO、GRPO中使用。 - vllm_enable_prefix_caching: vllm透传参数,默认为True。 - vllm_tensor_parallel_size: tp并行数,默认为`1`。 - vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](./GRPO/GetStarted/GRPO.md#权重同步加速)。 - - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1], 默认为0,不释放。 + - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1, 2], 默认为0,不释放。 - offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。 - offload_model: 是否在vLLM推理时 offload 模型,默认为False。 - completion_length_limit_scope: 在多轮对话中,`max_completion_length` 的限制范围。 @@ -593,7 +593,7 @@ reward模型参数将在PPO、GRPO中使用。 - max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。 - overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 - delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。 -- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://www.arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。 +- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。 - advantage_estimator: 优势计算函数,默认为 `grpo`,即计算组内相对优势,可选项为 `grpo`、[`rloo`](./GRPO/AdvancedResearch/RLOO.md)、[`reinforce_plus_plus`](./GRPO/AdvancedResearch/REINFORCEPP.md)。 - kl_in_reward: 控制 KL 散度正则项的处理位置;`false`表示作为损失函数的独立正则项,`true`表示将 KL 直接并入奖励(从奖励中扣除)。默认情况与advantage_estimator绑定,`grpo`下默认为`false`,`rloo` 和 `reinforce_plus_plus` 下默认为 `true`。 - scale_rewards:指定奖励的缩放策略。可选值包括 `group`(按组内标准差缩放)、`batch`(按整个批次的标准差缩放)、`none`(不进行缩放)。在 ms-swift < 3.10 版本中,该参数为布尔类型,`true` 对应 `group`,`false` 对应 `none`。默认值与 `advantage_estimator` 绑定:`grpo` 对应 `group`,`rloo` 对应 `none`,`reinforce_plus_plus` 对应 `batch`。 @@ -606,6 +606,8 @@ reward模型参数将在PPO、GRPO中使用。 - top_entropy_quantile: 仅对熵值处于前指定分位的 token 参与损失计算,默认为1.0,即不过滤低熵 token,具体参考[文档](./GRPO/AdvancedResearch/entropy_mask.md) - log_entropy: 记录训练中的熵值变化动态,默认为False,具体参考[文档](./GRPO/GetStarted/GRPO.md#logged-metrics) +##### 奖励函数参数 +内置的奖励函数参考[文档](./GRPO/DeveloperGuide/reward_function.md) cosine 奖励参数 - cosine_min_len_value_wrong:cosine 奖励函数参数,生成错误答案时,最小长度对应的奖励值。默认值为-0.5。 - cosine_max_len_value_wrong:生成错误答案时,最大长度对应的奖励值。默认值为0.0。 diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md b/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md index 1f21f2abfe..9bc9df2f80 100644 --- a/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md +++ b/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md @@ -2,7 +2,7 @@ **版本依赖**:ms-swift>=3.7 -[Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。以下是三种计算策略对比: +[Group Sequence Policy Optimization](https://arxiv.org/abs/2507.18071)中指出GRPO在计算重要性采样权重时,是在token级别进行操作的。然而,这种做法由于每个token仅采样一次,无法实现有效的分布校正,反而会在模型训练过程中引入高方差噪声,极易导致模型的梯度估计不稳定,最终造成模型训练的崩塌。因此,论文认为,优化目标的单位应该与奖励的单位保持一致。由于奖励通常是在序列级别(即完整生成的回复)给出的,因此更合理的做法是将 off-policy 校正和优化也提升到序列级别,而非 token 级别。以下是三种计算策略对比: 1. GRPO 对每个 token 独立计算重要性采样比,具体公式为 @@ -54,7 +54,7 @@ importance_weights = torch.exp(log_importance_weights) - `importance_sampling_level sequence` (GSPO) - `importance_sampling_level sequence_token` (GSPO-token) -其中 sequence_token 要求 ms-swift > 3.7 (源码安装) +其中 sequence_token 要求 ms-swift >= 3.8 论文其他超参 ```bash diff --git a/docs/source/Instruction/Use-tuners.md b/docs/source/Instruction/Use-tuners.md index c84ca6fe0c..7461877fc8 100644 --- a/docs/source/Instruction/Use-tuners.md +++ b/docs/source/Instruction/Use-tuners.md @@ -15,7 +15,7 @@ tuner是指附加在模型上的额外结构部分,用于减少训练参数量 - Adapter: [Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751) - Vision Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119) - Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503) -- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) | [Usage](ResTuning.md) > +- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) > - [PEFT](https://github.com/huggingface/peft)提供的tuners, 如AdaLoRA、DoRA、Fourierft等 ## 接口列表 diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index f551b9d906..5c75aa28c0 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -246,36 +246,13 @@ lora训练: - lora_bias: 默认为`'none'`,可以选择的值: 'none'、'all'。如果你要将bias全都设置为可训练,你可以设置为`'all'`。 - use_rslora: 默认为`False`,是否使用`RS-LoRA`。 - -**DPO参数**: -- ref_load: ref_model的加载路径。采用DPO/KTO算法且使用全参数训练时需要传入。默认为None,即设置为`load`。 -- ref_adapter_load: 加载ref_adapter的权重路径,默认为None。若你要使用SFT产生的LoRA权重进行DPO,请使用"ms-swift>=3.8",并在训练时设置`--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true`。若是此场景的断点续训,则设置`--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`。 -- beta: 含义与[TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)相同。控制与参考模型偏差程度的参数。beta值越高,表示与参考模型的偏差越小。对于 IPO 损失函数 (loss_type="ipo"),beta是[论文](https://huggingface.co/papers/2310.12036)中所指的正则化参数。默认为0.1。 -- 🔥rpo_alpha: 来自[RPO 论文](https://huggingface.co/papers/2404.19733)中的参数,用于控制损失函数中NLL项的权重(即SFT损失),`loss = dpo_loss + rpo_alpha * sft_loss`,论文中推荐设置为`1.`。默认为`None`,即默认不引入sft_loss。 - - **注意**:在"ms-swift<3.8",其默认值为`1.`。在"ms-swift>=3.8"该默认值修改为`None`。 -- reference_free: 是否忽略提供的参考模型,并隐式地使用一个对所有响应赋予相等概率的参考模型。默认为False。 -- label_smoothing: 默认为0.。 -- f_divergence_type: 默认为`reverse_kl`。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。 -- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions)。 - -**KTO参数**: -- ref_load: 含义同DPO。 -- ref_adapter_load: 含义同DPO。 -- beta: 控制与 ref_model 偏离程度的参数。较高的 beta 表示与 ref_model 偏离更小。默认为`0.1`。 -- loss_type: 默认为'kto'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type)。 -- desirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 desirable 损失按该系数进行加权,默认为`1.`。 -- undesirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 undesirable 损失按该系数进行加权,默认为`1.`。 - -**RM参数**: -- center_rewards_coefficient: 用于激励奖励模型输出均值为零的奖励的系数,具体查看这篇[论文](https://huggingface.co/papers/2312.09244)。推荐值:0.01。 - **Mcore-Bridge参数** - 🔥load_safetensors: 默认为False,是否直接从safetensors加载权重。 - 🔥save_safetensors: 默认为False,是否直接保存成safetensors权重。注意,若该参数设置为True,则不会存储优化器权重、随机数状态等断点续训内容。 - model: safetensors权重的model_id或者model_path。默认为None。 - model_type: 模型类型。介绍参考[ms-swift命令行参数文档](../Instruction/Command-line-parameters.md)。 - adapters: safetensors格式的LoRA增量权重的adapter_id或者adapter_path。默认为`[]`。 -- ref_model: ref_model safetensors权重的model_id或者model_path。采用dpo、kto算法且使用全参数训练时需要传入。默认为None,设置为`--model`。 +- ref_model: ref_model safetensors权重的model_id或者model_path。采用grpo、dpo、kto算法且使用全参数训练时需要传入。默认为None,设置为`--model`。 - ref_adapters: ref_adapters safetensors权重的adapter_id或者adapter_path的列表(目前只支持长度为1),默认为`[]`。 - use_hf: 控制模型下载、数据集下载、模型推送使用ModelScope还是HuggingFace。默认为False,使用ModelScope。 - hub_token: hub token. modelscope的hub token可以查看[这里](https://modelscope.cn/my/myaccesstoken)。默认为None。 @@ -318,11 +295,79 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 ## RLHF参数 除了继承训练参数外,还支持以下参数: -- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo'、'kto'和'rm'。 +- 🔥rlhf_type: 默认为'dpo'。目前可选择为'dpo'、'grpo'、'kto'和'rm'。 - loss_scale: 覆盖[基本参数](../Instruction/Command-line-parameters.md)中的loss_scale。默认为'last_round'。 - calculate_per_token_loss: 覆盖Megatron参数,默认为False。 +### DPO参数 +- ref_load: ref_model的加载路径。采用DPO/GRPO/KTO算法且使用全参数训练时需要传入。默认为None,即设置为`load`。 +- ref_adapter_load: 加载ref_adapter的权重路径,默认为None。若你要使用SFT产生的LoRA权重进行DPO,请使用"ms-swift>=3.8",并在训练时设置`--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true`。若是此场景的断点续训,则设置`--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`。 +- beta: 含义与[TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)相同。控制与参考模型偏差程度的参数。beta值越高,表示与参考模型的偏差越小。对于 IPO 损失函数 (loss_type="ipo"),beta是[论文](https://huggingface.co/papers/2310.12036)中所指的正则化参数。默认为0.1。 +- 🔥rpo_alpha: 来自[RPO 论文](https://huggingface.co/papers/2404.19733)中的参数,用于控制损失函数中NLL项的权重(即SFT损失),`loss = dpo_loss + rpo_alpha * sft_loss`,论文中推荐设置为`1.`。默认为`None`,即默认不引入sft_loss。 + - **注意**:在"ms-swift<3.8",其默认值为`1.`。在"ms-swift>=3.8"该默认值修改为`None`。 +- reference_free: 是否忽略提供的参考模型,并隐式地使用一个对所有响应赋予相等概率的参考模型。默认为False。 +- label_smoothing: 默认为0.。 +- f_divergence_type: 默认为`reverse_kl`。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer)。 +- loss_type: 默认为'sigmoid'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions)。 + +### KTO参数 +- ref_load: 含义同DPO。 +- ref_adapter_load: 含义同DPO。 +- beta: 控制与 ref_model 偏离程度的参数。较高的 beta 表示与 ref_model 偏离更小。默认为`0.1`。 +- loss_type: 默认为'kto'。可选值参考[TRL文档](https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type)。 +- desirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 desirable 损失按该系数进行加权,默认为`1.`。 +- undesirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 undesirable 损失按该系数进行加权,默认为`1.`。 + +### RM参数 +- center_rewards_coefficient: 用于激励奖励模型输出均值为零的奖励的系数,具体查看这篇[论文](https://huggingface.co/papers/2312.09244)。推荐值:0.01。 + +### GRPO参数 +- ref_load: 含义同DPO。 +- ref_adapter_load: 含义同DPO。 +- beta: KL正则系数,默认为0.04,设置为0时不加载ref model。 +- micro_batch_size: 每个device的批次大小,默认为1。 +- global_batch_size: 总批次大小,等价于`micro_batch_size*数据并行大小*梯度累加步数`。默认为16。 +- steps_per_generation:每轮生成的优化步数,即采样批量大小相对global_batch_size的倍数,默认为1。 +- generation_batch_size: 采样批量大小,需要是global_batch_size的倍数,默认等于global_batch_size*steps_per_generation。 +- num_generations: 每个prompt采样的数量,论文中的G值,默认为8。 +- reward_funcs: GRPO算法奖励函数,可选项为`accuracy`、`format`、`cosine`、`repetition`和`soft_overlong`,见swift/plugin/orm.py。你也可以在plugin中自定义自己的奖励函数。默认为`[]`。 +- reward_weights: 每个奖励函数的权重。必须与奖励函数和奖励模型的总数量匹配。默认为 None,即所有奖励的权重都相等,为`1.0`。 + - 提示:如果GRPO训练中包含`--reward_model`,则其加在奖励函数的最后位置。 +- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo'], 默认为'grpo', 具体查看该[pr](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348)。 +- log_completions: 是否记录训练中的模型生成内容,默认为False。 +- vllm_mode: vLLM 集成模式,可选项为 `server` 和 `colocate`。server 模式使用 `swift rollout` 拉起的 vLLM 服务器进行采样,colocate 模式在程序内部署 vLLM。使用server端时, +- vllm_mode server 参数 + - vllm_server_host: vLLM server host地址,默认为None。 + - vllm_server_port: vLLM server 服务端口,默认为8000。 + - vllm_server_base_url: vLLM server的Base URL(比如 http://local_host:8000), 默认为None。设置后,忽略host和port设置。 + - vllm_server_timeout: 连接vLLM server的超时时间,默认为 240s。 + - vllm_server_pass_dataset: 透传额外的数据集信息到vLLM server,用于多轮训练。 + - async_generate: 异步rollout以提高训练速度,注意开启时采样会使用上一轮更新的模型进行采样,不支持多轮场景。默认`false`. + - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE: 环境变量,用于控制权重同步时的传输桶大小(bucket size),适用于 Server Mode 下的全参数训练,单位为 MB,默认值为 512 MB。 +- vllm_mode colocate 参数(更多参数支持参考[vLLM参数](#vLLM参数)。) + - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。 + - vllm_max_model_len: vllm透传参数,默认为None。 + - vllm_enforce_eager: vllm透传参数,默认为False。 + - vllm_limit_mm_per_prompt: vllm透传参数,默认为None。 + - vllm_enable_prefix_caching: vllm透传参数,默认为True。 + - vllm_tensor_parallel_size: tp并行数,默认为`1`。 + - vllm_enable_lora: 支持vLLM Engine 加载 LoRA adapter,默认为False。用于加速LoRA训练的权重同步,具体参考[文档](../Instruction/GRPO/GetStarted/GRPO.md#权重同步加速)。 + - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1, 2], 默认为0,不释放。 + - offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。 + - offload_model: 是否在vLLM推理时 offload 模型,默认为False。 +- num_iterations: 每条数据的更新次数,[GRPO论文](https://arxiv.org/abs/2402.03300)中的 $\mu$ 值,默认为1。 +- epsilon: clip 系数,默认为0.2。 +- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。 +- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。 +- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。 +- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 +- delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。 +- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。 +- scale_rewards:指定奖励的缩放策略。可选值包括 `group`(按组内标准差缩放)、`batch`(按整个批次的标准差缩放)、`none`(不进行缩放)。在 ms-swift < 3.10 版本中,该参数为布尔类型,`true` 对应 `group`,`false` 对应 `none`。默认值与 `advantage_estimator` 绑定:`grpo` 对应 `group`,`rloo` 对应 `none`,`reinforce_plus_plus` 对应 `batch`。 + +内置奖励函数参数参考[文档](../Instruction/Command-line-parameters.md#奖励函数参数) + ## 导出参数 这里介绍`megatron export`的参数(需"ms-swift>=3.10"),若要使用`swift export`导出命令,请参考[ms-swift命令行参数文档](../Instruction/Command-line-parameters.md#导出参数)。`megatron export`相比`swift export`,支持分布式和多机导出。Megatron导出参数继承自Megatron参数和基本参数。 - 🔥to_mcore: HF格式权重转成Megatron格式。默认为False。 diff --git a/docs/source/Megatron-SWIFT/GRPO.md b/docs/source/Megatron-SWIFT/GRPO.md new file mode 100644 index 0000000000..a8aa4df0e4 --- /dev/null +++ b/docs/source/Megatron-SWIFT/GRPO.md @@ -0,0 +1,61 @@ +# GRPO + +**版本依赖**:ms-swift >= 3.11 + +如果你是首次使用 GRPO,请先参考 [GRPO文档](../Instruction/GRPO/GetStarted/GRPO.md)。 + +Megatron GRPO 当前已支持以下功能: + +- **训练模式**:全参数训练与 LoRA 微调 +- **并行策略**:支持上下文并行(CP)、流水线并行(PP)、张量并行(TP)和专家并行(EP) +- **推理加速**:支持 vLLM 的 colocate 模式和 server 模式 +- **模型支持**:兼容 Megatron Swift 中的 LLM 及 MLLM(多模态大模型) +- **算法支持**:涵盖 swift GRPO 的大部分功能 + +以下参数或功能将在后续版本中逐步支持: + +- **Entropy 相关配置**:如 `top_entropy_quantile`、`log_entropy` +- **Reward Model / Reward Model Plugin** +- **多轮 Rollout 调度机制**(`multi_turn_scheduler`):实现多轮对话策略优化 +- **优势估计器**(`advantage_estimator`):支持更复杂的策略梯度估计方法 +- **KL 散度计入奖励**(`kl_in_reward`) +- **虚拟流水线并行**(VPP) +- **参考模型同步更新**(`sync_ref_model`) +- **Async Generate** (`async_generate`) +- **num_iterations** +- **日志同步 SwanLab** + +⚠️ 注意:以下参数在 Megatron GRPO 中不生效: + +- **`use_vllm`**:Megatron GRPO 暂不支持使用 PTEngine 进行 Rollout 推理。 +- **`move_model_batches`**:该参数专用于 DeepSpeed ZeRO-3 优化,在 Megatron 架构下无效。 + +与 ms-swift GRPO 相同,Megatron GRPO batch size 相关的参数均以 **completion-level** 为单位,即表示模型生成的 completion 数量,而非 prompt 数量。 + +#### 参数对比 + +下表对比了 ms-swift 和 Megatron-SWIFT 中批量相关参数的对应关系: + +| ms-swift 参数 | Megatron-SWIFT 参数 | 说明 | +|---------------|---------------------|------| +| `per_device_train_batch_size` | `micro_batch_size` | 每张 GPU 的训练批次大小(completion-level) | +| `gradient_accumulation_steps` | - | 梯度累积步数,在 Megatron-SWIFT 中已包含在 `global_batch_size` 的计算中 | +| - | `global_batch_size` | 全局批次大小(completion-level)
**Megatron-SWIFT**: `micro_batch_size × dp_size × gradient_accumulation_steps`
**ms-swift**: `per_device_train_batch_size × world_size × gradient_accumulation_steps` | +| `num_generations` | `num_generations` | 每个 prompt 生成的 completion 数量 | +| `steps_per_generation` | `steps_per_generation` | Rollout 批次大小相对于训练批次大小的倍数
**注意**:在 ms-swift 中需为 `gradient_accumulation_steps` 的整数倍 | +| `generation_batch_size` | `generation_batch_size` | Rollout 阶段的批次大小(completion-level),需为 `global_batch_size` 的整数倍 | + +以下公式用于计算 Megatron GRPO 中的批量: + +- **数据并行大小**:`dp_size = world_size / (TP × PP × CP)` +- **全局批次大小**:`global_batch_size = micro_batch_size × dp_size × gradient_accumulation_steps` +- **生成批次大小**:`generation_batch_size = global_batch_size × steps_per_generation` +- **Rollout Prompt 数量**:`num_rollout_prompts = generation_batch_size / num_generations` +- **训练 Prompt 数量**:`num_train_prompts = global_batch_size / num_generations` +- **每个 DP group 的训练 Prompt 数量**:`num_prompts_per_dp_group = global_batch_size / num_generations / dp_size` + +注意:在 Megatron GRPO 中,每个 DP group 的训练 Prompt 数量须满足 `num_prompts_per_dp_group` 是 `micro_batch_size`的整数倍,以确保训练批次能够正确分配。 + +更多参数请参考[命令行文档](./Command-line-parameters.md#grpo参数) + +训练脚本请参考[Megatron GRPO 脚本](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/grpo) diff --git a/docs/source/Megatron-SWIFT/Multimodal-Model.md b/docs/source/Megatron-SWIFT/Multimodal-Model.md index 9cc51732f7..8f51213211 100644 --- a/docs/source/Megatron-SWIFT/Multimodal-Model.md +++ b/docs/source/Megatron-SWIFT/Multimodal-Model.md @@ -1,6 +1,6 @@ # 多模态模型 -ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL等模型的CPT/SFT/DPO/KTO/RM。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/Supported-models-and-datasets.md)。 +ms-swift引入了Megatron的并行技术来加速多模态大模型的训练。目前支持Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL等模型的CPT/SFT/GRPO/DPO/KTO/RM。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/Supported-models-and-datasets.md)。 环境准备请参考Megatron-SWIFT的[快速开始文档](./Quick-start.md)。 diff --git a/docs/source/Megatron-SWIFT/Quick-start.md b/docs/source/Megatron-SWIFT/Quick-start.md index 8c92e2b6b9..faff26ecec 100644 --- a/docs/source/Megatron-SWIFT/Quick-start.md +++ b/docs/source/Megatron-SWIFT/Quick-start.md @@ -8,6 +8,7 @@ ms-swift引入了Megatron的并行技术来加速大模型的训练,包括数 | ------ | ------ | ---- | ----- | ----- | | 预训练| ✅ | ✅| ✅ | ✅ | | 指令监督微调 | ✅ | ✅| ✅ | ✅ | +| GRPO | ✅ | ✅| ✅ | ✅ | | DPO | ✅ | ✅| ✅ | ✅ | | KTO | ✅ | ✅| ✅ | ✅ | | RM | ✅ | ✅| ✅ | ✅ | diff --git a/docs/source/index.rst b/docs/source/index.rst index c5a5fc08c8..f70a8a05c9 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,6 +42,7 @@ Swift DOCUMENTATION Megatron-SWIFT/LoRA-Training.md Megatron-SWIFT/Multimodal-Model.md Megatron-SWIFT/Mcore-Bridge.md + Megatron-SWIFT/GRPO.md .. toctree:: diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 74634419e3..c7dd3d865a 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -577,9 +577,9 @@ The meanings of the following parameters can be referenced [here](https://huggin - use_vllm: Whether to use vLLM as the infer_backend for GRPO generation, default is False. - vllm_mode: Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server` or `colocate` - vllm_mode server parameter - - vllm_server_base_url: Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " "and `vllm_server_port` are ignored. Default is None. - vllm_server_host: The host address of the vLLM server. Default is None. - vllm_server_port: The service port of the vLLM server. Default is 8000. + - vllm_server_base_url: Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " "and `vllm_server_port` are ignored. Default is None. - vllm_server_timeout: The connection timeout for the vLLM server. Default is 240 seconds. - vllm_server_pass_dataset: pass additional dataset information through to the vLLM server for multi-turn training. - async_generate: Use async rollout to improve train speed. Note that rollout will use the model updated in the previous round when enabled. Multi-turn scenarios are not supported. Default is `false`. @@ -592,7 +592,7 @@ The meanings of the following parameters can be referenced [here](https://huggin - vllm_enable_prefix_caching: A pass-through parameter for vLLM, default is True. - vllm_tensor_parallel_size: the tensor parallel size of vLLM engine, default is 1. - vllm_enable_lora: Enable the vLLM engine to load LoRA adapters; defaults to False. Used to accelerate weight synchronization during LoRA training. See the [documentation](./GRPO/GetStarted/GRPO.md#weight-sync-acceleration) for details. - - sleep_level: make vllm sleep when model is training. Options are 0 or 1, default is 0, no sleep + - sleep_level: make vllm sleep when model is training. Options are 0/1/2, default is 0, no sleep - offload_optimizer: Whether to offload optimizer parameters during inference with vLLM. The default is `False`. - offload_model: Whether to offload the model during inference with vLLM. The default is `False`. - completion_length_limit_scope: Specifies the scope of the `max_completion_length` limit in multi-turn conversations. @@ -607,7 +607,7 @@ The meanings of the following parameters can be referenced [here](https://huggin - overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False. The hyperparameters for the reward function can be found in the [Built-in Reward Functions section](#built-in-reward-functions). - delta: Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon. This method was introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). -- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`. +- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`. - advantage_estimator: Advantage estimator. Default is `grpo` (group-relative advantage). Options: `grpo`, [`rloo`](./GRPO/AdvancedResearch/RLOO.md), [`reinforce_plus_plus`](./GRPO/AdvancedResearch/REINFORCEPP.md). - kl_in_reward: Controls where the KL regularization is applied. `false`: KL is a separate loss term. `true`: KL is subtracted from the reward. The default is bound to `advantage_estimator`: `false` for `grpo`, and `true` for `rloo` and `reinforce_plus_plus`. - scale_rewards: Specifies the reward scaling strategy. Options: `group` (scale by intra-group std), `batch` (scale by batch-wide std), `none` (no scaling). In ms-swift < 3.10, this was a boolean where `true` corresponds to `group` and `false` to `none`. The default is bound to `advantage_estimator`: `group` for `grpo`, `none` for `rloo`, and `batch` for `reinforce_plus_plus`. @@ -621,6 +621,8 @@ The hyperparameters for the reward function can be found in the [Built-in Reward - log_entropy: Logs the entropy values during training. The default is False. For more information, refer to the [documentation](./GRPO/GetStarted/GRPO.md#logged-metrics). +##### Reward function parameters +Refer to the [documentation](./GRPO/DeveloperGuide/reward_function.md) for built-in reward functions. cosine reward function arguments - cosine_min_len_value_wrong (default: -0.5): Reward value corresponding to the minimum length when the answer is incorrect. diff --git a/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md b/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md index 2f8ec7ae54..03c67b3c6e 100644 --- a/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md +++ b/docs/source_en/Instruction/GRPO/AdvancedResearch/GSPO.md @@ -1,8 +1,8 @@ # Group Sequence Policy Optimization -**Version Requirement**: ms-swift>=3.7 +**Version Requirement**: ms-swift>=3.8 -In [Group Sequence Policy Optimization](https://www.arxiv.org/abs/2507.18071), it is pointed out that GRPO computes importance sampling weights at the token level. However, this approach is problematic: since each token is only sampled once, it cannot realize effective distribution correction, and instead introduces high-variance noise during training, which can easily lead to unstable gradient estimates and even training collapse. Therefore, the paper argues that the unit of the objective function should be consistent with that of the reward. Since the reward is typically given at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level rather than the token level. +In [Group Sequence Policy Optimization](https://arxiv.org/abs/2507.18071), it is pointed out that GRPO computes importance sampling weights at the token level. However, this approach is problematic: since each token is only sampled once, it cannot realize effective distribution correction, and instead introduces high-variance noise during training, which can easily lead to unstable gradient estimates and even training collapse. Therefore, the paper argues that the unit of the objective function should be consistent with that of the reward. Since the reward is typically given at the sequence level (i.e., for the entire generated response), it is more reasonable to perform off-policy correction and optimization at the sequence level rather than the token level. Below are the three main strategies for computing importance sampling weights: diff --git a/docs/source_en/Instruction/Use-tuners.md b/docs/source_en/Instruction/Use-tuners.md index f960591893..d1b4f2cb1d 100644 --- a/docs/source_en/Instruction/Use-tuners.md +++ b/docs/source_en/Instruction/Use-tuners.md @@ -15,7 +15,7 @@ Tuners refer to additional structural components attached to a model, aimed at r - Adapter: [Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751) - Vision Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119) - Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503) -- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) | [Usage](ResTuning.md) > +- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) > - Tuners provided by [PEFT](https://github.com/huggingface/peft), such as AdaLoRA, DoRA, Fourierft, etc. ## Interface List diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 8e0ef3085a..446916c1f7 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -262,28 +262,6 @@ LoRA Training: - lora_bias: Default is `'none'`. Available options: `'none'`, `'all'`. If you want all biases to be set as trainable, set this to `'all'`. - use_rslora: Default is `False`. Whether to use `RS-LoRA`. -**DPO Parameters** -- ref_load: The loading path for the reference model. This must be provided when using DPO/KTO algorithms with full-parameter training. Defaults to `None`, which means it will be set to the same value as `load`. -- ref_adapter_load: The path to load the ref_adapter weights, default is `None`. If you want to use LoRA weights generated from SFT for DPO, please use "ms-swift>=3.8" and set `--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true` during training. For resuming training from a checkpoint in this scenario, set `--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`. -- beta: Has the same meaning as in [TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig). It controls the degree of deviation from the reference model. A higher beta value indicates less deviation from the reference model. For the IPO loss function (`loss_type="ipo"`), beta is the regularization parameter as mentioned in the [paper](https://huggingface.co/papers/2310.12036). Default is 0.1. -- 🔥rpo_alpha: A parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) that controls the weight of the NLL term (i.e., the SFT loss) in the loss function, where `loss = dpo_loss + rpo_alpha * sft_loss`. The paper recommends setting it to `1.`. The default value is `None`, meaning the SFT loss is not included by default. - - **Note**: In "ms-swift<3.8", the default value was `1.`. Starting from "ms-swift>=3.8", the default has been changed to `None`. -- reference_free: Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. Default is `False`. -- label_smoothing: Default is 0. -- f_divergence_type: Default is `reverse_kl`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values. -- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions) for possible values. - -**KTO Parameters**: -- ref_load: same meaning as in DPO. -- ref_adapter_load: same meaning as in DPO. -- beta: parameter controlling the deviation from the ref_model. Higher `beta` means less deviation from the ref_model. Default is `0.1`. -- loss_type: default is `'kto'`. See possible values in the TRL docs: https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type. -- desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. -- undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. - -**RM Parameters**: -- center_rewards_coefficient: A coefficient used in reward model (RM) training to incentivize the model to output rewards with zero mean. See this [paper](https://huggingface.co/papers/2312.09244) for details. Recommended value: 0.01. - **Mcore-Bridge Parameters** - 🔥load_safetensors: Defaults to False. Whether to load weights directly from safetensors. @@ -291,7 +269,7 @@ LoRA Training: - model: The model_id or model_path of safetensors weights. Defaults to None. - model_type: Model type. For details, refer to [ms-swift command-line parameters documentation](../Instruction/Command-line-parameters.md). - adapters: adapter_id or adapter_path of LoRA incremental weights in safetensors format. Default is `[]`. -- ref_model: model_id or model_path of ref_model safetensors weights. Required when using DPO or KTO algorithms with full-parameter training. Default is None, set to `--model`. +- ref_model: model_id or model_path of ref_model safetensors weights. Required when using DPO/GRPO/KTO algorithms with full-parameter training. Default is None, set to `--model`. - ref_adapters: List of adapter_id or adapter_path of ref_adapters safetensors weights (currently only supports length of 1). Default is `[]`. - use_hf: Controls whether to use ModelScope or HuggingFace for model download, dataset download, and model push. Default is False, using ModelScope. - hub_token: Hub token. ModelScope hub token can be found [here](https://modelscope.cn/my/myaccesstoken). Default is None. @@ -313,7 +291,7 @@ Megatron training parameters are inherited from Megatron parameters and basic pa - Typically used together with `--freeze_vit false` and `--freeze_aligner false`. - aligner_lr: Specifies the learning rate for the aligner module in multimodal models. Default is `None`, same as `learning_rate`. - gradient_checkpointing_kwargs: Arguments passed to `torch.utils.checkpoint`. For example: set `--gradient_checkpointing_kwargs '{"use_reentrant": false}'`. Defaults to `None`. This parameter only takes effect when `vit_gradient_checkpointing` is enabled. -- 🔥packing: Whether to use sequence packing to improve computational efficiency (achieving better load balancing across nodes and processes, and higher GPU utilization), at the cost of additional preprocessing time, while also stabilizing GPU memory usage. Defaults to `False`. Currently supported for CPT, SFT, DPO, KTO and RM. +- 🔥packing: Whether to use sequence packing to improve computational efficiency (achieving better load balancing across nodes and processes, and higher GPU utilization), at the cost of additional preprocessing time, while also stabilizing GPU memory usage. Defaults to `False`. Currently supported for CPT, SFT, GRPO, DPO, KTO and RM. - Note: **Sequences within the same batch remain mutually invisible**, except for Qwen3-Next. - Note: **Packing will reduce the number of dataset samples. Please adjust global_batch_size and learning rate accordingly**. - packing_length: the length to use for packing. Defaults to None, in which case it is set to max_length. @@ -337,11 +315,83 @@ Megatron training parameters are inherited from Megatron parameters and basic pa In addition to inheriting the training parameters, the following parameters are also supported: -- 🔥rlhf_type: Default is 'dpo'. Currently, 'dpo', 'kto', and 'rm' are available. +- 🔥rlhf_type: Default is 'dpo'. Currently, 'dpo', 'grpo', 'kto', and 'rm' are available. - loss_scale: Overrides the `loss_scale` in [basic parameters](../Instruction/Command-line-parameters.md). Default is 'last_round'. - calculate_per_token_loss: Overrides the Megatron parameter. Default is False. +### DPO Parameters + +- ref_load: The loading path for the reference model. This must be provided when using DPO/GRPO/KTO algorithms with full-parameter training. Defaults to `None`, which means it will be set to the same value as `load`. +- ref_adapter_load: The path to load the ref_adapter weights, default is `None`. If you want to use LoRA weights generated from SFT for DPO, please use "ms-swift>=3.8" and set `--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true` during training. For resuming training from a checkpoint in this scenario, set `--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`. +- beta: Has the same meaning as in [TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig). It controls the degree of deviation from the reference model. A higher beta value indicates less deviation from the reference model. For the IPO loss function (`loss_type="ipo"`), beta is the regularization parameter as mentioned in the [paper](https://huggingface.co/papers/2310.12036). Default is 0.1. +- 🔥rpo_alpha: A parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) that controls the weight of the NLL term (i.e., the SFT loss) in the loss function, where `loss = dpo_loss + rpo_alpha * sft_loss`. The paper recommends setting it to `1.`. The default value is `None`, meaning the SFT loss is not included by default. + - **Note**: In "ms-swift<3.8", the default value was `1.`. Starting from "ms-swift>=3.8", the default has been changed to `None`. +- reference_free: Whether to ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. Default is `False`. +- label_smoothing: Default is 0. +- f_divergence_type: Default is `reverse_kl`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer) for possible values. +- loss_type: Default is `'sigmoid'`. See the [TRL documentation](https://huggingface.co/docs/trl/main/en/dpo_trainer#loss-functions) for possible values. + +### KTO Parameters + +- ref_load: same meaning as in DPO. +- ref_adapter_load: same meaning as in DPO. +- beta: parameter controlling the deviation from the ref_model. Higher `beta` means less deviation from the ref_model. Default is `0.1`. +- loss_type: default is `'kto'`. See possible values in the TRL docs: https://huggingface.co/docs/trl/main/en/kto_trainer#trl.KTOConfig.loss_type. +- desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. +- undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. + +### RM Parameters + +- center_rewards_coefficient: A coefficient used in reward model (RM) training to incentivize the model to output rewards with zero mean. See this [paper](https://huggingface.co/papers/2312.09244) for details. Recommended value: 0.01. + +### GRPO Parameters + +- ref_load: Same meaning as in DPO. +- ref_adapter_load: Same meaning as in DPO. +- beta: KL regularization coefficient, default is 0.04. When set to 0, the ref model is not loaded. +- micro_batch_size: Batch size per device, default is 1. +- global_batch_size: Total batch size, equivalent to `micro_batch_size * data parallel size * gradient accumulation steps`. Default is 16. +- steps_per_generation: Number of optimization steps per generation round, i.e., the ratio of sampling batch size to global_batch_size. Default is 1. +- generation_batch_size: Sampling batch size, must be a multiple of global_batch_size. Default equals global_batch_size * steps_per_generation. +- num_generations: Number of samples per prompt, the G value in the paper, default is 8. +- reward_funcs: GRPO algorithm reward functions. Options include `accuracy`, `format`, `cosine`, `repetition`, and `soft_overlong`. See swift/plugin/orm.py. You can also customize your own reward functions in the plugin. Default is `[]`. +- reward_weights: Weights for each reward function. Must match the total number of reward functions and reward models. Default is None, meaning all rewards have equal weights of `1.0`. + - Tip: If GRPO training includes `--reward_model`, it is added at the end of the reward functions. +- loss_type: Loss normalization type. Options are `['grpo', 'bnpo', 'dr_grpo']`. Default is `'grpo'`. See this [PR](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348) for details. +- log_completions: Whether to log model-generated content during training. Default is False. +- vllm_mode: vLLM integration mode. Options are `server` and `colocate`. Server mode uses the vLLM server launched by `swift rollout` for sampling, while colocate mode deploys vLLM within the program. When using server mode: +- vllm_mode server parameters: + - vllm_server_host: vLLM server host address. Default is None. + - vllm_server_port: vLLM server port. Default is 8000. + - vllm_server_base_url: Base URL of the vLLM server (e.g., http://local_host:8000). Default is None. When set, host and port settings are ignored. + - vllm_server_timeout: Timeout for connecting to the vLLM server. Default is 240s. + - vllm_server_pass_dataset: Pass additional dataset information to the vLLM server for multi-round training. + - async_generate: Asynchronous rollout to improve training speed. Note: When enabled, sampling uses the model from the previous round update, and multi-round scenarios are not supported. Default is `false`. + - SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE: Environment variable for controlling the bucket size during weight synchronization. Applicable to full-parameter training in Server Mode. Unit is MB, default value is 512 MB. +- vllm_mode colocate parameters (for more parameter support, refer to [vLLM parameters](#vllm-parameters)): + - vllm_gpu_memory_utilization: vLLM passthrough parameter. Default is 0.9. + - vllm_max_model_len: vLLM passthrough parameter. Default is None. + - vllm_enforce_eager: vLLM passthrough parameter. Default is False. + - vllm_limit_mm_per_prompt: vLLM passthrough parameter. Default is None. + - vllm_enable_prefix_caching: vLLM passthrough parameter. Default is True. + - vllm_tensor_parallel_size: Tensor parallel size. Default is `1`. + - vllm_enable_lora: Support loading LoRA adapters in the vLLM Engine. Default is False. Used to accelerate weight synchronization in LoRA training. See [documentation](../Instruction/GRPO/GetStarted/GRPO.md#weight-synchronization-acceleration) for details. + - sleep_level: Release vLLM GPU memory during training. Options are `[0, 1, 2]`. Default is 0, meaning no release. + - offload_optimizer: Whether to offload optimizer parameters during vLLM inference. Default is False. + - offload_model: Whether to offload the model during vLLM inference. Default is False. +- num_iterations: Number of updates per data sample, the $\mu$ value in the [GRPO paper](https://arxiv.org/abs/2402.03300). Default is 1. +- epsilon: Clip coefficient. Default is 0.2. +- epsilon_high: Upper clip coefficient. Default is None. When set, together with epsilon, forms the clipping range `[epsilon, epsilon_high]`. +- dynamic_sample: Filter out data with zero reward standard deviation within groups and sample additional new data. Default is False. +- max_resample_times: Limit the number of resampling times under dynamic_sample setting. Default is 3. +- overlong_filter: Skip overlong truncated samples, which do not participate in loss calculation. Default is False. +- delta: Bilateral GRPO upper bound clipping value from the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). If set, it is recommended to be greater than 1 + epsilon. Default is None. +- importance_sampling_level: Controls importance sampling ratio calculation. Options are `token` and `sequence`. In `token` mode, the original log probability ratio for each token is preserved. In `sequence` mode, the log probability ratios of all valid tokens in the sequence are averaged. The [GSPO paper](https://arxiv.org/abs/2507.18071) uses sequence-level calculation to stabilize training. Default is `token`. +- scale_rewards: Specifies the reward scaling strategy. Options include `group` (scale by within-group standard deviation), `batch` (scale by batch-wide standard deviation), and `none` (no scaling). In ms-swift < 3.10, this parameter is boolean, where `true` corresponds to `group` and `false` corresponds to `none`. The default value is bound to `advantage_estimator`: `grpo` corresponds to `group`, `rloo` corresponds to `none`, and `reinforce_plus_plus` corresponds to `batch`. + +Built-in reward function parameters refer to the [documentation](../Instruction/Command-line-parameters.md#reward-function-parameters). + ## Export Parameters This section introduces the parameters for `megatron export` (requires "ms-swift>=3.10"). To use the `swift export` command for exporting, please refer to the [ms-swift Command Line Parameters Documentation](../Instruction/Command-line-parameters.md#export-arguments). Compared to `swift export`, `megatron export` supports distributed and multi-node exporting. Megatron export parameters inherit from Megatron parameters and basic parameters. diff --git a/docs/source_en/Megatron-SWIFT/GRPO.md b/docs/source_en/Megatron-SWIFT/GRPO.md new file mode 100644 index 0000000000..3fa9dfb58d --- /dev/null +++ b/docs/source_en/Megatron-SWIFT/GRPO.md @@ -0,0 +1,61 @@ +# Megatron GRPO + +**Version Requirement**: ms-swift >= 3.11 + +If you are new to GRPO, please refer to the [GRPO documentation](../Instruction/GRPO/GetStarted/GRPO.md) first. + +Megatron GRPO currently supports the following features: + +- **Training Modes**: Full parameter training and LoRA fine-tuning +- **Parallelism Strategies**: Context Parallelism (CP), Pipeline Parallelism (PP), Tensor Parallelism (TP), and Expert Parallelism (EP) +- **Inference Acceleration**: vLLM colocate mode and server mode +- **Model Support**: Compatible with LLMs and MLLMs (multimodal large models) in Megatron Swift +- **Algorithm Support**: Covers most features of Swift GRPO + +The following parameters or features will be gradually supported in future versions: + +- **Entropy-related Configuration**: e.g., `top_entropy_quantile`, `log_entropy` +- **Reward Model / Reward Model Plugin** +- **Multi-turn Rollout Scheduling** (`multi_turn_scheduler`): Multi-turn conversation policy optimization +- **Advantage Estimator** (`advantage_estimator`): Support for more complex policy gradient estimation methods +- **KL Divergence in Reward** (`kl_in_reward`) +- **Virtual Pipeline Parallelism** (VPP) +- **Reference Model Synchronization** (`sync_ref_model`) +- **Async Generate** (`async_generate`) +- **num_iterations** +- **SwanLab Logging Integration** + +⚠️ **Note**: The following parameters are not effective in Megatron GRPO: + +- **`use_vllm`**: Megatron GRPO does not support using PTEngine for Rollout inference. +- **`move_model_batches`**: This parameter is specific to DeepSpeed ZeRO-3 optimization and is invalid in the Megatron architecture. + +Similar to ms-swift GRPO, all batch size-related parameters in Megatron GRPO are at the **completion-level**, meaning they represent the number of completions generated by the model, not the number of prompts. + +#### Parameter Comparison + +The following table compares the batch-related parameters between ms-swift and Megatron-SWIFT: + +| ms-swift Parameter | Megatron-SWIFT Parameter | Description | +|-------------------|--------------------------|-------------| +| `per_device_train_batch_size` | `micro_batch_size` | Training batch size per GPU (completion-level) | +| `gradient_accumulation_steps` | - | Gradient accumulation steps, already included in `global_batch_size` calculation in Megatron-SWIFT | +| - | `global_batch_size` | Global batch size (completion-level)
**Megatron-SWIFT**: `micro_batch_size × dp_size × gradient_accumulation_steps`
**ms-swift**: `per_device_train_batch_size × world_size × gradient_accumulation_steps` | +| `num_generations` | `num_generations` | Number of completions generated per prompt | +| `steps_per_generation` | `steps_per_generation` | Ratio of Rollout batch size to training batch size
**Note**: In ms-swift, must be an integer multiple of `gradient_accumulation_steps` | +| `generation_batch_size` | `generation_batch_size` | Batch size during Rollout phase (completion-level), must be an integer multiple of `global_batch_size` | + +The following formulas are used to calculate batch sizes in Megatron GRPO: + +- **Data Parallel Size**: `dp_size = world_size / (TP × PP × CP)` +- **Global Batch Size**: `global_batch_size = micro_batch_size × dp_size × gradient_accumulation_steps` +- **Generation Batch Size**: `generation_batch_size = global_batch_size × steps_per_generation` +- **Rollout Prompt Count**: `num_rollout_prompts = generation_batch_size / num_generations` +- **Training Prompt Count**: `num_train_prompts = global_batch_size / num_generations` +- **Training Prompt Count per DP Group**: `num_prompts_per_dp_group = global_batch_size / num_generations / dp_size` + +**Note**: In Megatron GRPO, the training prompt count per DP group must satisfy that `num_prompts_per_dp_group` is an integer multiple of `micro_batch_size` to ensure proper batch allocation during training. + +For more parameters, please refer to the [Command-line Parameters documentation](./Command-line-parameters.md#grpo-parameters). + +For training scripts, please refer to [Megatron GRPO Scripts](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/grpo). diff --git a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md index 9f339cc547..d3d96dde1f 100644 --- a/docs/source_en/Megatron-SWIFT/Multimodal-Model.md +++ b/docs/source_en/Megatron-SWIFT/Multimodal-Model.md @@ -1,6 +1,6 @@ # Multimodal Models -ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports CPT/SFT/DPO/KTO/RM for models such as Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). +ms-swift introduces Megatron's parallelization techniques to accelerate the training of large multimodal models. Currently, it supports CPT/SFT/GRPO/DPO/KTO/RM for models such as Qwen3-VL, Qwen3-Omni, Qwen2.5-VL, Qwen2.5-Omni, InternVL3.5, GLM4.5v, Kimi-VL. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). For environment setup, please refer to the Megatron-SWIFT [Quick Start guide](./Quick-start.md). diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index ed46f0471f..94123c8c4e 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -7,9 +7,10 @@ ms-swift incorporates Megatron's parallelization techniques to accelerate the tr | ---------------------------------- | -------------- | ---- | ---- | ---------- | | Pretraining | ✅ | ✅ | ✅ | ✅ | | Instruction-supervised fine-tuning | ✅ | ✅ | ✅ | ✅ | +| GRPO | ✅ | ✅ | ✅ | ✅ | | DPO | ✅ | ✅ | ✅ | ✅ | | KTO | ✅ | ✅ | ✅ | ✅ | -| RM | ✅ | ✅ | ✅ | ✅ | +| RM | ✅ | ✅ | ✅ | ✅ | | Classification tasks | ✅ | ✅ | ✅ | ✅ | ## Environment Setup diff --git a/docs/source_en/index.rst b/docs/source_en/index.rst index c5a5fc08c8..f70a8a05c9 100644 --- a/docs/source_en/index.rst +++ b/docs/source_en/index.rst @@ -42,6 +42,7 @@ Swift DOCUMENTATION Megatron-SWIFT/LoRA-Training.md Megatron-SWIFT/Multimodal-Model.md Megatron-SWIFT/Mcore-Bridge.md + Megatron-SWIFT/GRPO.md .. toctree:: diff --git a/examples/megatron/grpo/dense_colocate.sh b/examples/megatron/grpo/dense_colocate.sh new file mode 100644 index 0000000000..4cbd7cafbb --- /dev/null +++ b/examples/megatron/grpo/dense_colocate.sh @@ -0,0 +1,65 @@ +# DP size = world_size // (context_parallel_size * tensor_model_parallel_size * pipeline_model_parallel_size) +# = 8 // (1 * 1 * 1) = 8 + +# NOTE: global_batch_size and micro_batch_size are completion-level +# global_batch_size = micro_batch_size * DP size * gradient_accumulation_steps (128) +# generation_batch_size = global_batch_size * steps_per_generation (128 * 4 = 512) +# num_of_prompt_to_rollout = generation_batch_size / num_generations (512 / 8 = 64) +# num_of_prompt_to_train = generation_batch_size / num_generations (128 / 8 = 16) + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +NPROC_PER_NODE=8 \ +MAX_PIXELS=602112 \ +MASTER_PORT=29600 \ +megatron rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --load_safetensors true \ + --save_safetensors true \ + --context_parallel_size 1 \ + --tensor_model_parallel_size 1 \ + --pipeline_model_parallel_size 1 \ + --dataset AI-ModelScope/clevr_cogen_a_train#10000 \ + --max_epochs 1 \ + --global_batch_size 128 \ + --micro_batch_size 4 \ + --steps_per_generation 4 \ + --num_generations 8 \ + --external_plugins examples/train/grpo/plugin/plugin.py \ + --reward_funcs external_r1v_acc format \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.7 \ + --vllm_max_model_len 10240 \ + --max_length 8192 \ + --max_completion_length 2048 \ + --train_type full \ + --lr 1e-6 \ + --bf16 true \ + --beta 0.001 \ + --importance_sampling_level token \ + --epsilon 0.2 \ + --epsilon_high 0.2 \ + --dynamic_sample false \ + --overlong_filter true \ + --loss_type grpo \ + --sleep_level 2 \ + --offload_model true \ + --offload_optimizer true \ + --log_interval 1 \ + --recompute_granularity selective \ + --finetune \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim \ + --no_save_rng \ + --attention_backend flash \ + --temperature 1.0 \ + --system examples/train/grpo/prompt.txt \ + --padding_free true \ + --log_completions true \ + --wandb_project megatron_swift \ + --wandb_exp_name megatron_grpo \ + --train_iters 100 \ + --eval_interval 1000 \ + --save_interval 1000 diff --git a/examples/megatron/grpo/dense_server.sh b/examples/megatron/grpo/dense_server.sh new file mode 100644 index 0000000000..ee702800e2 --- /dev/null +++ b/examples/megatron/grpo/dense_server.sh @@ -0,0 +1,72 @@ +# MAX_PIXELS=602112 \ +# CUDA_VISIBLE_DEVICES=6,7 \ +# swift rollout \ +# --model Qwen/Qwen2.5-VL-3B-Instruct \ +# --vllm_data_parallel_size 2 \ +# --vllm_max_model_len 10240 + +# DP size = world_size // (context_parallel_size * tensor_model_parallel_size * pipeline_model_parallel_size) +# = 6 // (1 * 1 * 1) = 6 + +# NOTE: global_batch_size and micro_batch_size are completion-level +# global_batch_size = micro_batch_size * DP size * gradient_accumulation_steps (96) +# generation_batch_size = global_batch_size * steps_per_generation (96 * 4 = 384) +# num_of_prompt_to_rollout = generation_batch_size / num_generations (384 / 8 = 48) +# num_of_prompt_to_train = generation_batch_size / num_generations (96 / 8 = 12) + +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 \ +NPROC_PER_NODE=6 \ +MAX_PIXELS=602112 \ +MASTER_PORT=29600 \ +megatron rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --load_safetensors true \ + --save_safetensors true \ + --context_parallel_size 1 \ + --tensor_model_parallel_size 1 \ + --pipeline_model_parallel_size 1 \ + --dataset AI-ModelScope/clevr_cogen_a_train#10000 \ + --max_epochs 1 \ + --global_batch_size 96 \ + --micro_batch_size 4 \ + --steps_per_generation 4 \ + --num_generations 8 \ + --external_plugins examples/train/grpo/plugin/plugin.py \ + --reward_funcs external_r1v_acc format \ + --use_vllm true \ + --vllm_mode server \ + --vllm_server_host 127.0.0.1 \ + --vllm_server_port 8000 \ + --max_length 8192 \ + --max_completion_length 2048 \ + --train_type full \ + --lr 1e-6 \ + --bf16 true \ + --beta 0.001 \ + --importance_sampling_level token \ + --epsilon 0.2 \ + --epsilon_high 0.2 \ + --dynamic_sample false \ + --overlong_filter true \ + --loss_type grpo \ + --sleep_level 2 \ + --offload_model true \ + --offload_optimizer true \ + --log_interval 1 \ + --recompute_granularity selective \ + --finetune \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim \ + --no_save_rng \ + --attention_backend flash \ + --temperature 1.0 \ + --system examples/train/grpo/prompt.txt \ + --padding_free true \ + --log_completions true \ + --wandb_project megatron_swift \ + --wandb_exp_name megatron_grpo \ + --train_iters 100 \ + --eval_interval 1000 \ + --save_interval 1000 diff --git a/examples/megatron/grpo/moe_colocate_full.sh b/examples/megatron/grpo/moe_colocate_full.sh new file mode 100644 index 0000000000..7b66688fd9 --- /dev/null +++ b/examples/megatron/grpo/moe_colocate_full.sh @@ -0,0 +1,55 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +NPROC_PER_NODE=8 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +megatron rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --load_safetensors true \ + --save_safetensors true \ + --context_parallel_size 1 \ + --tensor_model_parallel_size 4 \ + --expert_model_parallel_size 4 \ + --pipeline_model_parallel_size 2 \ + --dataset open-r1/DAPO-Math-17k-Processed \ + --max_epochs 1 \ + --global_batch_size 8 \ + --micro_batch_size 1 \ + --steps_per_generation 1 \ + --num_generations 8 \ + --reward_funcs accuracy format \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.4 \ + --vllm_tensor_parallel_size 8 \ + --vllm_max_model_len 16384 \ + --max_length 8192 \ + --max_completion_length 8192 \ + --train_type full \ + --lr 1e-6 \ + --bf16 true \ + --beta 0.00 \ + --importance_sampling_level sequence \ + --epsilon 3e-4 \ + --epsilon_high 4e-4 \ + --dynamic_sample false \ + --overlong_filter true \ + --loss_type grpo \ + --sleep_level 2 \ + --offload_model true \ + --offload_optimizer true \ + --optimizer_cpu_offload true \ + --use_precision_aware_optimizer \ + --log_interval 1 \ + --recompute_granularity selective \ + --finetune \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim \ + --no_save_rng \ + --attention_backend flash \ + --temperature 1.0 \ + --padding_free true \ + --sequence_parallel true \ + --log_completions true \ + --wandb_project megatron_swift \ + --wandb_exp_name megatron_grpo \ diff --git a/examples/megatron/grpo/moe_colocate_lora.sh b/examples/megatron/grpo/moe_colocate_lora.sh new file mode 100644 index 0000000000..361a233e6c --- /dev/null +++ b/examples/megatron/grpo/moe_colocate_lora.sh @@ -0,0 +1,53 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +NPROC_PER_NODE=8 \ +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +megatron rlhf \ + --rlhf_type grpo \ + --model Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --load_safetensors true \ + --save_safetensors true \ + --context_parallel_size 2 \ + --tensor_model_parallel_size 2 \ + --expert_model_parallel_size 4 \ + --pipeline_model_parallel_size 2 \ + --dataset open-r1/DAPO-Math-17k-Processed \ + --max_epochs 1 \ + --global_batch_size 64 \ + --micro_batch_size 2 \ + --steps_per_generation 2 \ + --num_generations 8 \ + --reward_funcs accuracy format \ + --use_vllm true \ + --vllm_mode colocate \ + --vllm_gpu_memory_utilization 0.3 \ + --vllm_tensor_parallel_size 4 \ + --vllm_max_model_len 16384 \ + --max_length 8192 \ + --max_completion_length 8192 \ + --train_type lora \ + --lr 5e-5 \ + --bf16 true \ + --beta 0.00 \ + --importance_sampling_level sequence \ + --epsilon 3e-4 \ + --epsilon_high 4e-4 \ + --dynamic_sample false \ + --overlong_filter true \ + --loss_type grpo \ + --sleep_level 2 \ + --offload_model true \ + --offload_optimizer true \ + --log_interval 1 \ + --recompute_granularity selective \ + --finetune \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim \ + --no_save_rng \ + --attention_backend flash \ + --temperature 1.0 \ + --padding_free true \ + --sequence_parallel true \ + --log_completions true \ + --wandb_project megatron_swift \ + --wandb_exp_name megatron_grpo \ diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 2cba486e3f..588a4ee621 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -925,3 +925,10 @@ def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: ], dataset_name='self-cognition', tags=['chat', 'self-cognition', '🔥'])) + +register_dataset( + DatasetMeta( + ms_dataset_id='open-r1/DAPO-Math-17k-Processed', + hf_dataset_id='open-r1/DAPO-Math-17k-Processed', + subsets=['all'], + tags=['math', 'rlvr'])) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index a402e56af6..676158ff83 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1280,6 +1280,8 @@ def _handle_megatron_cp(self, encoded: Dict[str, Any]) -> None: cp_size = self.sequence_parallel_size if not self.use_megatron or cp_size == 1: return + if self.mode == 'vllm': # skip for megatron grpo rollout + return input_ids = encoded['input_ids'] padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids) input_ids += [self.tokenizer.pad_token_id] * padding_len diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 556a631f36..4609df7948 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -17,7 +17,7 @@ @dataclass class RLHFMegatronArgumentsMixin: - rlhf_type: Literal['dpo', 'kto', 'rm'] = None + rlhf_type: Literal['dpo', 'kto', 'grpo', 'rm'] = None ref_load: Optional[str] = None ref_adapter_load: Optional[str] = None @@ -36,6 +36,97 @@ class RLHFMegatronArgumentsMixin: # rm center_rewards_coefficient: Optional[float] = None + # grpo + generation_batch_size: Optional[int] = None + steps_per_generation: Optional[int] = None + num_generations: int = 8 + max_completion_length: int = 512 + # GSPO https://arxiv.org/abs/2507.18071 + importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token' + + epsilon: float = 0.2 + epsilon_high: Optional[float] = None + delta: Optional[float] = None + top_k: int = 50 + top_p: float = 0.9 + repetition_penalty: float = 1. + use_vllm: bool = True + vllm_mode: Literal['server', 'colocate'] = 'colocate' + + vllm_enable_prefix_caching: bool = True + vllm_gpu_memory_utilization: float = 0.9 + vllm_tensor_parallel_size: int = 1 + vllm_max_model_len: Optional[int] = None + vllm_enforce_eager: bool = False + vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}' + vllm_disable_cascade_attn: bool = False + sleep_level: Literal[0, 1, 2] = 0 + offload_optimizer: bool = False + offload_model: bool = False + + vllm_server_base_url: Optional[List[str]] = None + vllm_server_host: Optional[List[str]] = None + vllm_server_port: List[int] = field(default_factory=lambda: [8000]) + vllm_server_timeout: float = 240.0 + + reward_funcs: List[str] = field(default_factory=list) + reward_weights: List[float] = None + # see details in swift/plugin/orm.py + # cosine reward, https://arxiv.org/abs/2502.03373 + cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length. + cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length. + cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length. + cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length. + cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length + # repetition penalty, https://arxiv.org/abs/2502.03373 + repetition_n_grams: int = 3 + repetition_max_penalty: float = -1.0 + # soft_overlong, https://arxiv.org/abs/2503.14476 + soft_max_length: Optional[int] = None + soft_cache_length: Optional[int] = None + # DAPO, https://arxiv.org/abs/2503.14476 + dynamic_sample: bool = False + max_resample_times: int = 3 + overlong_filter: bool = False + + # Dr. GRPO, https://arxiv.org/abs/2503.20783 + scale_rewards: Literal['none', 'group', 'batch'] = 'group' + + wandb_log_unique_prompts: Optional[bool] = None + log_completions: bool = False + + # ─────────────────────────── Not Supported Yet ─────────────────────────── + # RLOO / REINFORCE++ + advantage_estimator: Literal['grpo', 'rloo', 'reinforce_plus_plus'] = 'grpo' + kl_in_reward: bool = False + # reward model + reward_model: Optional[List[str]] = None + reward_model_plugin: Optional[List[str]] = None + # sync ref model + sync_ref_model: bool = False + ref_model_sync_steps: int = 512 + ref_model_mixup_alpha: float = 0.6 + + async_generate: bool = False + + move_model_batches: Optional[int] = None + + # multi turn + multi_turn_scheduler: Optional[str] = None + max_turns: Optional[int] = None + completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round' + vllm_server_pass_dataset: bool = False + + # entropy + log_entropy: bool = False + # Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939 + top_entropy_quantile: float = 1.0 + + num_iterations: int = 1 + + # dataset + dataset_shuffle: Optional[bool] = True + def _init_kto(self): if self.calculate_KL is None: # Not all losses require a KL calculation @@ -46,11 +137,104 @@ def _init_kto(self): def __post_init__(self): if self.rlhf_type is None: return - default_loss_type = {'kto': 'kto', 'dpo': 'sigmoid'} + default_loss_type = {'kto': 'kto', 'dpo': 'sigmoid', 'grpo': 'grpo'} if self.loss_type is None: self.loss_type = default_loss_type.get(self.rlhf_type) if self.rlhf_type == 'kto': self._init_kto() + if self.rlhf_type == 'grpo': + self._init_grpo() + + def _init_grpo(self): + + def _check_not_supported(): + if self.async_generate: + raise ValueError('async_generate is not supported for Megatron GRPO right now') + if self.sync_ref_model: + raise ValueError('sync_ref_model is not supported for Megatron GRPO right now') + if not self.dataset_shuffle: + raise ValueError('dataset_shuffle false is not supported for Megatron GRPO') + if self.multi_turn_scheduler: + raise ValueError('multi_turn_scheduler is not supported for Megatron GRPO right now') + if self.log_entropy: + raise ValueError('log_entropy is not supported for Megatron GRPO right now') + if self.top_entropy_quantile < 1: + raise ValueError('top_entropy_quantile < 1 is not supported for Megatron GRPO right now') + if self.num_iterations > 1: + raise ValueError('num_iterations > 1 is not supported for Megatron GRPO right now') + if self.kl_in_reward: + raise ValueError('kl_in_reward is not supported for Megatron GRPO right now') + if self.advantage_estimator != 'grpo': + raise ValueError('advantage_estimator must be grpo for Megatron GRPO right now') + + def _check_batch_params(): + # Set default values if both are None + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = 1 + self.generation_batch_size = self.global_batch_size * self.steps_per_generation + # Both configured - error + elif self.generation_batch_size is not None and self.steps_per_generation is not None: + raise ValueError("'generation_batch_size' and 'steps_per_generation' cannot be both configured") + # Only generation_batch_size configured + elif self.generation_batch_size is not None: + if self.generation_batch_size % self.global_batch_size != 0: + raise ValueError(f'generation_batch_size ({self.generation_batch_size}) ' + f'must be divisible by global_batch_size ({self.global_batch_size})') + self.steps_per_generation = self.generation_batch_size // self.global_batch_size + # Only steps_per_generation configured + else: + self.generation_batch_size = self.global_batch_size * self.steps_per_generation + + world_size = torch.distributed.get_world_size() + dp_size = world_size // ( + self.pipeline_model_parallel_size * self.tensor_model_parallel_size * self.context_parallel_size) + num_rollout_prompt = self.generation_batch_size // self.num_generations + if num_rollout_prompt % dp_size != 0: + raise ValueError(f'num_rollout_prompt ({num_rollout_prompt}) = generation_batch_size ' + f'({self.generation_batch_size}) // num_generations ({self.num_generations}) ' + f'must be divisible by dp_size ({dp_size}). ' + f'Please adjust generation_batch_size/steps_per_generation/num_generations.') + + per_device_num_rollout_prompt = num_rollout_prompt // dp_size + + if per_device_num_rollout_prompt % self.micro_batch_size != 0: + raise ValueError(f'Per-device rollout prompt count ({per_device_num_rollout_prompt}) = ' + f'(generation_batch_size ({self.generation_batch_size}) // ' + f'num_generations ({self.num_generations})) // dp_size ({dp_size}) ' + f'must be divisible by micro_batch_size ({self.micro_batch_size}). ' + f'Please adjust arguments to satisfy: ' + f'(generation_batch_size // num_generations) // dp_size % ' + f'micro_batch_size == 0') + + self.per_device_generation_batch_size = self.generation_batch_size // world_size + + _check_not_supported() + _check_batch_params() + # default loss_type if no loss_type is provided + assert self.loss_type in ['grpo', 'bnpo', 'dr_grpo'], \ + f'loss_type must be one of [grpo, bnpo, dr_grpo], but got {self.loss_type}' + self.remove_unused_columns = False + logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}') + if self.truncation_strategy is None: + self.truncation_strategy = 'left' + assert self.truncation_strategy in ['left', 'delete' + ], ("GRPO requires `truncation_strategy 'left' or 'delete'`, " + f"Current value: `truncation_strategy='{self.truncation_strategy}'`." + ) # noqa + if self.beta is None: + self.beta = 0.04 # https://arxiv.org/abs/2402.03300 + if self.async_generate: + logger.info('Using async mode. This is a approximate version which ' + 'will use the old weights to generate responses to accelerate. ' + 'This will ignore the `CLIP` of advantages, if you found the training ' + 'is unstable, you may consider using --async_generate false.') + if 'soft_overlong' in self.reward_funcs: + assert self.soft_cache_length is not None, \ + 'The soft_cache_length must be set when using soft overlong rewards.' + if self.soft_max_length is None: + self.soft_max_length = self.max_completion_length + logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}') + assert self.use_vllm, 'use_vllm must be True for Megatron GRPO' @dataclass diff --git a/swift/megatron/argument/rlhf_args.py b/swift/megatron/argument/rlhf_args.py index 127f928ef1..5106175e2f 100644 --- a/swift/megatron/argument/rlhf_args.py +++ b/swift/megatron/argument/rlhf_args.py @@ -7,7 +7,7 @@ @dataclass class MegatronRLHFArguments(MegatronTrainArguments): - rlhf_type: Literal['dpo', 'kto', 'rm'] = 'dpo' + rlhf_type: Literal['dpo', 'kto', 'grpo', 'rm'] = 'dpo' loss_scale: str = 'last_round' calculate_per_token_loss: bool = False diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index 3a25bda9fe..d27f5aabb2 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -2,9 +2,10 @@ from typing import List, Optional, Union from swift.llm.train.kto import prepare_kto_dataset -from swift.utils import get_logger +from swift.trainers.rlhf_trainer.utils import identity_data_collator +from swift.utils import get_current_device, get_logger, is_last_rank from ..argument import MegatronRLHFArguments -from ..trainers import MegatronDPOTrainer, MegatronKTOTrainer, MegatronRewardTrainer +from ..trainers import MegatronDPOTrainer, MegatronGRPOTrainer, MegatronKTOTrainer, MegatronRewardTrainer from .sft import MegatronSft logger = get_logger() @@ -16,18 +17,29 @@ class MegatronRLHF(MegatronSft): def prepare_trainer(self): args = self.args - trainer_mapping = {'dpo': MegatronDPOTrainer, 'kto': MegatronKTOTrainer, 'rm': MegatronRewardTrainer} + trainer_mapping = { + 'dpo': MegatronDPOTrainer, + 'grpo': MegatronGRPOTrainer, + 'kto': MegatronKTOTrainer, + 'rm': MegatronRewardTrainer + } trainer_cls = trainer_mapping.get(args.rlhf_type) if trainer_cls is None: raise ValueError(f'The current Megatron-SWIFT does not support rlhf_type: {args.rlhf_type}.') - return trainer_cls(args, self.template) + kwargs = {} + if args.rlhf_type == 'grpo': + kwargs['vllm_client'] = self._prepare_vllm_client() + return trainer_cls(args, self.template, **kwargs) def _prepare_template(self) -> None: super()._prepare_template() - if self.args.rlhf_type == 'kto': - self.template.set_mode('kto') - else: - self.template.set_mode('rlhf') + model_mapping = {'grpo': 'train', 'kto': 'kto'} + self.template.set_mode(model_mapping.get(self.args.rlhf_type, 'rlhf')) + + def _get_data_collator(self): + if self.args.rlhf_type == 'grpo': + return identity_data_collator + return super()._get_data_collator() def _get_dataset(self): args = self.args @@ -36,6 +48,23 @@ def _get_dataset(self): train_dataset, val_dataset = prepare_kto_dataset(args, train_dataset, val_dataset) return train_dataset, val_dataset + def _prepare_vllm_client(self): + if self.args.rlhf_type != 'grpo' or (self.args.vllm_mode != 'server'): + return + from swift.trainers.rlhf_trainer.vllm_client import VLLMClient + vllm_client = None + if is_last_rank(): + logger.info('Start connecting to vLLM server') + vllm_client = VLLMClient( + base_urls=self.args.vllm_server_base_url, + hosts=self.args.vllm_server_host, + server_ports=self.args.vllm_server_port, + connection_timeout=self.args.vllm_server_timeout) + vllm_client.close_communicator() + vllm_client.init_communicator(device=get_current_device()) + logger.info('Connected to vLLM server') + return vllm_client + def megatron_rlhf_main(args: Optional[Union[List[str], MegatronRLHFArguments]] = None): return MegatronRLHF(args).main() diff --git a/swift/megatron/trainers/__init__.py b/swift/megatron/trainers/__init__.py index a70e1d7f10..80cf16fe22 100644 --- a/swift/megatron/trainers/__init__.py +++ b/swift/megatron/trainers/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .dpo_trainer import MegatronDPOTrainer +from .grpo_trainer import MegatronGRPOTrainer from .kto_trainer import MegatronKTOTrainer from .reward_trainer import MegatronRewardTrainer from .trainer import MegatronTrainer diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 164fe0ee0a..2b6d938cc4 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -31,7 +31,7 @@ from packaging import version from tqdm.auto import tqdm -from swift.llm import dynamic_gradient_checkpointing +from swift.llm import Template, dynamic_gradient_checkpointing from swift.plugin import MeanMetric from swift.trainers import SwiftMixin from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger @@ -50,11 +50,12 @@ class BaseMegatronTrainer(ABC): - def __init__(self, args, template): + def __init__(self, args, template: Template): self.args = args self.template = template self.stimer = StragglerDetector() self.unwrapped_models = [] + self.wrapped_models = [] self.peft_models = [] self._bridge = None logging_path = os.path.join(args.save, 'logging.jsonl') @@ -86,9 +87,11 @@ def initialize_megatron(*_args, **kwargs): args = get_args() data_parallel_size = mpu.get_data_parallel_world_size() step_batch_size = args.micro_batch_size * data_parallel_size + num_generations = args.num_generations if hasattr(args, 'num_generations') else 1 if args.train_iters is None and args.max_epochs is not None: if hasattr(train_dataset, '__len__'): dataset_sample = len(train_dataset) // step_batch_size * step_batch_size + dataset_sample = dataset_sample * num_generations args.train_iters = dataset_sample * args.max_epochs // args.global_batch_size else: raise ValueError( @@ -98,6 +101,7 @@ def initialize_megatron(*_args, **kwargs): args.eval_iters = 0 elif hasattr(val_dataset, '__len__'): dataset_sample = len(val_dataset) // step_batch_size * step_batch_size + dataset_sample = dataset_sample * num_generations args.eval_iters = max(dataset_sample // args.global_batch_size, 1) else: raise ValueError( @@ -419,6 +423,7 @@ def new_model_provider_func(*_args, **kwargs): with self._patch_load_state_dict(self._load_base_checkpoint), self._patch_get_param_groups(): model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer( new_model_provider_func, model_type, *_args, **kwargs) + self.wrapped_models = model if args.initialize_embedding: for m in self.unwrapped_models: self._initialize_embedding(m) @@ -937,6 +942,7 @@ def _patch_megatron(self): # support max_epochs self._origin_train_step = training.train_step training.train_step = self.train_step + self._origin_cyclic_iter = training.cyclic_iter training.cyclic_iter = self.new_cyclic_iter # patch training_log self._origin_training_log = training.training_log diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py new file mode 100644 index 0000000000..d3253d4b39 --- /dev/null +++ b/swift/megatron/trainers/grpo_trainer.py @@ -0,0 +1,1405 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import base64 +import gc +import inspect +import os +import uuid +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from copy import copy, deepcopy +from functools import partial +from typing import Any, Dict, List, Tuple, Union + +import json +import pandas as pd +import torch +import torch.nn as nn +from accelerate.utils import broadcast_object_list +from dacite import from_dict +from megatron.core import mpu +from megatron.core.rerun_state_machine import RerunDataIterator +from megatron.training import get_args, get_wandb_writer, training +from trl.trainer.grpo_trainer import nanstd +from vllm.distributed import parallel_state as vllm_ps + +from swift.llm import RequestConfig, RolloutInferRequest, RowPreprocessor, Template, to_device +from swift.llm.infer.protocol import RolloutOutput +from swift.llm.template.template_inputs import TemplateInputs +from swift.plugin import MultiTurnScheduler, multi_turns, orms +from swift.trainers.rlhf_trainer.grpo_trainer import DataType +from swift.trainers.rlhf_trainer.utils import (FlattenedTensorBucket, aggressive_empty_cache, + replace_assistant_response_with_ids, set_expandable_segments) +from swift.utils import (get_current_device, get_logger, is_last_rank, is_vllm_available, is_wandb_available, + remove_response) +from ..argument import MegatronArguments, MegatronRLHFArguments +from ..utils import forward_step_helper +from .rlhf_mixin import MegatronRLHFTrainer +from .utils import (gather, gather_object, get_swift_datasets_provider, load_megatron_model_to_gpu, + load_megatron_optimizer, offload_megatron_model_to_cpu, offload_megatron_optimizer, + profiling_context, profiling_decorator) + +if is_wandb_available(): + import wandb + +logger = get_logger() + + +class MegatronGRPOTrainer(MegatronRLHFTrainer): + + def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs): + self.vllm_client = kwargs.pop('vllm_client') + super().__init__(args, template) + self.args = args + self.hf_model_dir = args.model_info.model_dir + self.processing_class = self.template.processor + self._prepare_metrics() + self._prepare_template_data_collator() + self._init_grpo_params() + self._prepare_rewards() + self._prepare_scheduler() # TODO + self._prepare_rollout_engine() + + def train(self, train_dataset, val_dataset, data_collator): + # Store dataset provider for lazy resample iterator initialization + if self.dynamic_sample: + self._train_valid_test_dataset_provider = get_swift_datasets_provider(train_dataset, val_dataset) + self._train_valid_test_dataset_provider.is_distributed = True + super().train(train_dataset, val_dataset, data_collator) + + def _prepare_template_data_collator(self): + template = self.template + args = self.args + data_collator = template.data_collator + padding_to = None + if args.tensor_model_parallel_size > 1 and args.sequence_parallel: + padding_to = args.tensor_model_parallel_size + if args.context_parallel_size > 1: + padding_to = (padding_to or 1) * args.context_parallel_size + if args.fp8_format: + padding_to = max((padding_to or 1) * 8, 16) + logger.info(f'padding_to: {padding_to}') + data_collator = partial(data_collator, padding_to=padding_to) + template.data_collator = data_collator + + def _init_grpo_params(self): + args: MegatronArguments = self.args + # distributed params + self.world_size = torch.distributed.get_world_size() + self.process_index = torch.distributed.get_rank() + self.is_main_process = is_last_rank() + self.device = get_current_device() + # algorithm params + self.num_generations = args.num_generations # G in the GRPO paper + self.beta = args.beta + self.temperature = args.temperature + self.loss_type = args.loss_type + self.max_completion_length = args.max_completion_length + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + self.top_entropy_quantile = args.top_entropy_quantile + self.importance_sampling_level = args.importance_sampling_level + self.enable_offload = False + + # DAPO, https://arxiv.org/abs/2503.14476 + self.dynamic_sample = args.dynamic_sample + self.max_resample_times = args.max_resample_times + self.overlong_filter = args.overlong_filter + + # Dr. GRPO / RLOO / REINFORCE++ + self.scale_rewards = args.scale_rewards + self.advantage_estimator = args.advantage_estimator # TODO + self.kl_in_reward = args.kl_in_reward # TODO + + # Entropy mask settings, TODO + self.log_entropy = args.log_entropy + self.compute_entropy = self.log_entropy or self.top_entropy_quantile < 1.0 + + # batch size (completion-level) + self.generation_batch_size = args.generation_batch_size + self.steps_per_generation = args.steps_per_generation + self.global_batch_size = args.global_batch_size + self.micro_batch_size = args.micro_batch_size + self.per_device_generation_batch_size = args.per_device_generation_batch_size + + # sampling params + self.request_config = RequestConfig( + n=1, + max_tokens=args.max_completion_length, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + stop=args.stop_words, + return_details=True) + + self._step = 0 + self._last_loaded_step = -1 + self._rollout_group = None # Will be lazily initialized + + def _prepare_rollout_engine(self): + args = self.args + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode + self.use_vllm = args.use_vllm + self.async_generate = args.async_generate # TODO + self.vllm_use_async_engine = False + self.enable_offload = False + self.use_gym_env = False + self.enable_server_multi_turn = False # TODO + # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs + assert self.use_vllm + if not is_vllm_available(): + raise ImportError('vLLM is not available and `use_vllm` is set to True. ' + 'Please install vLLM with `pip install vllm -U` to use it.') + if self.vllm_mode == 'server': + pass + elif self.vllm_mode == 'colocate': + if not self.world_size % self.vllm_tensor_parallel_size == 0: + raise ValueError(f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' + f'({self.world_size}) evenly.') + + self.enable_offload = self.args.offload_model or self.args.offload_optimizer + context = self.offload_context if self.enable_offload else nullcontext + + with context(): + set_expandable_segments(False) + self.engine = self.prepare_vllm() + if self.args.sleep_level > 0: + self.engine.engine.sleep(self.args.sleep_level) + set_expandable_segments(True) + else: + raise ValueError(f'Invalid vllm_mode: {self.vllm_mode}') + + def prepare_vllm(self): + from swift.llm.infer.infer_engine import GRPOVllmEngine + args = self.args + max_num_seqs = self.per_device_generation_batch_size * self.vllm_tensor_parallel_size + vllm_template = copy(self.template) + vllm_template.padding_free = False + engine = GRPOVllmEngine( + self.hf_model_dir, + args.torch_dtype, + model_type=args.model_type, + use_async_engine=False, + tensor_parallel_size=self.vllm_tensor_parallel_size, + gpu_memory_utilization=self.vllm_gpu_memory_utilization, + enable_prefix_caching=self.args.vllm_enable_prefix_caching, + max_num_seqs=max_num_seqs, + enforce_eager=self.args.vllm_enforce_eager, + limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt, + enable_sleep_mode=self.args.sleep_level > 0, + max_model_len=self.args.vllm_max_model_len, + seed=self.process_index // self.vllm_tensor_parallel_size, + disable_cascade_attn=self.args.vllm_disable_cascade_attn, + load_format='dummy', + template=vllm_template, + distributed_executor_backend='external_launcher', + ) + if self.vllm_tensor_parallel_size > 1: + self.vllm_tp_group = vllm_ps.get_tp_group().device_group + self._buffered_inputs = None + return engine + + @profiling_decorator + def _move_model_to_vllm(self): + # Handle LoRA: merge adapters before exporting weights + is_lora_training = self.args.train_type == 'lora' + + try: + if is_lora_training: + self.merge_lora_adapters() + + # Export and load weights incrementally to avoid memory spikes + self._export_and_load_weights() + + finally: + # Unmerge adapters to restore training state + if is_lora_training: + self.unmerge_lora_adapters() + + # Reset prefix cache + if self.vllm_mode == 'server' and self.is_main_process: + self.vllm_client.reset_prefix_cache() + elif self.vllm_mode == 'colocate': + self.engine.engine.reset_prefix_cache() + + @property + def bridge(self): + if self._bridge is None: + self._bridge = self.args.megatron_model_meta.bridge_cls(disable_tqmd=True) + return self._bridge + + def _export_and_load_weights(self): + """ + Export weights from Megatron models and load to vLLM incrementally. + + For colocate mode: llm_model.load_weights accepts an iterator, so pass it directly. + For server mode: Process weights in buckets to avoid memory spikes. + """ + # Export weights returns an iterator + with profiling_context(self, 'export_weights'): + weight_iterator = self.bridge.export_weights(self.unwrapped_models) + + if self.vllm_mode == 'colocate': + # Colocate mode: load_weights supports iterator, pass directly + llm_model = self.engine.inner_model + llm_model.load_weights(weight_iterator) + elif self.vllm_mode == 'server' and self.is_main_process: + # Server mode: process in buckets and sync with flattened tensors + self._load_weights_to_server_in_buckets(weight_iterator) + + def _load_weights_to_server_in_buckets(self, weight_iterator): + """ + Load weights to vLLM server in buckets using FlattenedTensorBucket. + + Args: + weight_iterator: Iterator of (name, tensor) tuples from export_weights + """ + # Get bucket size from environment or use default + bucket_size_mb = int(os.environ.get('SWIFT_UPDATE_WEIGHTS_BUCKET_SIZE', 512)) + bucket_size_bytes = bucket_size_mb * 1024 * 1024 + + current_bucket = [] + current_size = 0 + + for name, param in weight_iterator: + param_size = param.numel() * param.element_size() + current_bucket.append((name, param)) + current_size += param_size + + # If adding this param would exceed bucket size, process current bucket first + if current_size > bucket_size_bytes and current_bucket: + self._sync_bucket_to_server(current_bucket) + current_bucket = [] + current_size = 0 + + # Process remaining parameters in the last bucket + if current_bucket: + self._sync_bucket_to_server(current_bucket) + + def _sync_bucket_to_server(self, bucket_params: List[Tuple[str, torch.Tensor]]): + """ + Synchronize a bucket of parameters to vLLM server using flattened tensors. + + Args: + bucket_params: List of (name, tensor) tuples to sync + """ + if not bucket_params: + return + + # Create FlattenedTensorBucket for efficient transfer + bucket = FlattenedTensorBucket(named_tensors=bucket_params) + metadatas = bucket.get_metadata() + flattened_tensor = bucket.get_flattened_tensor() + + # Directly call vllm_client to update weights + self.vllm_client.update_flattened_params(metadatas, flattened_tensor) + + # Clean up to free memory immediately + del bucket, metadatas, flattened_tensor + + def _prepare_rewards(self): + # TODO: reward model + args = self.args + reward_funcs = args.reward_funcs + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + + # initilize reward functions + if reward_funcs: + for i, reward_func in enumerate(reward_funcs): + if reward_func in orms: + reward_func_class = orms[reward_func] + reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters) + reward_func_kwargs = { + key: getattr(args, key) + for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key) + } + if 'tokenizer' in reward_func_args: + reward_func_kwargs['tokenizer'] = self.processing_class + reward_funcs[i] = reward_func_class(**reward_func_kwargs) + elif not callable(reward_func): + raise ValueError(f'reward_function {reward_func} is not implemented in swift.plugin') + + # get reward name for logging + self.reward_funcs = reward_funcs + self.reward_func_names = [] + for reward_func in reward_funcs: + if inspect.isfunction(reward_func): + reward_func_name = reward_func.__name__ + else: + reward_func_name = reward_func.__class__.__name__ + self.reward_func_names.append(reward_func_name) + + # set reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward ' + f'functions ({len(reward_funcs)})') + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32).to(self.device) + else: + self.reward_weights = torch.ones(len(self.reward_func_names), dtype=torch.float32).to(self.device) + + # TODO: reward models + self.reward_model_plugins = [None] * len(self.reward_funcs) + + assert self.reward_funcs, 'reward_funcs is not set' + + def _prepare_scheduler(self): + """Prepare multi-turn scheduler""" + args = self.args + + self.multi_turn_scheduler = None + if not hasattr(args, 'multi_turn_scheduler'): + return + + if args.multi_turn_scheduler: + if isinstance(args.multi_turn_scheduler, str): + assert args.multi_turn_scheduler in multi_turns + multi_turn_scheduler = multi_turns[args.multi_turn_scheduler](max_turns=args.max_turns) + self.multi_turn_scheduler: MultiTurnScheduler = multi_turn_scheduler + else: + assert isinstance(args.multi_turn_scheduler, MultiTurnScheduler) + self.multi_turn_scheduler: MultiTurnScheduler = args.multi_turn_scheduler + + def _get_rollout_group(self): + """ + Get or create the rollout process group (TP×PP×CP). + + The rollout group is used for: + 1. Data slicing: distributing rollout data across all model parallel ranks (including CP) + 2. Gather operations: collecting results from all model parallel ranks (including CP) + + Note: MODEL_PARALLEL_GROUP only includes TP×PP, but we need TP×PP×CP for correct + data distribution during rollout phase. + + Key insight: ranks with the same DP index but different TP/PP/CP indices should be + in the same rollout group. These ranks will: + - During rollout: each process different data slices + - During training: TP/PP ranks process same data (model split), CP ranks process same data (sequence split) + - During gather: collect all data from TP×PP×CP ranks for training + """ + if self._rollout_group is not None: + return self._rollout_group + + cp_size = mpu.get_context_parallel_world_size() + if cp_size == 1: + # No CP, use the standard MODEL_PARALLEL_GROUP + self._rollout_group = mpu.get_model_parallel_group() + return self._rollout_group + + # Get parallel dimensions + tp_size = mpu.get_tensor_model_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + global_rank = torch.distributed.get_rank() + + # Calculate rollout group size + rollout_group_size = tp_size * pp_size * cp_size + + # Simple and reliable method: assume ranks are organized in contiguous blocks per DP group + # This is typically true for the default order (tp-cp-ep-dp-pp) + # Each DP group has rollout_group_size consecutive ranks + ranks_per_dp_group = rollout_group_size + my_dp_block_index = global_rank // ranks_per_dp_group + + # Calculate the rank range for my rollout group + group_start = my_dp_block_index * ranks_per_dp_group + + # Create all rollout groups (must be done on all ranks) + if not hasattr(self, '_rollout_groups_created'): + for dp_idx in range(dp_size): + group_start = dp_idx * ranks_per_dp_group + group_ranks = list(range(group_start, min(group_start + ranks_per_dp_group, self.world_size))) + group = torch.distributed.new_group(ranks=group_ranks, group_desc='ROLLOUT_GROUP') + if global_rank in group_ranks: + self._rollout_group = group + self._rollout_groups_created = True + + return self._rollout_group + + def _init_resample_data_iterator(self): + """ + Initialize an independent data iterator for dynamic resampling (lazy initialization). + + This method is called lazily during the first dynamic resampling, ensuring that + pretrain() has already called initialize_megatron() to properly set up all args. + Uses a different seed (args.seed + 1) to avoid overlapping with training samples. + + Note: pretrain() will automatically reset the random seed back to args.seed + after this method completes, so we don't need manual state restoration. + + Args: + train_valid_test_dataset_provider: Dataset provider function + + Returns: + train_data_iterator: Independent data iterator with different random seed + """ + from megatron.training.training import build_train_valid_test_data_iterators + from megatron.training.initialize import _set_random_seed + from megatron.training import training + training.cyclic_iter = self._origin_cyclic_iter + args = get_args() + + train_valid_test_dataset_provider = self._train_valid_test_dataset_provider + # Use different seed for resample iterator (offset by 1 to avoid overlap) + resample_seed = getattr(args, 'seed', 42) + 1 + try: + # Set new seed for resample iterator creation + _set_random_seed( + resample_seed, + args.data_parallel_random_init, + args.te_rng_tracker, + args.inference_rng_tracker, + use_cudagraphable_rng=args.enable_cuda_graph, + ) + + # Build data iterators with new seed + # TODO: VPP (Virtual Pipeline Parallelism) + resample_data_iterator, _, _ = (build_train_valid_test_data_iterators(train_valid_test_dataset_provider)) + finally: + # Restore original random states to avoid affecting training + _set_random_seed( + args.seed, + args.data_parallel_random_init, + args.te_rng_tracker, + args.inference_rng_tracker, + use_cudagraphable_rng=args.enable_cuda_graph, + ) + return resample_data_iterator + + def _replace_data_iterator(self, data_iterator, model): + if self._step % self.steps_per_generation == 0: + num_iters_per_step = self.get_num_iters_per_step() + rollout_batch = [] + for _ in range(num_iters_per_step): + rollout_batch.extend(next(data_iterator)) + micro_batch_data = self._generate_and_score_completions(rollout_batch) + num_mini_batch = self.global_batch_size // (self.micro_batch_size * mpu.get_data_parallel_world_size()) + mini_batch_data = [ + micro_batch_data[i:i + num_mini_batch] for i in range(0, len(micro_batch_data), num_mini_batch) + ] + assert len(mini_batch_data) == self.steps_per_generation + self._buffered_inputs = mini_batch_data + self._step += 1 + inputs = self._buffered_inputs[self._step % self.steps_per_generation] + return RerunDataIterator(iter(inputs)) + + def _generate_and_score_completions(self, batch): + # Get or create the rollout group (TP×PP×CP) + rollout_group = self._get_rollout_group() + + rollout_batch = self.get_local_rollout_batch(batch) + + rollout_batch = self._generate_completions(rollout_batch) + + rewards_per_func = self._score_completions(rollout_batch) + + # Dynamic sampling for std=0 groups (DAPO) + if self.dynamic_sample: + rollout_batch, rewards_per_func = self._dynamic_sampling(rollout_batch, rewards_per_func) + + advantages = self._compute_advantages(rollout_batch, rewards_per_func) + + def _get_encoded_batch(rollout_batch, advantages): + template = self.template + with self._template_context(template): + encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch] + encoded_batch = to_device(template.data_collator(encoded_batch), self.device) + labels = encoded_batch['labels'] + assert self.template.padding_free + position_ids = encoded_batch.get('text_position_ids') + if position_ids is None: + position_ids = encoded_batch.get('position_ids') + squeezed_position_ids = position_ids.squeeze() + assert squeezed_position_ids is not None + # Remove trailing padding zeros from position_ids to avoid interference + # Find the last non-zero position + last_nonzero_idx = (squeezed_position_ids != 0).nonzero(as_tuple=True)[0] + if len(last_nonzero_idx) > 0: + # Keep only up to the last non-zero position + 1 to include the last valid position + squeezed_position_ids = squeezed_position_ids[:last_nonzero_idx[-1] + 1] + + # Calculate lengths based on sequence boundaries (position_ids == 0) + lengths = torch.diff( + torch.cat([(squeezed_position_ids == 0).nonzero(as_tuple=True)[0], + torch.tensor([len(squeezed_position_ids)]).to(squeezed_position_ids.device)])) + advantages = torch.repeat_interleave(advantages, lengths) + truncated_mask = torch.tensor([b['is_truncated'] for b in rollout_batch], + dtype=torch.bool, + device=self.device) + truncated_mask = torch.repeat_interleave(truncated_mask, lengths).unsqueeze(0) + padding_length = labels.shape[1] - truncated_mask.shape[1] + if padding_length > 0: + padding = torch.zeros((1, padding_length), device=truncated_mask.device, dtype=truncated_mask.dtype) + truncated_mask = torch.cat([truncated_mask, padding], dim=1) + # Pad advantages to match the original position_ids length + original_length = position_ids.shape[1] + if advantages.shape[0] < original_length: + padding_length = original_length - advantages.shape[0] + padding = torch.zeros(padding_length, device=advantages.device, dtype=advantages.dtype) + advantages = torch.cat([advantages, padding]) + + encoded_batch.update({ + 'completion_mask': labels != -100, + 'truncated_mask': truncated_mask, + 'advantages': advantages, + 'num_samples': len(rollout_batch), + }) + + return encoded_batch + + # Step2: ref/old logps + total_batch = gather_object(rollout_batch, group=rollout_group) + total_advantages = gather(advantages, group=rollout_group) + mini_batch_data = [] + for idx in range(0, len(total_batch), self.micro_batch_size): + micro_batch_data = total_batch[idx:idx + self.micro_batch_size] + micro_batch_data = self._maybe_replace_response_token(micro_batch_data) + micro_batch_advantages = total_advantages[idx:idx + self.micro_batch_size] + micro_batch_data = _get_encoded_batch(micro_batch_data, micro_batch_advantages) + with profiling_context(self, 'compute_ref_old_logps'): + micro_batch_data = self._maybe_compute_logps(micro_batch_data) + mini_batch_data.append(micro_batch_data) + + return mini_batch_data + + @profiling_decorator + def _generate_completions(self, batch): + """ + Generate completions for a batch of rollout data using vLLM engine. + + This method processes rollout data for the current process, generates completions + using the vLLM engine, and merges the results back into the original batch. + + Args: + batch: Rollout data assigned to the current process. + + Returns: + batch: The input batch with rollout completion results merged in. + """ + # TODO: server mode + # add prompt ids and system prompts + batch = self._preprocess_inputs(batch) + # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode) + if self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping: + wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters + # Load weights only (faster and reduces memory peak) + kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} + self.engine.engine.wake_up(**kwargs) + + # Step 2: Load model weights + if self._step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self._step + + context = self.offload_context if self.enable_offload else nullcontext + with context(): + if (self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping + and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): + aggressive_empty_cache() + set_expandable_segments(False) + self.engine.engine.wake_up(tags=['kv_cache']) + + # Step3: Rollout + outputs: List[RolloutOutput] = self._rollout(batch) + + # Step4: Sleep to release memory + if self.vllm_mode == 'colocate' and self.args.sleep_level > 0: + self.engine.engine.reset_prefix_cache() + self.engine.engine.sleep(level=self.args.sleep_level) + aggressive_empty_cache() + set_expandable_segments(True) + batch = self.postprocess_rollout_data(batch, outputs) + + return batch + + def _rollout(self, batch) -> List[RolloutOutput]: + batch = self._set_inputs_system(batch) + request_config = self._get_request_config() + # TODO: server mode + if self.vllm_mode == 'server': + rollout_outputs = self._server_rollout(batch, request_config) + elif self.vllm_mode == 'colocate': + rollout_outputs = self._colocate_rollout(batch, request_config) + # log prompt and completions + messages = gather_object([data['messages'] for data in batch]) + completions = gather_object([data.response.choices[0].message.content for data in rollout_outputs]) + self._logs['prompt'].extend(self._apply_chat_template_to_messages_list(messages)) + self._logs['completion'].extend(completions) + + return rollout_outputs + + def postprocess_rollout_data(self, batch, outputs): + """ + Post-process the raw vLLM generation outputs and merge them back into the + original input batch. + + Args: + batch (List[Dict[str, Any]]): + Original rollout samples. + outputs (List[RolloutOutput]): + outputs from vLLM from vLLM TP group + + Returns: + List[Dict[str, Any]]: + Updated samples with rollout results merged in. + """ + + def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], output: RolloutOutput): + response = output.response + choice = response.choices[0] + + # Step 1: Update or append assistant message + if output.messages: + input_data['messages'] = output.messages # Override full message history + else: + # not provided, append + messages = input_data['messages'] + remove_response(messages) + messages.append({'role': 'assistant', 'content': choice.message.content}) + # Step 2: Add token IDs and loss mask + if output.response_token_ids: + input_data['response_token_ids'] = output.response_token_ids + if output.response_loss_mask: + input_data['response_loss_mask'] = output.response_loss_mask + else: + # for single turn, skip tokenizer response + input_data['response_token_ids'] = output.response.choices[0].token_ids + + # Step 3: Attach rollout extra info + if output.rollout_infos: + input_data['rollout_infos'] = output.rollout_infos + + # Step 4: Store finish reason (used for truncation filters etc.) + input_data['finish_reason'] = choice.finish_reason + input_data['is_truncated'] = choice.finish_reason == 'length' + + return input_data + + assert len(batch) == len(outputs) + return [merge_output_input_data(input_data, output) for input_data, output in zip(batch, outputs)] + + def _get_request_config(self) -> RequestConfig: + request_config = copy(self.request_config) + if self.args.vllm_mode == 'colocate' and self.vllm_tensor_parallel_size > 1: + # Set request_config.seed + # 1. Ensure that the seed for vLLM Engines within each TP (Tensor Parallelism) group is the same; + # otherwise, the program may hang. + # 2. Ensure that the seed for vLLM Engines across different TP groups is different; + # otherwise, identical completions will be generated. + batch_size = self.per_device_generation_batch_size + batch_size *= self.vllm_tensor_parallel_size + # Since the TP (Tensor Parallelism) group gathers the inputs, + # multiply the batch size by the TP parallel size. + request_config.seed = batch_size * (self.process_index // self.vllm_tensor_parallel_size) + + return request_config + + def _server_rollout(self, + inputs: DataType, + request_config: RequestConfig, + is_global_inputs: bool = False) -> List[RolloutOutput]: + # TODO: async generate + infer_requests = self.inputs2requests(inputs) + + if is_global_inputs: + per_device_size = len(infer_requests) // self.world_size + all_requests = infer_requests + all_requests_lengths = [per_device_size] + [0] * (self.world_size - 1) + else: + all_requests = gather_object(infer_requests) + all_requests_lengths = gather_object([len(infer_requests)]) + + if not any(requests for requests in all_requests): + return [] + + if self.is_main_process: + all_outputs: List[RolloutOutput] = self.vllm_client.infer( + infer_requests=all_requests, request_config=request_config) + assert len(all_outputs) == len(all_requests) # TODO: dynamic num of samples + else: + all_outputs = [None] * len(all_requests) + + if not is_global_inputs: + all_outputs = broadcast_object_list(all_outputs, from_process=self.world_size - 1) + start_idx = sum(all_requests_lengths[:self.process_index]) + end_idx = start_idx + all_requests_lengths[self.process_index] + outputs = all_outputs[start_idx:end_idx] + else: + outputs = all_outputs if self.is_main_process else [] + return outputs + + def _colocate_rollout(self, batch, request_config: RequestConfig): + if self.vllm_tensor_parallel_size > 1: + local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) + local_input_length = len(batch) + all_input_lengths = [None] * self.vllm_tensor_parallel_size + torch.distributed.all_gather_object(all_input_lengths, local_input_length, group=self.vllm_tp_group) + + start_idx = sum(all_input_lengths[:local_rank_in_group]) + end_idx = start_idx + all_input_lengths[local_rank_in_group] + + gathered_batch = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_batch, batch, group=self.vllm_tp_group) + batch = [p for sublist in gathered_batch for p in sublist] + + outputs: List[RolloutOutput] = self.engine.infer(infer_requests=batch, request_config=request_config) + + if self.vllm_tensor_parallel_size > 1: + outputs = outputs[start_idx:end_idx] + + return outputs + + @profiling_decorator + def _score_completions(self, inputs: DataType) -> torch.Tensor: + """Score completions using all reward functions. + + Args: + inputs: List of input examples, each containing a 'messages' list with conversation history + + Returns: + rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with local reward values + """ + # Compute rewards using reward functions + local_rewards_per_func = self._compute_rewards_per_func(inputs) + + return local_rewards_per_func + + def _compute_rewards_per_func(self, batch: DataType) -> torch.Tensor: + """Compute rewards using all reward functions""" + device = self.device + rewards_per_func = torch.zeros((len(batch), len(self.reward_funcs)), device=device) + completions = [inp['messages'][-1]['content'] for inp in batch] + reward_kwargs = {} # TODO: training step info + for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate( + zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)): + with profiling_context(self, reward_func_name): + # reward model + if isinstance(reward_func, nn.Module): + output_reward_func = reward_model_plugin(inputs=batch, **reward_kwargs) + # reward function + else: + # Repeat all input columns (but "messages" and "completion") to match the number of generations + reward_kwargs.update(RowPreprocessor.rows_to_batched(batch)) + output_reward_func = reward_func(completions, **reward_kwargs) + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # If all reward functions return None for a given row, issue a detailed warning + if torch.isnan(rewards_per_func).all(dim=1).any(): + nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()} + row_reward_kwargs['completion'] = completions[nan_row_idx] + logger.warning(f'All reward functions returned None for the following kwargs: {row_reward_kwargs}. ' + 'Please ensure that at least one reward function returns a valid reward.') + + return rewards_per_func + + def _compute_advantages(self, batch: DataType, rewards_per_func: torch.Tensor) -> torch.Tensor: + """Compute advantages for RL training.""" + + def normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tensor) -> torch.Tensor: + """Normalize advantages if configured; otherwise, return as-is.""" + if self.scale_rewards != 'none': + return advantages / (rewards_std + 1e-4) + return advantages + + mode = 'train' if self.unwrapped_models[0].training else 'eval' + assert len(batch) == rewards_per_func.shape[0] + total_rewards_per_func = gather(rewards_per_func) + rewards = (total_rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1) + grouped_rewards = rewards.view(-1, self.num_generations) + + # Compute group statistics + group_rewards_mean = grouped_rewards.mean(dim=1) + + # Broadcast stats back to the original shape + group_rewards_mean = group_rewards_mean.repeat_interleave(self.num_generations) + + # Compute advantages relative to group mean + advantages = rewards - group_rewards_mean + + # Normalize advantages based on scale_rewards setting + if self.scale_rewards == 'batch': + # Global batch-level normalization + rewards_std = rewards.std().expand_as(rewards) + elif self.scale_rewards == 'group': + # Group-level normalization (default) + rewards_std = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations) + else: # 'none' + rewards_std = None + + if rewards_std is not None: + advantages = normalize_advantages(advantages, rewards_std) + + def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: torch.Tensor): + """Log reward statistics for monitoring. Only log once per unique request_id.""" + # rewards: [prompt_batch_size, self.num_generations] + # rewards_per_func_for_metrics: [prompt_batch_size*self.num_generations, self.num_reward_funcs] + group_rewards = rewards.view(-1, self.num_generations) + rewards_mean = group_rewards.mean(-1).mean().item() + # Compute std based on scale_rewards setting for logging + if self.scale_rewards in ['group', 'none']: + rewards_std = group_rewards.std(-1).mean().item() + elif self.scale_rewards == 'batch': + rewards_std = rewards.std().item() + is_std_zero = torch.isclose(group_rewards.std(dim=1), torch.zeros_like(group_rewards.std(dim=1))) + + self._metrics[mode]['reward'].append(rewards_mean) + self._metrics[mode]['reward_std'].append(rewards_std) + self._metrics[mode]['frac_reward_zero_std'].append(is_std_zero.float().mean().item()) + + # Log per-reward-function statistics using deduplicated rewards_per_func + for i, name in enumerate(self.reward_func_names): + col = rewards_per_func_for_metrics[:, i] + self._metrics[mode][f'rewards/{name}/mean'].append(torch.nanmean(col).item()) + self._metrics[mode][f'rewards/{name}/std'].append(nanstd(col).item()) + + log_rewards_metrics(rewards=grouped_rewards, rewards_per_func_for_metrics=total_rewards_per_func) + self._logs['advantages'].extend(advantages.tolist()) + for i, name in enumerate(self.reward_func_names): + self._logs['rewards'][name].extend(total_rewards_per_func[:, i].tolist()) + + slice_start = self.process_index * len(batch) + slice_end = slice_start + len(batch) + advantages = advantages[slice_start:slice_end] + + return advantages + + def _dynamic_sampling(self, rollout_batch: DataType, + rewards_per_func: torch.Tensor) -> Tuple[DataType, torch.Tensor]: + """ + Perform dynamic sampling to replace samples with zero-reward-variance groups. + + This method implements DAPO (https://arxiv.org/abs/2503.14476) by replacing + samples from groups with zero reward variance (std=0) through resampling. + + Args: + rollout_batch: local rollout data samples + rewards_per_func: reward per function for local data samples + rollout_group: rollout communication group + + Returns: + tuple: (rollout_batch, rewards_per_func) with zero-variance groups replaced by resampled data + """ + resample_count = 0 + valid_samples = [] + valid_rewards_per_func = [] + origin_data = (rollout_batch, rewards_per_func) + + while resample_count < self.max_resample_times: + # Gather all samples and rewards across rollout group first + global_rollout_batch = gather_object(rollout_batch) + global_rewards_per_func = gather(rewards_per_func) + + # Compute reward std for the entire global batch + # We need to compute std on the gathered data to get a global mask + global_rewards = (global_rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1) + grouped_rewards = global_rewards.view(-1, self.num_generations) + group_rewards_std = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations) + global_valid_mask = (group_rewards_std > 0) + + # Filter valid samples based on std > 0 + valid_samples.extend([sample for sample, mask in zip(global_rollout_batch, global_valid_mask) if mask]) + valid_rewards_per_func.append(global_rewards_per_func[global_valid_mask]) + + if len(valid_samples) >= self.generation_batch_size: + break + + # Lazy initialization of resample_data_iterator + # Only initialize when needed, after pretrain() has set up args + if not hasattr(self, 'resample_data_iterator') or self.resample_data_iterator is None: + self.resample_data_iterator = self._init_resample_data_iterator() + num_iters_per_step = self.get_num_iters_per_step() + next_rollout_prompt_batch = [] + for _ in range(num_iters_per_step): + next_rollout_prompt_batch.extend(next(self.resample_data_iterator)) + + # Repeat num_generations times and get local slice + rollout_batch = self.get_local_rollout_batch(next_rollout_prompt_batch) + + # Generate and score new completions + rollout_batch = self._generate_completions(rollout_batch) + rewards_per_func = self._score_completions(rollout_batch) + resample_count += 1 + + if len(valid_samples) >= self.generation_batch_size: + # Get local slice of valid samples + rank = self.process_index + per_device_batch_size = self.per_device_generation_batch_size + data_slice = slice(rank * per_device_batch_size, (rank + 1) * per_device_batch_size) + rollout_batch = valid_samples[:self.generation_batch_size][data_slice] + rewards_per_func = torch.cat(valid_rewards_per_func)[:self.generation_batch_size][data_slice] + else: + logger.warning(f'There are still std=0 groups present after {self.max_resample_times} retries.') + rollout_batch, rewards_per_func = origin_data + + return rollout_batch, rewards_per_func + + def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: + # TODO: entropy + inputs = {k: v for k, v in batch.items() if k not in ['completion_mask', 'advantages', 'truncated_mask']} + if self.beta != 0.0: + with torch.no_grad(), self.null_ref_context() as ref_models: + assert len(ref_models) == 1, 'GRPO currently does not support VPP.' + ref_model = ref_models[0] + batch['ref_per_token_logps'] = self.model_forward( + ref_model, iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps'] + + if not self.on_policy: + batch['old_per_token_logps'] = self.model_forward( + self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps'] + return batch + + @contextmanager + def _disable_maxlength_template_context(self, template: Template): + # The max_length for prompt and completion has already been restricted, so there is no need for max_length here. + max_length = template.max_length + template.max_length = None + try: + yield + finally: + template.max_length = max_length + + def _maybe_replace_response_token(self, batch): + # maybe replace the response token with the response token ids to avoid repetitive tokenize + + for data in batch: + if 'response_token_ids' in data and data['response_token_ids']: + loss_mask = None + if 'response_loss_mask' in data and data['response_loss_mask']: + loss_mask = data['response_loss_mask'] + # token in token out + data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids'], + loss_mask) + return batch + + @property + def on_policy(self): + return self.steps_per_generation == 1 + + @contextmanager + def patch_megatron_data_collator(self, data_collator): + """ + Context manager that temporarily patches Megatron's data-loader factory so each + prompt-level micro-batch size equals (original micro-batch size // num_generations), + required by GRPO. Restores the original size and loader on exit. + """ + origin_build_pretraining_data_loader = training.build_pretraining_data_loader + + def build_pretraining_data_loader(*_args, **kwargs): + args = get_args() + org_micro_batch_size = args.micro_batch_size + # args.micro_batch_size = org_micro_batch_size // self.num_generations + res = origin_build_pretraining_data_loader(*_args, **kwargs) + args.micro_batch_size = org_micro_batch_size + if res is not None and args.dataloader_type != 'external': + res.collate_fn = data_collator + return res + + training.build_pretraining_data_loader = build_pretraining_data_loader + try: + yield + finally: + training.build_pretraining_data_loader = origin_build_pretraining_data_loader + + @profiling_decorator + def forward_step(self, data_iterator, model): + # train_batch_size + # return: output_tensor, loss_func + data = self.get_batch(data_iterator) + data.pop('loss_scale', None) + inputs = { + k: v + for k, v in data.items() if k not in + ['completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', 'truncated_mask'] + } + + with self.stimer: + output_tensor = model(**inputs) + return output_tensor, partial(self.loss_func, data=data) + + @profiling_decorator + def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): + advantages = data['advantages'] + labels = data['labels'] + completion_mask = data['completion_mask'] + packed_seq_params = data['packed_seq_params'] + truncated_mask = data['truncated_mask'] + micro_batch_size = self.micro_batch_size + # Use full sequence lengths directly (get_logps returns full sequences in CP mode) + lengths = packed_seq_params.cu_seqlens_q[1:micro_batch_size + + 1] - packed_seq_params.cu_seqlens_q[:micro_batch_size] + lengths_with_padding = packed_seq_params.cu_seqlens_q[1:] - packed_seq_params.cu_seqlens_q[:-1] + + # get_logps with per_token=True now returns full sequences (all_gather in CP mode) + per_token_logps = self.get_logps( + output_tensor, labels, packed_seq_params, packed_seq_params.num_samples, per_token=True) + + if self.args.overlong_filter and truncated_mask.any(): + completion_mask = completion_mask & (~truncated_mask) + if not completion_mask.any(): + logger.warning('All completions are truncated in this batch. Loss and grad_norm will be 0. ' + 'Consider increasing max_completion_length') + + if self.beta != 0.0: + ref_per_token_logps = data.get('ref_per_token_logps') + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1) + + old_per_token_logps = ( + per_token_logps.detach() if data.get('old_per_token_logps') is None else data['old_per_token_logps']) + log_ratio = per_token_logps - old_per_token_logps + + if self.importance_sampling_level == 'token': + log_importance_weights = log_ratio + elif self.importance_sampling_level in ['sequence', 'sequence_token']: + log_ratio_list = torch.split(log_ratio.squeeze(0), lengths_with_padding.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) + # Optimized: compute weighted sum for each sequence (avoid list comprehension overhead) + # Use torch.stack on results instead of intermediate lists + seq_weights = torch.stack([(lr * m).sum() / m.sum().clamp(min=1.0) + for lr, m in zip(log_ratio_list, mask_list)]) + seq_level_log_weights = seq_weights.to(log_ratio.dtype).unsqueeze(-1) + if self.importance_sampling_level == 'sequence': + log_importance_weights = seq_level_log_weights + else: + seq_level_log_weight = seq_level_log_weights.detach() + # Vectorized: use repeat_interleave with tensor directly + seq_level_log_weight = torch.repeat_interleave( + seq_level_log_weight.squeeze(-1), lengths_with_padding, dim=0).unsqueeze(0) + log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight + else: + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " + ",'sequence' and 'sequence_token'.") + + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) + + if self.template.padding_free: + # In padding_free + sequence mode, coef_1 is [num_samples, 1] + # We need to expand to [1, total_tokens] for token-level loss computation + if self.importance_sampling_level == 'sequence': + # Vectorized: expand sequence-level weights to token-level without gradient + coef_1 = torch.repeat_interleave(coef_1.squeeze(-1), lengths_with_padding, dim=0).unsqueeze(0) + coef_2 = torch.repeat_interleave(coef_2.squeeze(-1), lengths_with_padding, dim=0).unsqueeze(0) + + advantages = advantages[-coef_1.shape[1]:] + per_token_loss1 = coef_1 * advantages.unsqueeze(0) + per_token_loss2 = coef_2 * advantages.unsqueeze(0) + else: + raise NotImplementedError + # per_token_loss1 = coef_1 * advantages.unsqueeze(1) + # per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == 'grpo': + loss_list = torch.split(per_token_loss.squeeze(0), lengths_with_padding.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) + + sample_loss = torch.stack([(loss * mask).sum() / mask.sum().clamp(min=1.0) + for loss, mask in zip(loss_list[:micro_batch_size], mask_list[:micro_batch_size]) + ]) + loss = sample_loss.mean() + elif self.loss_type == 'bnpo': + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + elif self.loss_type == 'dr_grpo': + loss = (per_token_loss * completion_mask).sum() / (micro_batch_size * self.max_completion_length) + else: + raise ValueError(f'Unknown loss type: {self.loss_type}') + + avg_metric = { + 'loss': loss.clone().detach(), + } + custom_metrics = {} + total_lengths = gather(lengths, group=mpu.get_data_parallel_group(with_context_parallel=True)) + custom_metrics = { + 'completions/mean_length': total_lengths.float().mean(), + 'completions/max_length': total_lengths.float().max(), + 'completions/min_length': total_lengths.float().min(), + } + + if self.beta != 0.0: + # Unified processing (no CP-specific logic needed) + kl_value = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + avg_metric['kl'] = kl_value.clone().detach() + + mode = 'train' if self.unwrapped_models[0].training else 'eval' + if self._metrics[mode]: + addition_metrics = { + key: torch.tensor(sum(val) / len(val), device=loss.device) + for key, val in self._metrics[mode].items() + } + avg_metric.update(addition_metrics) + + avg_metric = self._all_reduce_metric(avg_metric) + + reporting_metric = {**avg_metric, **custom_metrics} + + # log_completions + if self.log_completions and self.is_main_process and self._step % self.steps_per_generation == 0: + table = { + 'gen_step': [self._step] * len(self._logs['prompt']), + 'prompt': list(self._logs['prompt']), + 'completion': list(self._logs['completion']), + **{k: list(v) + for k, v in self._logs['rewards'].items()}, + 'advantages': list(self._logs['advantages']), + } + self.jsonl_writer.append(table) + wandb_writer = get_wandb_writer() + if wandb_writer: + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=['prompt']) + # if not self.init_custom_metric: + # wandb_writer.define_metric('completions', step_metric='gen_step') + # self.init_custom_metric = True + wandb_writer.log({'completions': wandb.Table(dataframe=df)}) + + return loss, reporting_metric + + def model_forward(self, model, data_iterator, no_grad=True, per_token=False): + # used to calculate model forward (logps) in GRPO + with self.stimer(bdata=True): + data = self.get_batch(data_iterator) + data.pop('loss_scale', None) + labels = data.get('labels') + context = torch.no_grad() if no_grad else nullcontext() + with context: + output_tensor = forward_step_helper(model, data) + packed_seq_params = data['packed_seq_params'] + data['logps'] = None if labels is None else self.get_logps( + output_tensor, labels, data['packed_seq_params'], packed_seq_params.num_samples, per_token=per_token) + return data + + @contextmanager + def offload_context(self): + if self.args.offload_model: + offload_megatron_model_to_cpu(self.wrapped_models) + if hasattr(self, 'ref_models') and self.ref_models: + offload_megatron_model_to_cpu(self.ref_models) + if getattr(self, 'optimizer', None) and self.args.offload_optimizer: + offload_megatron_optimizer(self.optimizer) + + try: + yield + finally: + # reload (load back) model when exiting context + if self.args.offload_model: + load_megatron_model_to_gpu(self.wrapped_models) + if hasattr(self, 'ref_models') and self.ref_models: + load_megatron_model_to_gpu(self.ref_models) + if getattr(self, 'optimizer', None) and self.args.offload_optimizer: + load_megatron_optimizer(self.optimizer) + + def inputs2requests(self, inputs: DataType) -> List[RolloutInferRequest]: + """Convert raw input data into RolloutInferRequest objects""" + + def _process_image_data(image_data: Union[dict, str]) -> str: + if isinstance(image_data, dict): + if image_data.get('bytes'): + return base64.b64encode(image_data['bytes']).decode('utf-8') + if image_data.get('path'): + return image_data['path'] + return image_data + + if not inputs: + return [] + args = self.args + + REQUEST_METADATA_FIELDS = ['messages', 'images', 'audios', 'videos', 'tools', 'objects', 'uuid'] + requests_dicts = [] + + for data in inputs: + request_data = {key: data[key] for key in REQUEST_METADATA_FIELDS if key in data and data[key] is not None} + if 'uuid' not in request_data: + request_data['uuid'] = data['request_id'] + if hasattr(args, 'vllm_server_pass_dataset') and args.vllm_server_pass_dataset: + extra_fields = { + k: v + for k, v in data.items() if k not in REQUEST_METADATA_FIELDS and data[k] is not None + } + if extra_fields: + request_data['data_dict'] = extra_fields + elif self.multi_turn_scheduler: + base_data_dict = {} + if 'data_dict' in data: + if isinstance(data['data_dict'], dict): + base_data_dict = data['data_dict'] + else: + raise ValueError('data_dict exists but is not a dictionary') + extra_data = { + k: v + for k, v in data.items() + if k not in REQUEST_METADATA_FIELDS and k != 'data_dict' and data[k] is not None + } + final_data_dict = {**extra_data, **base_data_dict} + request_data['data_dict'] = final_data_dict if final_data_dict else {} + + requests_dicts.append(request_data) + + for request in requests_dicts: + if 'images' in request and request['images']: + request['images'] = ([_process_image_data(img) for img in request['images']] if isinstance( + request['images'], list) else _process_image_data(request['images'])) + + return [from_dict(RolloutInferRequest, request_data) for request_data in requests_dicts] + + def _preprocess_inputs(self, inputs: DataType) -> DataType: + """Preprocess inputs before inference""" + processed_inputs = self._add_prompt_id_to_inputs(inputs) + for input_item in processed_inputs: + remove_response(input_item['messages']) + return processed_inputs + + def _add_prompt_id_to_inputs(self, inputs: DataType) -> DataType: + """Add unique prompt_id and request_id to each input""" + if not inputs: + return inputs + + all_messages = gather_object([inp['messages'] for inp in inputs]) + messages_to_prompt_id = {} + prompt_id_counter = 0 + + for messages in all_messages: + key = json.dumps(messages) + if key not in messages_to_prompt_id: + messages_to_prompt_id[key] = f'prompt_{prompt_id_counter}' + prompt_id_counter += 1 + + for input_item in inputs: + messages = input_item.get('messages') + input_item['prompt_id'] = messages_to_prompt_id[json.dumps(messages)] + input_item['request_id'] = f'chatcmpl-{str(uuid.uuid4().hex)}' + + return inputs + + def get_num_iters_per_step(self): + if hasattr(self, '_num_iters_per_step'): + return self._num_iters_per_step + # each rollout DP group will generate generation_batch_size / dp_size completions + dp_size = mpu.get_data_parallel_world_size() + completions_to_rollout = self.generation_batch_size // dp_size + # completions will be repeated num_generations times after + # so we need to divide num_iters_per_step by num_generations to get prompt batch size + prompts_to_rollout = completions_to_rollout // self.num_generations + # every iter will generate micro_batch_size prompts + num_iters_per_step = prompts_to_rollout // self.micro_batch_size + assert num_iters_per_step > 0, ( + f'num_iters_per_step={num_iters_per_step} <= 0. ' + f'This means no prompts will be generated' + f'generation_batch_size={self.generation_batch_size}, ' + f'data_parallel_world_size={mpu.get_data_parallel_world_size()}, ' + f'num_generations={self.num_generations}, ' + f'micro_batch_size={self.micro_batch_size}. ' + 'Please adjust these parameters so that ' + 'generation_batch_size // data_parallel_world_size // num_generations // micro_batch_size >= 1.') + self._num_iters_per_step = num_iters_per_step + return num_iters_per_step + + def get_local_rollout_batch(self, batch): + # repeat num_generations times + rollout_group = self._get_rollout_group() + global_rollout_batch = [deepcopy(item) for item in batch for _ in range(self.num_generations)] + # get local rollout data + rollout_rank = torch.distributed.get_rank(group=rollout_group) + rollout_group_size = torch.distributed.get_world_size(group=rollout_group) + + per_device_batch_size = self.per_device_generation_batch_size + assert rollout_group_size * per_device_batch_size == len(global_rollout_batch) + data_slice = slice(rollout_rank * per_device_batch_size, (rollout_rank + 1) * per_device_batch_size) + rollout_batch = global_rollout_batch[data_slice] + return rollout_batch + + @contextmanager + def _template_context(self, template: Template): + # The max_length for prompt and completion has already been restricted, so there is no need for max_length here. + max_length = template.max_length + template.max_length = None + try: + yield + finally: + template.max_length = max_length + + def _prepare_metrics(self): + args = self.args + from swift.utils import JsonlWriter + from collections import deque + self.log_completions = args.log_completions + self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.jsonl_writer = JsonlWriter(os.path.join(args.save, 'completions.jsonl')) + self.init_custom_metric = False + self._logs = { + 'prompt': deque(maxlen=args.generation_batch_size), + 'completion': deque(maxlen=args.generation_batch_size), + 'rewards': defaultdict(lambda: deque(maxlen=args.generation_batch_size)), + 'advantages': deque(maxlen=args.generation_batch_size), + } + if is_wandb_available(): + # when log profiling, the step is different from the step in the training loop + # here patch wandb log to pop the step argument + from wandb.sdk.wandb_run import Run + origin_log = Run.log + from functools import wraps + + @wraps(origin_log) + def log(self, data: dict[str, Any], step: int | None = None, commit: bool | None = None): + return origin_log(self, data, None, commit) + + Run.log = log + + self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} + + def _apply_chat_template_to_messages_list(self, messages_list: DataType): + prompts_text = [] + for messages in messages_list: + remove_response(messages) + template_inputs = TemplateInputs.from_dict({'messages': messages}) + res = self.template.encode(template_inputs) + prompts_text.append(self.template.safe_decode(res['input_ids'])) + return prompts_text + + def _set_inputs_system(self, batch: DataType) -> DataType: + """ + Ensures the system message is consistently set for all conversations in the batch. + + The template handles the user-defined system message. However, in server mode, + tokenization occurs on the rollout side. To prevent a mismatch where the system + message is set only during training but missing during rollout, this method + injects the default system message into each conversation if not already present. + + Args: + batch: A list of data items, each containing a 'messages' list. + + Returns: + The updated batch with the default system message inserted at the beginning + of each conversation that lacks one. + """ + + if self.vllm_mode != 'server': + return batch + + # Return early if no default system message is defined + if not self.template.template_meta.default_system: + return batch + + # Return early if all conversations already start with a system message + if all(data['messages'][0]['role'] == 'system' for data in batch): + return batch + + # Insert the default system message at the beginning of each conversation + # that doesn't already have one + for data in batch: + messages = data['messages'] + if messages[0]['role'] != 'system': + messages.insert(0, {'role': 'system', 'content': self.template.template_meta.default_system}) + + return batch diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index c004d5f91b..1c4efce1c9 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -1,11 +1,13 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from contextlib import contextmanager +import torch +import torch.distributed as dist from megatron.core import mpu from megatron.training import get_args, get_model from megatron.training.checkpointing import load_checkpoint from megatron.training.utils import unwrap_model -from torch.distributed.nn import all_reduce +from torch.distributed.nn import all_gather, all_reduce from transformers.utils import ContextManagers from swift.utils import get_logger @@ -54,11 +56,18 @@ def null_ref_context(self): for m in self.peft_models: m.set_adapter('default') - def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None): + def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None, per_token=False): args = get_args() per_token_logps = -output_tensor loss_mask = labels != -100 per_token_logps = per_token_logps * loss_mask + if per_token: + # In CP mode, all_gather and reconstruct full sequence + if args.context_parallel_size > 1: + per_token_logps = self._postprocess_packed_tensor_cp(per_token_logps, packed_seq_params, num_samples + or packed_seq_params.num_samples) + return per_token_logps + if num_samples is None: num_samples = packed_seq_params.num_samples * 2 cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples + 1] // args.context_parallel_size @@ -69,3 +78,59 @@ def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None): if args.context_parallel_size > 1: all_logps = all_reduce(all_logps, group=mpu.get_context_parallel_group()) return all_logps + + def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples): + """ + Generic method: In CP mode, all_gather and reconstruct full tensor sequences. + Works for both logps (float) and masks (bool/int). + + Args: + tensor: [1, packed_len/cp_size] - CP-split tensor (any dtype) + packed_seq_params: PackedSeqParams object + num_samples: Number of samples in the batch + + Returns: + output_full: [1, packed_len] - Full sequence tensor + """ + args = get_args() + cp_size = args.context_parallel_size + cp_rank = mpu.get_context_parallel_rank() + + # All-gather across CP ranks + output_list = [torch.empty_like(tensor) for _ in range(cp_size)] + torch.distributed.all_gather(output_list, tensor.contiguous(), group=mpu.get_context_parallel_group()) + output_list[cp_rank] = tensor + + # Reconstruct full sequence + # Shape: [1, packed_len/cp_size] -> [1, packed_len] + cu_seqlens_full = packed_seq_params.cu_seqlens_q + cu_seqlens_cp = cu_seqlens_full // cp_size + + # Calculate total packed length + total_packed_len = cu_seqlens_full[num_samples].item() + output_full = tensor.new_zeros(1, total_packed_len) + + # Reconstruct each sequence + for i in range(num_samples): + start_full = cu_seqlens_full[i].item() + end_full = cu_seqlens_full[i + 1].item() + seq_len = end_full - start_full + + # Length of each chunk after CP split + chunk_len = seq_len // cp_size + half_chunk = chunk_len // 2 + + # Concatenate from each CP rank's output (load-balanced split) + for j in range(cp_size): + o = output_list[j][0] + start_cp = cu_seqlens_cp[i].item() + + # Get two half chunks (CP's load-balanced split) + o0 = o[start_cp:start_cp + half_chunk] + o1 = o[start_cp + half_chunk:start_cp + chunk_len] + + # Place back to full sequence + output_full[0, start_full + j * half_chunk:start_full + (j + 1) * half_chunk] = o0 + output_full[0, end_full - (j + 1) * half_chunk:end_full - j * half_chunk] = o1 + + return output_full diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 6879fe23bf..594561cdd8 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -1,16 +1,26 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict +import functools +import gc +import time +from contextlib import contextmanager +from typing import Any, Dict, Optional import megatron.core import torch +from accelerate.utils import gather as hf_gather +from accelerate.utils import gather_object as hf_gather_object from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as DDP +from megatron.core.optimizer import ChainedOptimizer from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank -from megatron.training import get_args +from megatron.training import get_args, get_wandb_writer from packaging import version from swift.llm import get_packed_seq_params as _get_packed_seq_params from swift.llm import to_device +from swift.utils import get_logger +from swift.utils.torch_utils import empty_cache, get_current_device mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') @@ -105,6 +115,7 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): keys.append('decoder_input') else: keys.append('input_ids') + packed_seq_params = batch.get('packed_seq_params') if packed_seq_params is None: return mcore_get_batch_on_this_cp_rank(batch) @@ -117,3 +128,245 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, -1) return batch + + +@contextmanager +def profiling_context(trainer, name: str): + start_time = time.perf_counter() + yield + end_time = time.perf_counter() + duration = end_time - start_time + + profiling_metrics = {f'profiling/Time taken: {trainer.__class__.__name__}.{name}': duration} + wandb_writer = get_wandb_writer() + if wandb_writer and trainer.is_main_process: + wandb_writer.log(profiling_metrics) + + +def profiling_decorator(func): + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + with profiling_context(self, func.__name__): + return func(self, *args, **kwargs) + + return wrapper + + +def gather(tensor, group: Optional[torch.distributed.ProcessGroup] = None): + if group is None: + return hf_gather(tensor) + size = torch.distributed.get_world_size(group=group) + output = [torch.empty_like(tensor) for _ in range(size)] + torch.distributed.all_gather(output, tensor, group=group, async_op=False) + + return torch.cat(output, dim=0) + + +def gather_object(object: Any, group: Optional[torch.distributed.ProcessGroup] = None): + if group is None: + return hf_gather_object(object) + size = torch.distributed.get_world_size(group=group) + output_objects = [None for _ in range(size)] + torch.distributed.all_gather_object(output_objects, object, group=group) + return [x for y in output_objects for x in y] + + +# code borrowed from verl +@torch.no_grad() +def load_megatron_model_to_gpu(models, load_grad=True): + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + # sometimes, we don't want to load grad for pure inference + if load_grad: + buffer.grad_data.storage().resize_(buffer.grad_data_size) + buffer.grad_data.zero_() + + if buffer.param_data.storage().size() == 0: + buffer.param_data.storage().resize_(buffer.param_data_size) + # copy data from cpu to cuda + buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True) + else: + # we need this for ref module + device_id = get_current_device() + for _, param in model_chunk.named_parameters(): + param.data = param.data.to(device_id, non_blocking=True) + if param.grad is not None: + param.grad = param.grad.to(device_id, non_blocking=True) + gc.collect() + empty_cache() + + +@torch.no_grad() +def offload_megatron_model_to_cpu(models): + """ + In megatron, the model and optimizer storage are: + - bf16 parameter data chunked in model parallel group + - fp32 grad chunked in model parallel group + - fp32 main_parameter chunked in model and dp group + - fp32 optimizer state chunked in model and dp group + """ + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + # offload parameters + if buffer.param_data.storage().size() > 0: + buffer.param_data.cpu_data = buffer.param_data.data.cpu().pin_memory() + buffer.param_data_size = buffer.param_data.storage().size() + buffer.param_data.storage().resize_(0) + + assert buffer.param_data_size == buffer.param_data.cpu_data.storage().size() + + if buffer.grad_data.storage().size() > 0: + # if the grad_data size is already zero, we assume that it is already offloaded + buffer.grad_data_size = buffer.grad_data.storage().size() + buffer.grad_data.storage().resize_(0) + else: + # we need this for ref module + for _, param in model_chunk.named_parameters(): + param.data = param.data.to('cpu', non_blocking=True) + if param.grad is not None: + param.grad = param.grad.to('cpu', non_blocking=True) + gc.collect() + empty_cache() + + +@torch.no_grad() +def load_megatron_copy_params(optimizers): + """ + Load optimizer parameters back to GPU. Handles ChainedOptimizer. + + Args: + optimizers: Optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def load_tensor_to_gpu(tensor): + if tensor is None: + return + device_id = get_current_device() + tensor.data = tensor.data.to(device_id, non_blocking=True) + + def load_group_to_gpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + load_tensor_to_gpu(param) + else: + load_tensor_to_gpu(param_group) + else: + load_tensor_to_gpu(group) + + # Load all parameter groups to GPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, 'shard_fp32_from_float16_groups'): + load_group_to_gpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def offload_megatron_copy_params(optimizers): + """ + Offload optimizer parameters to CPU. Supports both Megatron optimizers + and `ChainedOptimizer`, which wraps a list of underlying optimizers. + + Args: + optimizers: The optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def offload_tensor_to_cpu(tensor): + if tensor is None: + return + tensor.data = tensor.data.to('cpu', non_blocking=True) + + def offload_group_to_cpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + offload_tensor_to_cpu(param) + else: + offload_tensor_to_cpu(param_group) + else: + offload_tensor_to_cpu(group) + + # Offload all parameter groups to CPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, 'shard_fp32_from_float16_groups'): + offload_group_to_cpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def load_megatron_optimizer(optimizers): + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + load_megatron_copy_params(_opt) + # if we are using HybridDeviceOptimizer, we need to only move gpu optimizer state to gpu + if hasattr(_opt.optimizer, '_move_new_state_to_right_device'): + _opt.optimizer._move_new_state_to_right_device() + else: + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if 'exp_avg' in v: + v['exp_avg'] = v['exp_avg'].to(get_current_device(), non_blocking=True) + if 'exp_avg_sq' in v: + v['exp_avg_sq'] = v['exp_avg_sq'].to(get_current_device(), non_blocking=True) + gc.collect() + empty_cache() + + +@torch.no_grad() +def offload_megatron_optimizer(optimizers): + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + offload_megatron_copy_params(_opt) + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if 'exp_avg' in v: + v['exp_avg'] = v['exp_avg'].to('cpu', non_blocking=True) + if 'exp_avg_sq' in v: + v['exp_avg_sq'] = v['exp_avg_sq'].to('cpu', non_blocking=True) + gc.collect() + empty_cache() + + +def log_gpu_memory(prefix: str = '', info_once: bool = False): + logger = get_logger() + log_msg = (f'{prefix} GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, ' + f'{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved') + if info_once: + logger.info_once(log_msg, hash_id=prefix) + else: + logger.info(log_msg) diff --git a/swift/megatron/tuners/lora.py b/swift/megatron/tuners/lora.py index 815fa63d5c..23fcd2b107 100644 --- a/swift/megatron/tuners/lora.py +++ b/swift/megatron/tuners/lora.py @@ -428,6 +428,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N def unmerge(self) -> None: """ Unmerge all merged adapter weights from the base weights. + This method reverses the merge operation by subtracting the LoRA delta weights from the base layer weights, restoring the original base weights. """ diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 968f5cedf5..42f6afdcdd 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -324,7 +324,7 @@ class GRPOArgumentsMixin(RolloutTrainerArgumentsMixin): # Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939 top_entropy_quantile: float = 1.0 - # GSPO https://www.arxiv.org/abs/2507.18071 + # GSPO https://arxiv.org/abs/2507.18071 importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token' # RLOO, REINFORCE++ diff --git a/swift/trainers/rlhf_trainer/__init__.py b/swift/trainers/rlhf_trainer/__init__.py index 8830dbac20..829dba091b 100644 --- a/swift/trainers/rlhf_trainer/__init__.py +++ b/swift/trainers/rlhf_trainer/__init__.py @@ -14,6 +14,7 @@ from .gkd_trainer import GKDTrainer from .rlhf_mixin import RLHFTrainerMixin from .utils import patch_lora_merge, patch_lora_unmerge, round_robin, _ForwardRedirection + from .vllm_client import VLLMClient else: _import_structure = { 'cpo_trainer': ['CPOTrainer'], @@ -26,6 +27,7 @@ 'gkd_trainer': ['GKDTrainer'], 'rlhf_mixin': ['RLHFTrainerMixin'], 'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'], + 'vllm_client': ['VLLMClient'], } import sys diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 1b81ffaeb0..53cc4b5c99 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1861,7 +1861,7 @@ def _prepare_algorithm_params(self): # Entropy Mask, https://arxiv.org/abs/2506.01939 self.top_entropy_quantile = args.top_entropy_quantile - # GSPO, https://www.arxiv.org/abs/2507.18071 + # GSPO, https://arxiv.org/abs/2507.18071 self.importance_sampling_level = args.importance_sampling_level # RLOO, diff --git a/swift/trainers/rlhf_trainer/rollout_mixin.py b/swift/trainers/rlhf_trainer/rollout_mixin.py index 17d8210021..3cb154a94c 100644 --- a/swift/trainers/rlhf_trainer/rollout_mixin.py +++ b/swift/trainers/rlhf_trainer/rollout_mixin.py @@ -637,6 +637,7 @@ def _fast_infer(self, inputs: DataType) -> DataType: if self.engine.inner_model_executor.is_sleeping: wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} + aggressive_empty_cache() self.engine.engine.wake_up(**kwargs) if self.state.global_step != self._last_loaded_step: diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index 81e614a89c..2de38550d6 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -133,9 +133,14 @@ def infer( results = [None] * self.num_servers errors = [None] * self.num_servers + if isinstance(request_config, RequestConfig): + request_config = asdict(request_config) def process_chunk(i, chunk): try: + if len(chunk) > 0 and isinstance(chunk[0], RolloutInferRequest): + chunk = [asdict(req) for req in chunk] + response = self.sessions[i].post( f'{self.base_urls[i]}/infer/', json={ @@ -208,7 +213,7 @@ def init_communicator(self, device: Union[int, str] = 0): pg = StatelessProcessGroup.create( host=self.hosts[i], port=self.group_ports[i], rank=rank, world_size=world_size) - comm = PyNcclCommunicator(pg, device=0) + comm = PyNcclCommunicator(pg, device=device) self.pynccl_comms.append(comm) atexit.register(self.close_communicator) From a041aaa9baf2743eaa12dab1f75c4ea3da110bf1 Mon Sep 17 00:00:00 2001 From: Jintao Date: Sun, 16 Nov 2025 01:16:54 +0800 Subject: [PATCH 11/29] [bugfix] fix megatron train_iters (#6611) --- swift/megatron/trainers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 2b6d938cc4..999745acbf 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -87,7 +87,7 @@ def initialize_megatron(*_args, **kwargs): args = get_args() data_parallel_size = mpu.get_data_parallel_world_size() step_batch_size = args.micro_batch_size * data_parallel_size - num_generations = args.num_generations if hasattr(args, 'num_generations') else 1 + num_generations = args.num_generations if args.rlhf_type == 'grpo' else 1 if args.train_iters is None and args.max_epochs is not None: if hasattr(train_dataset, '__len__'): dataset_sample = len(train_dataset) // step_batch_size * step_batch_size From 36658fee9fdc11628a94c0800bedc9776ec72478 Mon Sep 17 00:00:00 2001 From: Jintao Date: Sun, 16 Nov 2025 13:01:03 +0800 Subject: [PATCH 12/29] [bugfix] fix modelscope patch_hub (#6612) --- swift/megatron/init.py | 1 + 1 file changed, 1 insertion(+) diff --git a/swift/megatron/init.py b/swift/megatron/init.py index fcf602ed00..d2b7b7cb95 100644 --- a/swift/megatron/init.py +++ b/swift/megatron/init.py @@ -673,6 +673,7 @@ def _apply_rotary_pos_emb_thd( def _patch_megatron(): + os.environ.pop('VLLM_USE_MODELSCOPE', None) logging_level = logging.root.level _patch_flash_attn() _patch_transformer_engine() From 663d296da3767d50942cce199d82b82445665c07 Mon Sep 17 00:00:00 2001 From: Jintao Date: Sun, 16 Nov 2025 13:56:33 +0800 Subject: [PATCH 13/29] [template] support add_eos (#6613) --- swift/llm/template/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 676158ff83..5993dfee0c 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1139,7 +1139,11 @@ def _swift_encode(self, inputs: StdTemplateInputs): if isinstance(stop_word, str)) # self.is_training needed because we may want to continue generation from # the current response - if (self.is_training or self.task_type != 'causal_lm') and not sep_token and not endswith_stop_words: + add_eos = inputs.extra_kwargs.get('add_eos') + if add_eos is None: + add_eos = (self.is_training + or self.task_type != 'causal_lm') and not sep_token and not endswith_stop_words + if add_eos: extra_context_list = template_meta.suffix extra_context_type = ContextType.SUFFIX elif template_meta.response_prefix: From 04d012edf4e086a5289e4ac4c5c5ded8e0b01de9 Mon Sep 17 00:00:00 2001 From: Jintao Date: Sun, 16 Nov 2025 21:35:24 +0800 Subject: [PATCH 14/29] [dataset] refactor cached_dataset (#6561) --- .../Instruction/Command-line-parameters.md | 10 +-- .../Megatron-SWIFT/Command-line-parameters.md | 2 - .../Instruction/Command-line-parameters.md | 10 +-- .../Megatron-SWIFT/Command-line-parameters.md | 2 - examples/export/cached_dataset/dpo.sh | 69 +++++++++++++++++++ examples/export/cached_dataset/mcore.sh | 20 ++++-- examples/export/cached_dataset/pretrained.sh | 3 +- examples/export/cached_dataset/sft.sh | 3 +- examples/export/cached_dataset/vlm.sh | 9 +-- examples/megatron/moe/qwen3_moe_offload.sh | 4 +- examples/megatron/rlhf/dpo/moe.sh | 1 - swift/llm/argument/base_args/base_args.py | 11 ++- swift/llm/argument/base_args/data_args.py | 14 ++++ swift/llm/argument/deploy_args.py | 3 - swift/llm/argument/export_args.py | 1 + swift/llm/argument/infer_args.py | 12 +--- swift/llm/argument/train_args.py | 2 +- swift/llm/dataset/__init__.py | 3 +- swift/llm/dataset/utils.py | 17 +++-- swift/llm/export/cached_dataset.py | 9 ++- swift/llm/infer/__init__.py | 4 +- swift/llm/infer/infer.py | 14 +++- swift/llm/infer/utils.py | 22 ++++++ swift/llm/template/base.py | 29 +++++--- swift/llm/train/rlhf.py | 4 +- swift/llm/train/sft.py | 31 +++------ swift/megatron/argument/megatron_args.py | 2 - swift/megatron/argument/train_args.py | 2 - swift/megatron/model/mm_gpt_model.py | 7 +- 29 files changed, 215 insertions(+), 105 deletions(-) create mode 100644 examples/export/cached_dataset/dpo.sh diff --git a/docs/source/Instruction/Command-line-parameters.md b/docs/source/Instruction/Command-line-parameters.md index 2a2f9db2d4..5af5b0a894 100644 --- a/docs/source/Instruction/Command-line-parameters.md +++ b/docs/source/Instruction/Command-line-parameters.md @@ -56,6 +56,9 @@ - 子数据集: 该参数只有当dataset为ID或者文件夹时生效。若注册时指定了subsets,且只有一个子数据集,则默认选择注册时指定的子数据集,否则默认为'default'。你可以使用`/`来选择多个子数据集,例如:`:subset1/subset2`。你也可以使用'all'来选择注册时指定的所有子数据集,例如:`:all`。注册例子可以参考[这里](https://modelscope.cn/datasets/swift/garbage_competition)。 - 采样数量: 默认使用完整的数据集。你可以通过设置`#采样数`对选择的数据集进行采样。若采样数少于数据样本总数,则进行随机选择(不重复采样)。若采样数高于数据样本总数,则只额外随机采样`采样数%数据样本总数`的样本,数据样本重复采样`采样数//数据样本总数`次。注意:流式数据集(`--streaming true`)只进行顺序采样。若设置`--dataset_shuffle false`,则非流式数据集也进行顺序采样。 - 🔥val_dataset: 验证集id或路径的list。默认为`[]`。 +- 🔥cached_dataset: 使用缓存数据集(使用`swift export --to_cached_dataset true ...`命令产生),避免大型数据集训练/推理时,tokenize过程占用gpu时间。默认为`[]`。例子参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset)。 + - 提示:在"ms-swift>=3.11",cached_dataset只会在数据集中额外存储length字段(为避免存储压力),并过滤掉会报错的数据样本。在训练/推理时,支持`--max_length`参数进行超长数据过滤/裁剪以及`--packing`参数。数据实际预处理过程将在训练时同步进行,该过程和训练是重叠的,并不会影响训练速度。 + - cached_dataset在`ms-swift`和`Megatron-SWIFT`之间是通用的,且支持pt/sft/infer/rlhf(需"ms-swift>=3.11")。 - 🔥split_dataset_ratio: 不指定val_dataset时从训练集拆分验证集的比例,默认为0.,即不从训练集切分验证集。 - 注意:该参数在"ms-swift<3.6"的默认值为0.01。 - data_seed: 数据集随机种子,默认为42。 @@ -450,8 +453,6 @@ Vera使用`target_modules`、`target_regex`、`modules_to_save`三个参数, - packing_num_proc: packing的进程数,默认为1。需要注意的是,不同的`packing_num_proc`,最终形成的packed数据集是不同的。(该参数在流式packing时不生效) - lazy_tokenize: 是否使用lazy_tokenize。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(多模态模型则包括从磁盘中读取图片)。该参数默认为None,在LLM训练中默认为False,而MLLM训练默认为True,节约内存。 - 注意:若你要进行图像的数据增强,你需要将lazy_tokenize(或streaming)设置为True,并修改Template类中的encode方法。 -- cached_dataset: 训练中使用缓存数据集(使用`swift export --to_cached_dataset true ...`命令产生),避免大型数据集训练时,tokenize过程占用gpu时间。默认为`[]`。例子参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset)。 - - 注意:cached_dataset支持`--packing`,但不支持`--lazy_tokenize`和`--streaming`。 - use_logits_to_keep: 通过在`forward`中根据labels传入logits_to_keep,减少无效logits的计算与存储,从而减少显存占用并加快训练速度。默认为None,进行自动选择。 - acc_strategy: 训练和验证时计算acc的策略。可选为`seq`和`token`级别的acc,默认为`token`。 - max_new_tokens: 覆盖生成参数。predict_with_generate=True时的最大生成token数量,默认64。 @@ -700,8 +701,9 @@ App参数继承于[部署参数](#部署参数), [Web-UI参数](#Web-UI参数) - max_length: 校准集的max_length, 默认值2048。 - quant_batch_size: 量化batch_size,默认为1。 - group_size: 量化group大小,默认为128。 -- to_cached_dataset: 提前对数据集进行tokenize并导出,默认为False。例子参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset)。 - - 注意:数据packing在训练时进行,而不在此处。 +- to_cached_dataset: 提前对数据集进行tokenize并导出,默认为False。例子参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset)。更多介绍请查看`cached_dataset`。 + - 提示:cached_dataset需提前区分好训练集和验证集。你可以通过`--split_dataset_ratio`或者`--val_dataset`指定验证集内容。 +- template_mode: 用于支持对`swift rlhf`训练的`cached_dataset`功能。该参数只在`--to_cached_dataset true`时生效。可选项包括: 'train'、'rlhf'和'kto'。其中`swift pt/sft`使用'train',`swift rlhf --rlhf_type kto`使用'kto',其他rlhf算法使用'rlhf'。注意:当前'gkd', 'ppo', 'grpo'算法不支持`cached_dataset`功能。默认为'train'。 - to_ollama: 产生ollama所需的Modelfile文件。默认为False。 - 🔥to_mcore: HF格式权重转成Megatron格式。默认为False。 - to_hf: Megatron格式权重转成HF格式。默认为False。 diff --git a/docs/source/Megatron-SWIFT/Command-line-parameters.md b/docs/source/Megatron-SWIFT/Command-line-parameters.md index 5c75aa28c0..f80c6e1812 100644 --- a/docs/source/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source/Megatron-SWIFT/Command-line-parameters.md @@ -282,8 +282,6 @@ Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用 - 注意:因为流式数据集无法获得其长度,因此需要设置`--train_iters`参数。设置`max_epochs`参数确保训练到对应epochs时退出训练,并对权重进行验证和保存。 - 注意:流式数据集可以跳过预处理等待,将预处理时间与训练时间重叠。流式数据集的预处理只在rank0上进行,并通过数据分发的方式同步到其他进程,**其通常效率不如非流式数据集采用的数据分片读取方式**。当训练的world_size较大时,预处理和数据分发将成为训练瓶颈。 - lazy_tokenize: 是否使用lazy_tokenize。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(多模态模型则包括从磁盘中读取图片)。该参数默认为None,在LLM训练中默认为False,而MLLM训练默认为True,节约内存。 -- cached_dataset: 训练中使用缓存数据集(使用`swift export --to_cached_dataset true ...`命令产生),避免大型数据集训练时,tokenize过程占用gpu时间。默认为`[]`。例子参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset)。 - - 注意:cached_dataset支持`--packing`,但不支持`--lazy_tokenize`和`--streaming`。cached_dataset暂不支持CP。 - enable_dft_loss: 是否在SFT训练中使用[DFT](https://arxiv.org/abs/2508.05629) (Dynamic Fine-Tuning) loss,默认为False。 - enable_channel_loss: 启用channel loss,默认为`False`。你需要在数据集中准备"channel"字段,ms-swift会根据该字段分组统计loss(若未准备"channel"字段,则归为默认`None` channel)。数据集格式参考[channel loss](../Customization/Custom-dataset.md#channel-loss)。channel loss兼容packing/padding_free/loss_scale等技术。 - new_special_tokens: 需要新增的特殊tokens。默认为`[]`。例子参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/lora/new_special_tokens.sh)。 diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index c7dd3d865a..570ffd9717 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -55,6 +55,9 @@ The command-line arguments will be introduced in four categories: basic argument - Subset: This parameter is only effective when the dataset is a dataset ID or a folder. If subsets were specified during registration and only one exists, that subset is selected by default; otherwise, the default subset `'default'` is used. You can select multiple subsets using `/`, e.g., `:subset1/subset2`. You can also use `'all'` to select all registered subsets, e.g., `:all`. See an example of registration [here](https://modelscope.cn/datasets/swift/garbage_competition). - Sampling count: By default, the full dataset is used. You can sample the dataset by specifying `#sample_count`. If the sample count is less than the total number of samples, random sampling without replacement is performed. If the sample count exceeds the total, the dataset is repeated `sample_count // total_samples` times, with an additional `sample_count % total_samples` samples randomly sampled. Note: For streaming datasets (`--streaming true`), only sequential sampling is performed. If `--dataset_shuffle false` is set, non-streaming datasets also use sequential sampling. - 🔥val_dataset: A list of validation dataset IDs or paths. Default is `[]`. +- 🔥cached_dataset: Use cached dataset (generated using `swift export --to_cached_dataset true ...` command) to avoid GPU time consumed by the tokenization process during large dataset training/inference. Default is `[]`. For examples, refer to [here](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset). + - Note: In "ms-swift>=3.11", cached_dataset only stores an additional length field in the dataset (to avoid storage pressure) and filters out data samples that would cause errors. During training/inference, the `--max_length` parameter is supported for filtering/truncating excessively long data and the `--packing` parameter is supported. The actual data preprocessing process occurs synchronously during training and overlaps with the training process, which does not affect training speed. + - cached_dataset is compatible between `ms-swift` and `Megatron-SWIFT`, and supports pt/sft/infer/rlhf (requires "ms-swift>=3.11"). - 🔥split_dataset_ratio: The ratio for splitting a validation set from the training set when `val_dataset` is not specified. Default is `0.`, meaning no splitting occurs. - Note: In "ms-swift<3.6", the default value was `0.01`. - data_seed: Random seed for dataset operations. Default is `42`. @@ -458,8 +461,6 @@ Training arguments include the [base arguments](#base-arguments), [Seq2SeqTraine - packing_num_proc: Number of processes for packing, default is 1. Note that different values of `packing_num_proc` will result in different packed datasets. (This parameter does not take effect during streaming packing) - lazy_tokenize: Whether to use lazy tokenization. If set to `False`, all dataset samples will be tokenized (and for multimodal models, images will be loaded from disk) before training begins. Default is `None`: in LLM training, it defaults to `False`; in MLLM training, it defaults to `True` to save memory. - Note: If you want to perform image data augmentation, you need to set `lazy_tokenize` (or `streaming`) to True and modify the `encode` method in the Template class. -- cached_dataset: Use a cached dataset (generated with `swift export --to_cached_dataset true ...`) during training to avoid GPU time spent on tokenizing large datasets. Default is `[]`. Example: [here](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset). - - Note: cached_dataset supports `--packing` but does not support `--lazy_tokenize` or `--streaming`. - use_logits_to_keep: Pass `logits_to_keep` in the `forward` method based on labels to reduce the computation and storage of unnecessary logits, thereby reducing memory usage and accelerating training. The default is `None`, which enables automatic selection. - acc_strategy: Strategy for calculating accuracy during training and validation. Options are `seq`-level and `token`-level accuracy, with `token` as the default. - max_new_tokens: Generation parameter override. The maximum number of tokens to generate when `predict_with_generate=True`, defaulting to 64. @@ -718,8 +719,9 @@ Export Arguments include the [basic arguments](#base-arguments) and [merge argum - max_length: Max length for the calibration set, default value is 2048. - quant_batch_size: Quantization batch size, default is 1. - group_size: Group size for quantization, default is 128. -- to_cached_dataset: pre-tokenize the dataset and export it in advance, default is False. See the example [here](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset). - - Note: data packing is performed during training, not in this step. +- to_cached_dataset: pre-tokenize the dataset and export it in advance, default is False. See the example [here](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset). For more information, please refer to cached_dataset. + - Note: cached_dataset requires the training set and validation set to be distinguished in advance. You can specify the validation set content through `--split_dataset_ratio` or `--val_dataset`. +- template_mode: Used to support the `cached_dataset` feature for `swift rlhf` training. This parameter only takes effect when `--to_cached_dataset true` is set. Available options include: 'train', 'rlhf', and 'kto'. Among them, `swift pt/sft` uses 'train', `swift rlhf --rlhf_type kto` uses 'kto', and other rlhf algorithms use 'rlhf'. Note: Currently, 'gkd', 'ppo', and 'grpo' algorithms do not support the `cached_dataset` feature. Default is 'train'. - to_ollama: Generate the Modelfile required by Ollama. Default is False. - 🔥to_mcore: Convert weights from HF format to Megatron format. Default is False. - to_hf: Convert weights from Megatron format to HF format. Default is False. diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 446916c1f7..8094db9320 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -300,8 +300,6 @@ Megatron training parameters are inherited from Megatron parameters and basic pa - Note: Since the length of a streaming dataset cannot be determined, the `--train_iters` parameter must be set. Also set the `max_epochs` parameter to ensure training exits after the specified number of epochs, and to validate and save the model weights accordingly. - Note: Streaming datasets can skip preprocessing wait time by overlapping preprocessing with training. Preprocessing for streaming datasets is performed only on rank 0 and then synchronized to other processes via data distribution. **This is generally less efficient than the data sharding approach used in non-streaming datasets.** When the training world_size is large, preprocessing and data distribution can become a training bottleneck. - lazy_tokenize: Whether to use lazy tokenization. If set to `False`, all dataset samples will be tokenized (and for multimodal models, images will be loaded from disk) before training begins. Default is `None`: in LLM training, it defaults to `False`; in MLLM training, it defaults to `True` to save memory. -- cached_dataset: Use a cached dataset (generated with `swift export --to_cached_dataset true ...`) during training to avoid GPU time spent on tokenizing large datasets. Default is `[]`. Example: [here](https://github.com/modelscope/ms-swift/tree/main/examples/export/cached_dataset). - - Note: cached_dataset supports `--packing` but does not support `--lazy_tokenize` or `--streaming`. Cached dataset is currently not supported for CP. - enable_dft_loss: Whether to use [DFT](https://arxiv.org/abs/2508.05629) (Dynamic Fine-Tuning) loss in SFT training, default is False. - enable_channel_loss: Enable channel-based loss. Default is `False`. Requires a `"channel"` field in the dataset. ms-swift groups and computes loss by this field (samples without `"channel"` are grouped into the default `None` channel). Dataset format reference: [channel loss](../Customization/Custom-dataset.md#channel-loss). Channel loss is compatible with packing, padding_free, and loss_scale techniques. - new_special_tokens: List of additional special tokens to be added. Default is `[]`. Example usage can be found [here](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/lora/new_special_tokens.sh). diff --git a/examples/export/cached_dataset/dpo.sh b/examples/export/cached_dataset/dpo.sh new file mode 100644 index 0000000000..8964e3bfb5 --- /dev/null +++ b/examples/export/cached_dataset/dpo.sh @@ -0,0 +1,69 @@ +# ms-swift>=3.11 +OMP_NUM_THREADS=14 \ +IMAGE_MAX_TOKEN_NUM=1024 \ +VIDEO_MAX_TOKEN_NUM=128 \ +FPS_MAX_FRAMES=16 \ +swift export \ + --model Qwen/Qwen3-VL-30B-A3B-Instruct \ + --dataset swift/RLAIF-V-Dataset \ + --split_dataset_ratio 0.01 \ + --dataset_num_proc 8 \ + --to_cached_dataset true \ + --template_mode rlhf \ + --output_dir ./qwen3_vl_cached_dataset + + +# 16s/it; 8 * 65GiB +PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ +NPROC_PER_NODE=8 \ +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ +IMAGE_MAX_TOKEN_NUM=1024 \ +VIDEO_MAX_TOKEN_NUM=128 \ +FPS_MAX_FRAMES=16 \ +megatron rlhf \ + --rlhf_type dpo \ + --model Qwen/Qwen3-VL-30B-A3B-Instruct \ + --load_safetensors true \ + --save_safetensors true \ + --cached_dataset qwen3_vl_cached_dataset \ + --load_from_cache_file true \ + --train_type full \ + --tensor_model_parallel_size 4 \ + --expert_tensor_parallel_size 1 \ + --pipeline_model_parallel_size 2 \ + --expert_model_parallel_size 4 \ + --moe_permute_fusion true \ + --moe_grouped_gemm true \ + --moe_shared_expert_overlap true \ + --moe_aux_loss_coeff 1e-6 \ + --micro_batch_size 1 \ + --global_batch_size 4 \ + --packing true \ + --recompute_granularity full \ + --recompute_method uniform \ + --recompute_num_layers 1 \ + --finetune true \ + --cross_entropy_loss_fusion true \ + --lr 1e-5 \ + --lr_warmup_fraction 0.05 \ + --min_lr 1e-6 \ + --save megatron_output/Qwen3-VL-30B-A3B-Instruct \ + --eval_interval 500 \ + --save_interval 500 \ + --max_length 8192 \ + --max_epochs 1 \ + --num_workers 8 \ + --dataset_num_proc 8 \ + --no_save_optim true \ + --no_save_rng true \ + --sequence_parallel true \ + --freeze_llm false \ + --freeze_vit true \ + --freeze_aligner true \ + --optimizer_cpu_offload true \ + --use_precision_aware_optimizer true \ + --optimizer_offload_fraction 0.4 \ + --attention_backend flash \ + --rpo_alpha 0.1 \ + --beta 0.1 \ + --loss_type sigmoid diff --git a/examples/export/cached_dataset/mcore.sh b/examples/export/cached_dataset/mcore.sh index d78d2d4f2c..bcbf1f1e33 100644 --- a/examples/export/cached_dataset/mcore.sh +++ b/examples/export/cached_dataset/mcore.sh @@ -1,8 +1,7 @@ -# Note: cached_dataset does not support CP temporarily. +# ms-swift>=3.11 swift export \ --model Qwen/Qwen3-30B-A3B-Base \ --dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT' \ - --max_length 8192 \ --split_dataset_ratio 0.01 \ --dataset_num_proc 64 \ --to_cached_dataset true \ @@ -14,18 +13,20 @@ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 \ megatron sft \ - --load Qwen3-30B-A3B-Base-mcore \ + --model Qwen/Qwen3-30B-A3B-Base \ + --load_safetensors true \ + --save_safetensors true \ + --merge_lora false \ --cached_dataset './qwen3_cached_dataset' \ --train_type lora \ --lora_rank 32 \ --lora_alpha 64 \ --target_modules all-linear \ - --split_dataset_ratio 0.01 \ --moe_permute_fusion true \ --expert_model_parallel_size 4 \ --moe_grouped_gemm true \ --moe_shared_expert_overlap true \ - --moe_aux_loss_coeff 1e-3 \ + --moe_aux_loss_coeff 1e-6 \ --micro_batch_size 1 \ --global_batch_size 16 \ --recompute_granularity full \ @@ -48,3 +49,12 @@ megatron sft \ --no_save_rng true \ --sequence_parallel true \ --attention_backend flash + + +CUDA_VISIBLE_DEVICES=0 \ +swift infer \ + --adapters megatron_output/Qwen3-30B-A3B-Base/vx-xxx/checkpoint-xxx \ + --load_data_args true \ + --attn_impl flash_attn \ + --stream true \ + --max_new_tokens 512 diff --git a/examples/export/cached_dataset/pretrained.sh b/examples/export/cached_dataset/pretrained.sh index 284cad980d..1f22a0030d 100644 --- a/examples/export/cached_dataset/pretrained.sh +++ b/examples/export/cached_dataset/pretrained.sh @@ -1,7 +1,7 @@ +# ms-swift>=3.11 swift export \ --model Qwen/Qwen2.5-7B \ --dataset 'AI-ModelScope/ruozhiba:all' \ - --max_length 8192 \ --dataset_num_proc 64 \ --to_cached_dataset true \ --split_dataset_ratio 0.01 \ @@ -17,7 +17,6 @@ swift pt \ --train_type full \ --cached_dataset './pretrain_cached_dataset' \ --num_train_epochs 3 \ - --split_dataset_ratio 0.01 \ --torch_dtype bfloat16 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ diff --git a/examples/export/cached_dataset/sft.sh b/examples/export/cached_dataset/sft.sh index c4928ebde7..d6ae46fda7 100644 --- a/examples/export/cached_dataset/sft.sh +++ b/examples/export/cached_dataset/sft.sh @@ -1,7 +1,7 @@ +# ms-swift>=3.11 swift export \ --model Qwen/Qwen2.5-7B \ --dataset 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT' \ - --max_length 8192 \ --dataset_num_proc 64 \ --split_dataset_ratio 0.01 \ --to_cached_dataset true \ @@ -16,7 +16,6 @@ swift sft \ --train_type full \ --cached_dataset './qwen2_5_cached_dataset' \ --num_train_epochs 3 \ - --split_dataset_ratio 0.01 \ --torch_dtype bfloat16 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ diff --git a/examples/export/cached_dataset/vlm.sh b/examples/export/cached_dataset/vlm.sh index 9f00199ed9..20b4851f93 100644 --- a/examples/export/cached_dataset/vlm.sh +++ b/examples/export/cached_dataset/vlm.sh @@ -1,3 +1,4 @@ +# ms-swift>=3.11 OMP_NUM_THREADS=14 \ MAX_PIXELS=1003520 \ VIDEO_MAX_PIXELS=50176 \ @@ -7,11 +8,9 @@ swift export \ --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#10000' \ 'AI-ModelScope/LaTeX_OCR:human_handwrite#5000' \ 'speech_asr/speech_asr_aishell1_trainsets:validation#5000' \ - --max_length 4096 \ --split_dataset_ratio 0.01 \ --dataset_num_proc 16 \ --to_cached_dataset true \ - --lazy_tokenize false \ --output_dir ./qwen2_5_omni_cached_dataset # 4 * 70GiB @@ -27,7 +26,6 @@ swift sft \ --train_type full \ --cached_dataset './qwen2_5_omni_cached_dataset' \ --num_train_epochs 1 \ - --split_dataset_ratio 0.01 \ --torch_dtype bfloat16 \ --per_device_train_batch_size 1 \ --per_device_eval_batch_size 1 \ @@ -59,11 +57,8 @@ FPS_MAX_FRAMES=12 \ ENABLE_AUDIO_OUTPUT=0 \ swift infer \ --model output/Qwen2.5-Omni-7B/vx-xxx/checkpoint-xxx \ - --dataset 'AI-ModelScope/alpaca-gpt4-data-zh#10000' \ - 'AI-ModelScope/LaTeX_OCR:human_handwrite#5000' \ - 'speech_asr/speech_asr_aishell1_trainsets:validation#5000' \ + --load_data_args true \ --max_length 4096 \ - --split_dataset_ratio 0.01 \ --attn_impl flash_attn \ --stream true \ --temperature 0 \ diff --git a/examples/megatron/moe/qwen3_moe_offload.sh b/examples/megatron/moe/qwen3_moe_offload.sh index 30e43bba9a..97b81fc130 100644 --- a/examples/megatron/moe/qwen3_moe_offload.sh +++ b/examples/megatron/moe/qwen3_moe_offload.sh @@ -3,7 +3,9 @@ PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=4 \ CUDA_VISIBLE_DEVICES=0,1,2,3 \ megatron sft \ - --load Qwen3-30B-A3B-Base-mcore \ + --model Qwen/Qwen3-30B-A3B-Base \ + --load_safetensors true \ + --save_safetensors true \ --dataset 'liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT' \ --load_from_cache_file true \ --split_dataset_ratio 0.01 \ diff --git a/examples/megatron/rlhf/dpo/moe.sh b/examples/megatron/rlhf/dpo/moe.sh index 318fee9c3e..d82575da05 100644 --- a/examples/megatron/rlhf/dpo/moe.sh +++ b/examples/megatron/rlhf/dpo/moe.sh @@ -1,5 +1,4 @@ # 8 * 46GiB; 13s/it -# Note: "ms-swift<3.8" does not support DPO packing; please remove --packing true. PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \ NPROC_PER_NODE=8 \ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 0695f4d68c..1054497e4a 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -86,7 +86,6 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat packing_length: Optional[int] = None packing_num_proc: int = 1 lazy_tokenize: Optional[bool] = None - cached_dataset: List[str] = field(default_factory=list) custom_register_path: List[str] = field(default_factory=list) # .py # hub use_hf: bool = False @@ -106,8 +105,10 @@ def _prepare_training_args(self, training_args: Dict[str, Any]) -> None: def _init_lazy_tokenize(self): if self.lazy_tokenize is None: - if (self.model_meta is not None and self.model_meta.is_multimodal and not self.streaming - and not self.packing): + if self.cached_dataset: + self.lazy_tokenize = False + elif (self.model_meta is not None and self.model_meta.is_multimodal and not self.streaming + and not self.packing): self.lazy_tokenize = True else: self.lazy_tokenize = False @@ -162,7 +163,6 @@ def __post_init__(self): self._init_custom_register() self._import_external_plugins() self._init_model_kwargs() - self._init_stream() # The Seq2SeqTrainingArguments has a property called world_size, which cannot be assigned a value. self.rank, self.local_rank, self.global_world_size, self.local_world_size = get_dist_setting() logger.info(f'rank: {self.rank}, local_rank: {self.local_rank}, ' @@ -176,12 +176,11 @@ def __post_init__(self): TemplateArguments.__post_init__(self) DataArguments.__post_init__(self) RayArguments.__post_init__(self) + self._init_stream() if self.max_length is None and self.model_info is not None: self.max_length = self.model_info.max_model_len if self.packing and self.packing_length is None: self.packing_length = self.max_length - if isinstance(self.cached_dataset, str): - self.cached_dataset = [self.cached_dataset] self._init_lazy_tokenize() self.hub = get_hub(self.use_hf) if self.hub.try_login(self.hub_token): diff --git a/swift/llm/argument/base_args/data_args.py b/swift/llm/argument/base_args/data_args.py index 2e706203ac..c03c15f576 100644 --- a/swift/llm/argument/base_args/data_args.py +++ b/swift/llm/argument/base_args/data_args.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os from dataclasses import dataclass, field from typing import List, Literal, Optional, Union @@ -29,6 +30,7 @@ class DataArguments: # dataset_id or dataset_dir or dataset_path dataset: List[str] = field(default_factory=list) val_dataset: List[str] = field(default_factory=list) + cached_dataset: List[str] = field(default_factory=list) split_dataset_ratio: float = 0. data_seed: int = 42 @@ -68,6 +70,18 @@ def __post_init__(self): msg = 'args.streaming is True' logger.info(f'Because {msg}, setting split_dataset_ratio: {self.split_dataset_ratio}') self._init_custom_dataset_info() + if isinstance(self.cached_dataset, str): + self.cached_dataset = [self.cached_dataset] + self._init_val_dataset_exists() + + def _init_val_dataset_exists(self): + exists = self.dataset and self.split_dataset_ratio > 0 or self.val_dataset + if not exists and self.cached_dataset: + for dataset in self.cached_dataset: + if os.path.exists(os.path.join(dataset, 'val')): + exists = True + break + self._val_dataset_exists = exists def get_dataset_kwargs(self): return { diff --git a/swift/llm/argument/deploy_args.py b/swift/llm/argument/deploy_args.py index 81c12b5775..2cff455662 100644 --- a/swift/llm/argument/deploy_args.py +++ b/swift/llm/argument/deploy_args.py @@ -70,9 +70,6 @@ def _init_ckpt_dir(self, adapters=None): def _init_stream(self): return BaseArguments._init_stream(self) - def _init_eval_human(self): - pass - def _init_result_path(self, folder_name: str) -> None: if folder_name == 'infer_result': folder_name = 'deploy_result' diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index f4cf8d5fc1..921163e4ae 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -43,6 +43,7 @@ class ExportArguments(MergeArguments, BaseArguments): # cached_dataset to_cached_dataset: bool = False + template_mode: Literal['train', 'rlhf', 'kto'] = 'train' # ollama to_ollama: bool = False diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index a7c00174f1..cbf746d4dc 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -127,7 +127,8 @@ def _init_result_path(self, folder_name: str) -> None: logger.info(f'args.result_path: {self.result_path}') def _init_stream(self): - self.eval_human = not (self.dataset and self.split_dataset_ratio > 0 or self.val_dataset) + self.eval_human = not self._val_dataset_exists + logger.info(f'Setting args.eval_human: {self.eval_human}') if self.stream is None: self.stream = self.eval_human if self.stream and self.num_beams != 1: @@ -148,13 +149,4 @@ def __post_init__(self) -> None: BaseArguments.__post_init__(self) VllmArguments.__post_init__(self) self._init_result_path('infer_result') - self._init_eval_human() self._init_ddp() - - def _init_eval_human(self): - if len(self.dataset) == 0 and len(self.val_dataset) == 0: - eval_human = True - else: - eval_human = False - self.eval_human = eval_human - logger.info(f'Setting args.eval_human: {self.eval_human}') diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 776a58091b..bcd439b8ce 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -170,7 +170,7 @@ def __post_init__(self) -> None: if getattr(self, 'accelerator_config', None) is None: self.accelerator_config = {'dispatch_batches': False} - if self.split_dataset_ratio == 0 and not self.val_dataset and not self.eval_dataset: + if not (self.eval_dataset or self._val_dataset_exists): self.eval_strategy = 'no' self.training_args = TrainerFactory.get_training_args(self) self.training_args.remove_unused_columns = False diff --git a/swift/llm/dataset/__init__.py b/swift/llm/dataset/__init__.py index c397430b95..1c989f482f 100644 --- a/swift/llm/dataset/__init__.py +++ b/swift/llm/dataset/__init__.py @@ -9,7 +9,8 @@ from .preprocessor import (AlpacaPreprocessor, AutoPreprocessor, MessagesPreprocessor, ResponsePreprocessor, RowPreprocessor) from .register import DATASET_MAPPING, DatasetMeta, SubsetDataset, register_dataset, register_dataset_info -from .utils import EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, sample_dataset +from .utils import (AddLengthPreprocessor, EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, + sample_dataset) datasets.fingerprint.get_temporary_cache_files_directory = get_temporary_cache_files_directory datasets.arrow_dataset.get_temporary_cache_files_directory = get_temporary_cache_files_directory diff --git a/swift/llm/dataset/utils.py b/swift/llm/dataset/utils.py index a0ccd8338b..159ccf8953 100644 --- a/swift/llm/dataset/utils.py +++ b/swift/llm/dataset/utils.py @@ -309,14 +309,17 @@ def __iter__(self): class EncodePreprocessor(RowPreprocessor): - def __init__(self, template: 'Template', pre_tokenize: bool = False): + def __init__(self, template: 'Template'): super().__init__() self.template = template - self.pre_tokenize = pre_tokenize def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: - encoded = self.template.encode(row, return_length=True) - if self.pre_tokenize: - row['length'] = encoded['length'] - encoded = row - return encoded + return self.template.encode(row, return_length=True) + + +class AddLengthPreprocessor(EncodePreprocessor): + + def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]: + encoded = super().preprocess(row) + row['length'] = encoded['length'] + return row diff --git a/swift/llm/export/cached_dataset.py b/swift/llm/export/cached_dataset.py index 2bb780c45f..a9ec0e3a62 100644 --- a/swift/llm/export/cached_dataset.py +++ b/swift/llm/export/cached_dataset.py @@ -17,6 +17,7 @@ class ExportCachedDataset(SwiftSft): def __init__(self, args: Optional[Union[List[str], ExportArguments]] = None) -> None: super(SwiftSft, self).__init__(args) + args = self.args self.train_msg = {} # dummy template_cls = TEMPLATE_MAPPING[args.template].template_cls if template_cls and template_cls.use_model: @@ -26,11 +27,13 @@ def __init__(self, args: Optional[Union[List[str], ExportArguments]] = None) -> with torch.device('meta'): self._prepare_model_tokenizer(**kwargs) self._prepare_template() + self.template.set_mode(args.template_mode) + + def _post_process_datasets(self, datasets: List) -> List: + return datasets def main(self): - train_dataset, val_dataset = self._get_dataset() - train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset) - self._show_dataset(train_dataset, val_dataset) + train_dataset, val_dataset = self._prepare_dataset() train_dataset.save_to_disk(os.path.join(self.args.output_dir, 'train')) if val_dataset is not None: val_dataset.save_to_disk(os.path.join(self.args.output_dir, 'val')) diff --git a/swift/llm/infer/__init__.py b/swift/llm/infer/__init__.py index dceb30bac9..b9e7bd0be3 100644 --- a/swift/llm/infer/__init__.py +++ b/swift/llm/infer/__init__.py @@ -8,7 +8,7 @@ from .rollout import rollout_main from .deploy import deploy_main, SwiftDeploy, run_deploy from .protocol import RequestConfig, Function - from .utils import prepare_model_template + from .utils import prepare_model_template, get_cached_dataset from .infer_engine import (InferEngine, VllmEngine, LmdeployEngine, SglangEngine, PtEngine, InferClient, prepare_generation_config, AdapterRequest, BaseInferEngine) else: @@ -17,7 +17,7 @@ 'infer': ['infer_main', 'SwiftInfer'], 'deploy': ['deploy_main', 'SwiftDeploy', 'run_deploy'], 'protocol': ['RequestConfig', 'Function'], - 'utils': ['prepare_model_template'], + 'utils': ['prepare_model_template', 'get_cached_dataset'], 'infer_engine': [ 'InferEngine', 'VllmEngine', 'LmdeployEngine', 'SglangEngine', 'PtEngine', 'InferClient', 'prepare_generation_config', 'AdapterRequest', 'BaseInferEngine' diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 639b76c846..c28deb877c 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -8,9 +8,10 @@ from swift.llm import InferArguments, InferRequest, SwiftPipeline, load_dataset, prepare_model_template, sample_dataset from swift.plugin import InferStats, MeanMetric, compute_rouge_bleu from swift.utils import JsonlWriter, get_dist_setting, get_logger, is_dist, is_master, read_from_jsonl +from ..dataset.loader import DatasetLoader from .infer_engine import AdapterRequest, PtEngine from .protocol import RequestConfig -from .utils import InferCliState +from .utils import InferCliState, get_cached_dataset logger = get_logger() @@ -178,16 +179,23 @@ def infer_cli(self) -> List[Dict[str, Any]]: def _prepare_val_dataset(self) -> HfDataset: args = self.args dataset_kwargs = args.get_dataset_kwargs() + if args.cached_dataset: + _, val_datasets = get_cached_dataset(self.args) + else: + val_datasets = [] if len(args.val_dataset) > 0: _, val_dataset = load_dataset( args.val_dataset, split_dataset_ratio=1.0, shuffle=args.val_dataset_shuffle, **dataset_kwargs) - else: + val_datasets.append(val_dataset) + elif args.dataset: _, val_dataset = load_dataset( args.dataset, split_dataset_ratio=args.split_dataset_ratio, shuffle=args.dataset_shuffle, **dataset_kwargs) - assert val_dataset is not None + val_datasets.append(val_dataset) + assert len(val_datasets) > 0 + val_dataset = DatasetLoader._concat_datasets(val_datasets) val_dataset = sample_dataset(val_dataset, args.val_dataset_sample, args.dataset_shuffle, self.random_state) return val_dataset diff --git a/swift/llm/infer/utils.py b/swift/llm/infer/utils.py index 49bc8b53a2..75659ed4ef 100644 --- a/swift/llm/infer/utils.py +++ b/swift/llm/infer/utils.py @@ -1,9 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os import re from copy import deepcopy from dataclasses import dataclass, field from typing import List, Literal, Optional +from datasets import load_from_disk + from swift.llm.utils import update_generation_config_eos_token from swift.plugin import extra_tuners from swift.tuners import Swift @@ -151,3 +154,22 @@ def prepare_model_template(args, **kwargs): model = prepare_adapter(args, model, adapters=adapters) update_generation_config_eos_token(model.generation_config, template) return model, template + + +def _select_dataset(dataset, max_length): + idxs = [i for i, length in enumerate(dataset['length']) if length <= max_length] + new_dataset = dataset.select(idxs) + if len(new_dataset) < len(dataset): + logger.info(f'Dataset filtered, origin length: {len(dataset)}, filtered dataset length: {len(new_dataset)}') + return new_dataset + + +def get_cached_dataset(args): + train_datasets, val_datasets = [], [] + for cached_dataset in args.cached_dataset: + train_path = os.path.join(cached_dataset, 'train') + val_path = os.path.join(cached_dataset, 'val') + train_datasets.append(_select_dataset(load_from_disk(train_path), args.max_length)) + if os.path.exists(val_path): + val_datasets.append(_select_dataset(load_from_disk(val_path), args.max_length)) + return train_datasets, val_datasets diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 5993dfee0c..acc295849d 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -91,7 +91,7 @@ def __init__( from .template_meta import TemplateMeta from swift.plugin import agent_templates, loss_scale_map self._processor_inited = False - self._version = 'v3' # Avoid compatibility issues caused by load_from_cache_file caching. + self._version = 'v4' # Avoid compatibility issues caused by load_from_cache_file caching. self.max_length = max_length self.model = None self.dummy_model = None @@ -1217,7 +1217,6 @@ def _encode_truncated(self, inputs: StdTemplateInputs): encoded[key] = value else: encoded = self._encode(inputs) - self._handle_megatron_cp(encoded) # TODO: fix cp_size & cached_dataset input_ids = encoded.get('input_ids') labels = encoded.get('labels') loss_scale = encoded.get('loss_scale') @@ -1280,18 +1279,25 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: encoded[k] = None return encoded - def _handle_megatron_cp(self, encoded: Dict[str, Any]) -> None: + def _get_megatron_cp_length(self, length) -> int: + cp_size = self.sequence_parallel_size + if not self.use_megatron or cp_size == 1: + return length + return math.ceil(length / (cp_size * 2)) * (cp_size * 2) + + def _handle_megatron_cp(self, batch: List[Dict[str, Any]]) -> None: cp_size = self.sequence_parallel_size if not self.use_megatron or cp_size == 1: return - if self.mode == 'vllm': # skip for megatron grpo rollout - return - input_ids = encoded['input_ids'] - padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids) - input_ids += [self.tokenizer.pad_token_id] * padding_len - encoded['labels'] += [-100] * padding_len - if encoded.get('loss_scale') is not None: - encoded['loss_scale'] += [0] * padding_len + for encoded in batch: + input_ids = encoded['input_ids'] + padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids) + input_ids += [self.tokenizer.pad_token_id] * padding_len + encoded['labels'] += [-100] * padding_len + if encoded.get('loss_scale') is not None: + encoded['loss_scale'] += [0] * padding_len + if encoded.get('length') is not None: + encoded['length'] += padding_len def debug_logger(self, inputs): if not strtobool(os.getenv('SWIFT_DEBUG', 'false')): @@ -1619,6 +1625,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in assert self.tokenizer.pad_token_id is not None padding_side = self.padding_side if self.is_training else 'left' padding_right = padding_side == 'right' + self._handle_megatron_cp(batch) if self.padding_free: batch[:] = [self.packing_row(batch)] assert 'position_ids' in batch[0], f'batch[0]: {batch[0]}' diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index d72514e458..5d16ebf4c7 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -177,8 +177,8 @@ def prepare_model(cls, args, model, *, template=None, train_dataset=None, task_t def _prepare_template(self) -> None: args = self.args super()._prepare_template() - model_mapping = {'kto': 'kto', 'gkd': 'gkd', 'ppo': 'pt', 'grpo': 'train'} - self.template.set_mode(model_mapping.get(args.rlhf_type, 'rlhf')) + mode_mapping = {'kto': 'kto', 'gkd': 'gkd', 'ppo': 'pt', 'grpo': 'train'} + self.template.set_mode(mode_mapping.get(args.rlhf_type, 'rlhf')) if args.rlhf_type == 'ppo': args.training_args.stop_token_id = self.template.template_meta.stop_token_id diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 6b224ef0e5..5c8e0c07fc 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -4,17 +4,17 @@ from typing import List, Optional, Union from datasets import Dataset as HfDataset -from datasets import load_from_disk -from swift.llm.dataset.loader import DatasetLoader from swift.plugin import extra_callbacks from swift.ray import RayHelper from swift.trainers import TrainerFactory from swift.utils import append_to_jsonl, get_logger, get_model_parameter_info, is_master, plot_images, stat_array from ..argument import TrainArguments from ..base import SwiftPipeline -from ..dataset import EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset, load_dataset -from ..infer import prepare_generation_config +from ..dataset import (AddLengthPreprocessor, EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, + PackingDataset, load_dataset) +from ..dataset.loader import DatasetLoader +from ..infer import get_cached_dataset, prepare_generation_config from .tuner import TunerMixin logger = get_logger() @@ -107,25 +107,14 @@ def _save_val_dataset(self, val_dataset): append_to_jsonl(val_dataset_path, val_dataset.to_list()) logger.info(f'The split dataset from the training set will be saved at: {val_dataset_path}.') - def _get_cached_dataset(self): - args = self.args - assert not args.streaming and not args.lazy_tokenize - train_datasets, val_datasets = [], [] - for cached_dataset in args.cached_dataset: - train_path = os.path.join(cached_dataset, 'train') - val_path = os.path.join(cached_dataset, 'val') - train_datasets.append(load_from_disk(train_path)) - if os.path.exists(val_path): - val_datasets.append(load_from_disk(val_path)) - return train_datasets, val_datasets - @RayHelper.function(group='default') def _prepare_dataset(self): args = self.args # Defer encoding to the training phase pre_process = not (hasattr(args, 'rlhf_type') and args.rlhf_type in ['grpo', 'gkd']) if args.cached_dataset: - train_datasets, val_datasets = self._get_cached_dataset() + assert not args.streaming, 'Cached dataset does not support streaming.' + train_datasets, val_datasets = get_cached_dataset(self.args) else: train_datasets, val_datasets = [], [] if args.dataset: @@ -139,7 +128,7 @@ def _prepare_dataset(self): if not pre_process: return datasets datasets = self._post_process_datasets(datasets) - + self._show_dataset(*datasets) return datasets def _post_process_datasets(self, datasets: List) -> List: @@ -153,7 +142,7 @@ def _post_process_datasets(self, datasets: List) -> List: if i == 1 and predict_with_generate: # val_dataset continue - if (args.model_meta.is_multimodal or args.lazy_tokenize) and not args.streaming: + if not args.streaming: dataset = LazyLLMDataset(dataset, template.encode, strict=args.strict, random_state=args.data_seed) if args.packing: packing_dataset_cls = IterablePackingDataset if args.streaming else PackingDataset @@ -173,7 +162,6 @@ def _post_process_datasets(self, datasets: List) -> List: load_from_cache_file=args.load_from_cache_file, strict=args.strict) datasets[i] = dataset - self._show_dataset(*datasets) return datasets @RayHelper.function(group='default') @@ -336,7 +324,8 @@ def _encode_dataset(self, train_dataset, val_dataset, pre_process=True): # val_dataset continue if not args.lazy_tokenize and not args.streaming: - preprocessor = EncodePreprocessor(template=template, pre_tokenize=args.model_meta.is_multimodal) + # Compatible with cached_dataset, only additionally write length here. + preprocessor = AddLengthPreprocessor(template=template) batch_size = 100 if args.model_meta.is_multimodal else 1000 dataset = preprocessor( dataset, diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index 4609df7948..358c06177b 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -638,8 +638,6 @@ def __post_init__(self): MegatronTunerMixin.__post_init__(self) os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1' self._set_default() - if self.optimizer_cpu_offload: - require_version('megatron-core>=0.13') self.model_info, self.model_meta = get_model_info_meta( self.model, model_type=self.model_type, use_hf=self.use_hf, hub_token=self.hub_token) self.model_type = self.model_info.model_type diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 8a100a9380..fd72353f33 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -51,5 +51,3 @@ def __post_init__(self): raise ValueError('You did not pass `--load` or `--load_safetensors true` to read directly ' 'from safetensors weights, so you need to set `--no_initialization false` ' 'to allow the model to initialize weights properly.') - if self.cached_dataset and self.context_parallel_size > 1: - raise ValueError('`cached_dataset` does not support context parallelism.') diff --git a/swift/megatron/model/mm_gpt_model.py b/swift/megatron/model/mm_gpt_model.py index d5347f556c..8be3c36744 100644 --- a/swift/megatron/model/mm_gpt_model.py +++ b/swift/megatron/model/mm_gpt_model.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from contextlib import contextmanager +import megatron.core import torch from megatron.core import InferenceParams from megatron.core.packed_seq_params import PackedSeqParams @@ -9,9 +10,12 @@ from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig from megatron.training import get_args +from packaging import version from .gpt_model import GPTModel +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + class MultimodalGPTModel(MegatronModule): @@ -63,7 +67,8 @@ def forward(_self, input_): res = split_cp_inputs(res, packed_seq_params.cu_seqlens_q, 1) if reduce_scatter_embeddings: res = res.transpose(0, 1).contiguous() - res = scatter_to_sequence_parallel_region(res, group=_self.tp_group) + group_kwargs = {'group': _self.tp_group} if mcore_013 else {} + res = scatter_to_sequence_parallel_region(res, **group_kwargs) return res VocabParallelEmbedding.forward = forward From 78bb780571c35e11e0104a4b4d05fc77bfe1a464 Mon Sep 17 00:00:00 2001 From: jinghanhu Date: Mon, 17 Nov 2025 10:36:23 +0800 Subject: [PATCH 15/29] [bugfix]fix add_eos in gkd/grpo for truncated samples encoding (#6618) --- swift/trainers/rlhf_trainer/rollout_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/swift/trainers/rlhf_trainer/rollout_mixin.py b/swift/trainers/rlhf_trainer/rollout_mixin.py index 3cb154a94c..3a82e54944 100644 --- a/swift/trainers/rlhf_trainer/rollout_mixin.py +++ b/swift/trainers/rlhf_trainer/rollout_mixin.py @@ -844,7 +844,7 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out input_data['finish_reason'] = choice.finish_reason input_data['is_truncated'] = choice.finish_reason == 'length' - + input_data['add_eos'] = not choice.finish_reason == 'length' if output.rollout_infos: multi_modal_keys = ['images', 'videos', 'audios'] for key in multi_modal_keys: From 156f2bd724f4a8a64fc4f71069d82681ae7b26b6 Mon Sep 17 00:00:00 2001 From: jinghanhu Date: Mon, 17 Nov 2025 10:36:56 +0800 Subject: [PATCH 16/29] Support GKD Liger Kernel Loss (#6619) --- swift/trainers/rlhf_trainer/gkd_trainer.py | 163 ++++++++++++++++----- 1 file changed, 128 insertions(+), 35 deletions(-) diff --git a/swift/trainers/rlhf_trainer/gkd_trainer.py b/swift/trainers/rlhf_trainer/gkd_trainer.py index 2cf0947ee5..77a564c1a2 100644 --- a/swift/trainers/rlhf_trainer/gkd_trainer.py +++ b/swift/trainers/rlhf_trainer/gkd_trainer.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from accelerate.utils import gather_object +from accelerate.utils import gather_object, is_peft_model from transformers import PreTrainedModel from trl import GKDTrainer as HFGKDTrainer from trl import SFTTrainer as HFSFTTrainer @@ -22,6 +22,12 @@ from .rollout_mixin import DataType, RolloutTrainerMixin from .utils import identity_data_collator, patch_profiling_context, patch_profiling_decorator, prepare_deepspeed +try: + from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss + _liger_kernel_available = True +except ImportError: + _liger_kernel_available = False + del HFGKDTrainer.__init__ del HFSFTTrainer.__init__ @@ -50,6 +56,10 @@ def __init__(self, model: Optional[Union[PreTrainedModel, nn.Module, str]] = Non # Initialize logging components self._prepare_logging() + + # Initialize liger loss + self._prepare_liger_loss() + self.teacher_ds3_gather_for_generation = args.ds3_gather_for_generation # Initialize teacher model if self.is_deepspeed_enabled: @@ -124,42 +134,107 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if use_logits_to_keep: self.prepare_logits_to_keep(inputs) model_inputs['logits_to_keep'] = inputs['logits_to_keep'] - if self.args.sft_alpha > 0: - model_inputs['labels'] = inputs['labels'] - # compute student output - outputs_student = model(**model_inputs) - - model_inputs.pop('labels', None) - load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() - with torch.no_grad(), load_context: - outputs_teacher = self.teacher_model(**model_inputs) - - shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) - mask = shifted_labels != -100 - shifted_student_logits = outputs_student.logits[mask][None] - shifted_teacher_logits = outputs_teacher.logits[mask][None] - - # Fix the vocab_size mismatch between Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-7B-Instruct. - stu_dim = shifted_student_logits.shape[-1] - tea_dim = shifted_teacher_logits.shape[-1] - if stu_dim < tea_dim: - shifted_student_logits = F.pad(shifted_student_logits, (0, tea_dim - stu_dim), 'constant', 0) - shifted_student_logits[..., stu_dim:] = shifted_teacher_logits[..., stu_dim:] - elif stu_dim > tea_dim: - shifted_teacher_logits = F.pad(shifted_teacher_logits, (0, stu_dim - tea_dim), 'constant', 0) - shifted_teacher_logits[..., tea_dim:] = shifted_student_logits[..., tea_dim:] - - # compute loss - loss = self.generalized_jsd_loss( - student_logits=shifted_student_logits, - teacher_logits=shifted_teacher_logits, - beta=self.beta, - ) - if self.args.sft_alpha > 0: - loss = loss + self.args.sft_alpha * outputs_student.loss + + if self.use_liger_gkd_loss: + # Liger fused JSD loss for memory efficiency + # Get base models (exclude lm_head to save memory) + unwrapped_student = self.accelerator.unwrap_model(model) + if is_peft_model(unwrapped_student): + unwrapped_student = unwrapped_student.base_model.model + base_student = getattr(unwrapped_student, getattr(unwrapped_student, 'base_model_prefix', 'model'), + unwrapped_student) + + unwrapped_teacher = self.accelerator.unwrap_model(self.teacher_model) + base_teacher = getattr(unwrapped_teacher, getattr(unwrapped_teacher, 'base_model_prefix', 'model'), + unwrapped_teacher) + + # Forward through base models + student_outputs = base_student(**model_inputs, use_cache=False) + + load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() + with load_context: + with torch.no_grad(): + teacher_outputs = base_teacher(**model_inputs, use_cache=False) + + # Get hidden states (shifted) + student_hidden = student_outputs.last_hidden_state[:, :-1] + teacher_hidden = teacher_outputs.last_hidden_state[:, :-1] + + # Release full outputs to free memory + del student_outputs, teacher_outputs + + # Prepare labels (shifted) + labels_mask = inputs['labels'] != -100 + masked_input_ids = torch.where(labels_mask, inputs['input_ids'], + torch.full_like(inputs['input_ids'], -100)) + true_labels = masked_input_ids[:, 1:].contiguous() + + # Release intermediate tensors + del labels_mask, masked_input_ids + + # Get output heads + student_head = unwrapped_student.get_output_embeddings() + teacher_head = unwrapped_teacher.get_output_embeddings() + + # Compute liger fused JSD loss + loss = self.liger_jsd_loss( + student_input=student_hidden, + student_weight=student_head.weight, + teacher_input=teacher_hidden, + teacher_weight=teacher_head.weight, + true_labels=true_labels, + student_bias=getattr(student_head, 'bias', None), + teacher_bias=getattr(teacher_head, 'bias', None), + ) + + # Release hidden states after loss computation + del student_hidden, teacher_hidden, true_labels + else: + # Standard loss computation + if self.args.sft_alpha > 0: + model_inputs['labels'] = inputs['labels'] + # compute student output + outputs_student = model(**model_inputs) + + model_inputs.pop('labels', None) + load_context = self.load_teacher_model_context() if self.args.offload_teacher_model else nullcontext() + with torch.no_grad(), load_context: + outputs_teacher = self.teacher_model(**model_inputs) + + shifted_labels = torch.roll(inputs['labels'], shifts=-1, dims=1) + mask = shifted_labels != -100 + shifted_student_logits = outputs_student.logits[mask][None] + shifted_teacher_logits = outputs_teacher.logits[mask][None] + + # Fix the vocab_size mismatch between Qwen2.5-VL-3B-Instruct and Qwen2.5-VL-7B-Instruct. + stu_dim = shifted_student_logits.shape[-1] + tea_dim = shifted_teacher_logits.shape[-1] + if stu_dim < tea_dim: + shifted_student_logits = F.pad(shifted_student_logits, (0, tea_dim - stu_dim), 'constant', 0) + shifted_student_logits[..., stu_dim:] = shifted_teacher_logits[..., stu_dim:] + elif stu_dim > tea_dim: + shifted_teacher_logits = F.pad(shifted_teacher_logits, (0, stu_dim - tea_dim), 'constant', 0) + shifted_teacher_logits[..., tea_dim:] = shifted_student_logits[..., tea_dim:] + + # compute loss + loss = self.generalized_jsd_loss( + student_logits=shifted_student_logits, + teacher_logits=shifted_teacher_logits, + beta=self.beta, + ) + + # Add SFT loss if enabled (common for both paths) + if self.args.sft_alpha > 0: + loss = loss + self.args.sft_alpha * outputs_student.loss # Return loss - return (loss, outputs_student) if return_outputs else loss + if return_outputs: + if self.use_liger_gkd_loss: + # outputs has been released in liger loss computation to reduce peak memory + outputs_student = None + return (loss, outputs_student) + else: + return loss def _prepare_batch_inputs(self, inputs: list) -> Dict[str, torch.Tensor]: template = self.template @@ -298,6 +373,24 @@ def load_teacher_model_context(self): yield self.offload_model(self.accelerator.unwrap_model(self.teacher_model)) + def _prepare_liger_loss(self): + """Initialize liger loss if enabled.""" + args = self.args + self.use_liger_gkd_loss = False + if getattr(args, 'use_liger_kernel', False): + if not _liger_kernel_available: + raise ImportError( + 'Liger kernel is not installed. Please install liger-kernel by running: pip install liger-kernel') + assert self.args.sft_alpha == 0, 'SFT loss is not supported with liger loss' + + self.liger_jsd_loss = LigerFusedLinearJSDLoss( + beta=self.beta, + ignore_index=-100, + temperature=self.temperature, + compiled=False, + ) + self.use_liger_gkd_loss = True + def _prepare_logging(self): """Initialize logging components for on-policy rollout tracking.""" args = self.args From acbef0b0e55ac8b19ff61cc54a972153899847ce Mon Sep 17 00:00:00 2001 From: russwest404 <80997191+0russwest0@users.noreply.github.com> Date: Mon, 17 Nov 2025 11:43:02 +0800 Subject: [PATCH 17/29] Support generative reranker right pad (#6573) --- swift/plugin/loss.py | 27 +++++++++++++++++------- swift/trainers/mixin.py | 21 +++++++++++++++---- swift/trainers/trainers.py | 42 ++++++++++++++++++++------------------ swift/utils/__init__.py | 6 +++--- swift/utils/torch_utils.py | 36 ++++++++++++++++++++++++++++++++ 5 files changed, 98 insertions(+), 34 deletions(-) diff --git a/swift/plugin/loss.py b/swift/plugin/loss.py index 61b3eb4cda..c8139984d5 100755 --- a/swift/plugin/loss.py +++ b/swift/plugin/loss.py @@ -11,6 +11,8 @@ from torch.nn import CrossEntropyLoss, MSELoss from transformers.utils import strtobool +from swift.utils import get_last_valid_indices + def cross_entropy_loss_func(outputs, labels, num_items_in_batch=None, **kwargs): # You need to return a scalar representing the loss. @@ -534,6 +536,7 @@ def generative_reranker_loss(outputs, loss_scale=None, num_items_in_batch=None, trainer=None, + attention_mask=None, **kwargs) -> torch.Tensor: """ Generative reranker loss function. @@ -570,10 +573,14 @@ def generative_reranker_loss(outputs, raise ValueError(f"Failed to convert tokens '{positive_token}'/'{negative_token}' to IDs. " f'Please check if these tokens exist in the tokenizer vocabulary. Error: {e}') - # Extract logits for positive and negative tokens directly from last position - # This avoids creating the large intermediate tensor last_logits - positive_logits = logits[:, -1, positive_token_id] # [batch_size] - negative_logits = logits[:, -1, negative_token_id] # [batch_size] + # Extract logits at the last valid (non-padding) token position for each sample + batch_size = logits.shape[0] + last_valid_indices = get_last_valid_indices(attention_mask) + batch_indices = torch.arange(batch_size, device=logits.device) + last_valid_logits = logits[batch_indices, last_valid_indices, :] + + positive_logits = last_valid_logits[:, positive_token_id] # [batch_size] + negative_logits = last_valid_logits[:, negative_token_id] # [batch_size] # Stack to create binary classification logits # Shape: [batch_size, 2] where dim=1 represents [negative, positive] @@ -683,6 +690,7 @@ def listwise_generative_reranker_loss(outputs, loss_scale=None, num_items_in_batch=None, trainer=None, + attention_mask=None, **kwargs) -> torch.Tensor: """ List-wise generative reranker loss function. @@ -733,9 +741,14 @@ def listwise_generative_reranker_loss(outputs, raise ValueError(f"Failed to convert tokens '{positive_token}'/'{negative_token}' to IDs. " f'Please check if these tokens exist in the tokenizer vocabulary. Error: {e}') - # Extract logits for positive and negative tokens from last position - positive_logits = logits[:, -1, positive_token_id] # [batch_size] - negative_logits = logits[:, -1, negative_token_id] # [batch_size] + # Extract logits at the last valid (non-padding) token position for each sample + batch_size = logits.shape[0] + last_valid_indices = get_last_valid_indices(attention_mask) + batch_indices = torch.arange(batch_size, device=logits.device) + last_valid_logits = logits[batch_indices, last_valid_indices, :] + + positive_logits = last_valid_logits[:, positive_token_id] # [batch_size] + negative_logits = last_valid_logits[:, negative_token_id] # [batch_size] logits = F.logsigmoid(positive_logits - negative_logits) diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index d5f2f16848..81db14477f 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -42,7 +42,8 @@ from swift.llm.utils import update_generation_config_eos_token from swift.plugin import MeanMetric, compute_acc, extra_tuners, get_loss_func, get_metric from swift.tuners import SwiftModel -from swift.utils import get_current_device, get_logger, is_dist, is_mp, is_mp_ddp, ms_logger_context, seed_worker +from swift.utils import (get_current_device, get_last_valid_indices, get_logger, is_dist, is_mp, is_mp_ddp, + ms_logger_context, seed_worker) from ..llm.model.patcher import get_lm_head_model, revert_padding_free, transformers_seq_cls_forward from .arguments import TrainingArguments from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model @@ -907,7 +908,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int): else: super().create_optimizer_and_scheduler(num_training_steps=num_training_steps) - def _compute_acc(self, outputs, labels, cu_seqlens=None) -> None: + def _compute_acc(self, outputs, labels, cu_seqlens=None, attention_mask=None) -> None: args = self.args logits = outputs.logits metrics = None @@ -932,8 +933,20 @@ def _compute_acc(self, outputs, labels, cu_seqlens=None) -> None: if isinstance(positive_token_id, int) and isinstance(negative_token_id, int) \ and positive_token_id >= 0 and negative_token_id >= 0: - positive_logits = logits[:, -1, positive_token_id] - negative_logits = logits[:, -1, negative_token_id] + # Handle right padding by finding the last valid token position + if attention_mask is not None: + # Extract logits at the last valid (non-padding) token position for each sample + batch_size = logits.shape[0] + last_valid_indices = get_last_valid_indices(attention_mask) + batch_indices = torch.arange(batch_size, device=logits.device) + last_valid_logits = logits[batch_indices, last_valid_indices, :] + positive_logits = last_valid_logits[:, positive_token_id] + negative_logits = last_valid_logits[:, negative_token_id] + else: + # Fallback to original behavior if attention_mask is not available + positive_logits = logits[:, -1, positive_token_id] + negative_logits = logits[:, -1, negative_token_id] + binary_preds = (positive_logits > negative_logits).long() metrics = compute_acc( binary_preds, diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index cf4d451870..d30e8dd484 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -108,6 +108,7 @@ class RerankerTrainer(Trainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.args.include_for_metrics = ['inputs'] self.compute_metrics = self.calculate_metric self.label_names = ['labels'] @@ -124,8 +125,6 @@ def _preprocess_generative_reranker_logits(self, logits, labels): Extract only the yes/no token logits at the last valid (non -100) timestep for each sample, avoiding padded timesteps created by multi-GPU gather. """ - import torch - import os # Get token IDs for positive and negative tokens positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes') @@ -146,21 +145,8 @@ def _preprocess_generative_reranker_logits(self, logits, labels): # Extract only the yes/no token logits from the last non -100 position per sample # Shapes: logits [batch, seq_len, vocab] if len(logits.shape) == 3: - batch_size, _, vocab_size = logits.shape - - # Identify padded rows whose entire vocab logits are -100 - row_is_pad = (logits == -100).all(dim=-1) # [batch, seq_len] - valid_mask = ~row_is_pad - lengths = valid_mask.long().sum(dim=1) - 1 - lengths = torch.clamp(lengths, min=0) - last_indices = lengths.to(device=logits.device) - - # Gather the logits at the last valid index for each sample: [batch, vocab] - gather_index = last_indices.view(batch_size, 1, 1).expand(batch_size, 1, vocab_size) - last_step_logits = torch.gather(logits, dim=1, index=gather_index).squeeze(1) - - positive_logits = last_step_logits[:, positive_token_id] - negative_logits = last_step_logits[:, negative_token_id] + positive_logits = logits[:, :, positive_token_id] + negative_logits = logits[:, :, negative_token_id] logits = positive_logits - negative_logits return logits else: @@ -173,8 +159,19 @@ def evaluation_loop(self, *args, **kwargs): return output def calculate_metric(self, eval_prediction: EvalPrediction) -> Dict[str, float]: + import numpy as np from swift.plugin.loss import calculate_reranker_metrics - return calculate_reranker_metrics(eval_prediction.predictions, eval_prediction.label_ids) + input_ids = eval_prediction.inputs + logits = eval_prediction.predictions + labels = eval_prediction.label_ids + + if logits.ndim == 2 and logits.shape[1] > 1: + pad_token_id = self.tokenizer.pad_token_id + valid_mask = (input_ids != pad_token_id) & (input_ids != -100) + last_valid_indices = valid_mask[:, ::-1].argmax(axis=1) + last_valid_indices = input_ids.shape[1] - 1 - last_valid_indices + logits = logits[np.arange(logits.shape[0]), last_valid_indices] + return calculate_reranker_metrics(logits, labels) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): # Check if we have a custom loss function @@ -188,7 +185,12 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if labels is not None: # Call custom loss function - loss = self.compute_loss_func(outputs, labels, num_items_in_batch=num_items_in_batch, trainer=self) + loss = self.compute_loss_func( + outputs, + labels, + num_items_in_batch=num_items_in_batch, + trainer=self, + attention_mask=inputs['attention_mask']) else: # Fallback to model's loss loss = outputs.loss @@ -197,7 +199,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss = loss / self.args.gradient_accumulation_steps if labels is not None: - self._compute_acc(outputs, labels) + self._compute_acc(outputs, labels, attention_mask=inputs.get('attention_mask')) return (loss, outputs) if return_outputs else loss else: diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py index 4e24aa500b..7db8b3c9a8 100644 --- a/swift/utils/__init__.py +++ b/swift/utils/__init__.py @@ -12,9 +12,9 @@ from .torch_utils import (Serializer, activate_parameters, check_shared_disk, disable_safe_ddp_context_use_barrier, empty_cache, find_all_linears, find_embedding, find_layers, find_norm, freeze_parameters, gc_collect, get_cu_seqlens_from_position_ids, get_current_device, get_device, - get_device_count, get_model_parameter_info, get_n_params_grads, init_process_group, - safe_ddp_context, seed_worker, set_default_ddp_config, set_device, show_layers, - time_synchronize, unwrap_model_for_generation) + get_device_count, get_last_valid_indices, get_model_parameter_info, get_n_params_grads, + init_process_group, safe_ddp_context, seed_worker, set_default_ddp_config, set_device, + show_layers, time_synchronize, unwrap_model_for_generation) from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port, format_time, get_env_args, import_external_file, json_parse_to_dict, lower_bound, parse_args, patch_getattr, read_multi_line, remove_response, seed_everything, split_list, subprocess_run, diff --git a/swift/utils/torch_utils.py b/swift/utils/torch_utils.py index f376310df7..4b412e2dd9 100644 --- a/swift/utils/torch_utils.py +++ b/swift/utils/torch_utils.py @@ -390,6 +390,42 @@ def get_position_ids_from_cu_seqlens(cu_seqlens: torch.LongTensor): return position_ids.unsqueeze(0) +def get_last_valid_indices(attention_mask: torch.Tensor) -> torch.Tensor: + """ + Get the last valid (non-padding) token position indices for each sample. + + This function correctly handles sequences with different padding directions (left/right/none) + within the same batch by computing the last valid index for each sequence individually. + + Args: + attention_mask: Attention mask [batch_size, seq_len] where 1=valid, 0=padding + + Returns: + torch.Tensor: Indices of last valid positions [batch_size] + + Examples: + >>> # Right padding + >>> attention_mask = torch.tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]) + >>> get_last_valid_indices(attention_mask) + tensor([2, 3]) + + >>> # Left padding + >>> attention_mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 1, 1]]) + >>> get_last_valid_indices(attention_mask) + tensor([4, 4]) + """ + seq_len = attention_mask.shape[1] + + # Flip the mask horizontally to bring the last elements to the front. + # `argmax` will then find the index of the first '1', which corresponds to the last valid token. + last_valid_indices = torch.fliplr(attention_mask).argmax(dim=1) + + # Convert the index from the right-to-left frame to the original left-to-right frame. + indices = seq_len - 1 - last_valid_indices + + return indices + + class Serializer: @staticmethod From 228e699a84d9633a7c0d065498ee56abc43a6c61 Mon Sep 17 00:00:00 2001 From: Jintao Date: Mon, 17 Nov 2025 11:47:56 +0800 Subject: [PATCH 18/29] update swift image 3.10.1 (#6622) --- docs/source/GetStarted/SWIFT-installation.md | 13 +++++++++---- docs/source/Megatron-SWIFT/Quick-start.md | 6 +++--- docs/source_en/GetStarted/SWIFT-installation.md | 13 +++++++++---- docs/source_en/Megatron-SWIFT/Quick-start.md | 6 +++--- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/docs/source/GetStarted/SWIFT-installation.md b/docs/source/GetStarted/SWIFT-installation.md index c19bf42334..7750912ffe 100644 --- a/docs/source/GetStarted/SWIFT-installation.md +++ b/docs/source/GetStarted/SWIFT-installation.md @@ -40,6 +40,11 @@ pip install ms-swift==2.* docker可以查看[这里](https://github.com/modelscope/modelscope/blob/master/docker/build_image.py#L345)。 ``` +# swift3.10.1 +modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 +modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 +modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 + # swift3.9.3 modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.9.3 modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.9.3 @@ -49,7 +54,11 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.6.3-py311-torch2.7.1-vllm0.10.1.1-modelscope1.29.2-swift3.8.3 modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.6.3-py311-torch2.7.1-vllm0.10.1.1-modelscope1.29.2-swift3.8.3 modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.6.3-py311-torch2.7.1-vllm0.10.1.1-modelscope1.29.2-swift3.8.3 +``` +
历史镜像 + +``` # swift3.7.2 modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.6.3-py311-torch2.7.1-vllm0.10.0-modelscope1.28.2-swift3.7.2 modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.6.3-py311-torch2.7.1-vllm0.10.0-modelscope1.28.2-swift3.7.2 @@ -59,11 +68,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py310-torch2.6.0-vllm0.8.5.post1-modelscope1.28.1-swift3.6.4 modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py310-torch2.6.0-vllm0.8.5.post1-modelscope1.28.1-swift3.6.4 modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py310-torch2.6.0-vllm0.8.5.post1-modelscope1.28.1-swift3.6.4 -``` - -
历史镜像 -``` # swift3.5.3 modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py310-torch2.6.0-vllm0.8.5.post1-modelscope1.27.1-swift3.5.3 modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py310-torch2.6.0-vllm0.8.5.post1-modelscope1.27.1-swift3.5.3 diff --git a/docs/source/Megatron-SWIFT/Quick-start.md b/docs/source/Megatron-SWIFT/Quick-start.md index faff26ecec..9a176965e9 100644 --- a/docs/source/Megatron-SWIFT/Quick-start.md +++ b/docs/source/Megatron-SWIFT/Quick-start.md @@ -54,9 +54,9 @@ MAX_JOBS=8 pip install "flash-attn<2.8.2" --no-build-isolation 或者你也可以使用镜像:(历史镜像查看[这里](../GetStarted/SWIFT-installation.md#镜像)) ``` -modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.9.3 -modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.9.3 -modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.9.3 +modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 +modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 +modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 ``` 推荐运行环境: diff --git a/docs/source_en/GetStarted/SWIFT-installation.md b/docs/source_en/GetStarted/SWIFT-installation.md index da512f4f84..826829d90c 100644 --- a/docs/source_en/GetStarted/SWIFT-installation.md +++ b/docs/source_en/GetStarted/SWIFT-installation.md @@ -41,6 +41,11 @@ pip install ms-swift==2.* You can check Docker [here](https://github.com/modelscope/modelscope/blob/master/docker/build_image.py#L345). ``` +# swift3.10.1 +modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 +modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 +modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 + # swift3.9.3 modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.9.3 modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.9.3 @@ -50,7 +55,11 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.6.3-py311-torch2.7.1-vllm0.10.1.1-modelscope1.29.2-swift3.8.3 modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.6.3-py311-torch2.7.1-vllm0.10.1.1-modelscope1.29.2-swift3.8.3 modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.6.3-py311-torch2.7.1-vllm0.10.1.1-modelscope1.29.2-swift3.8.3 +``` +
Historical Mirrors + +``` # swift3.7.2 modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.6.3-py311-torch2.7.1-vllm0.10.0-modelscope1.28.2-swift3.7.2 modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.6.3-py311-torch2.7.1-vllm0.10.0-modelscope1.28.2-swift3.7.2 @@ -60,11 +69,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py310-torch2.6.0-vllm0.8.5.post1-modelscope1.28.1-swift3.6.4 modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py310-torch2.6.0-vllm0.8.5.post1-modelscope1.28.1-swift3.6.4 modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py310-torch2.6.0-vllm0.8.5.post1-modelscope1.28.1-swift3.6.4 -``` - -
Historical Mirrors -``` # swift3.5.3 modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py310-torch2.6.0-vllm0.8.5.post1-modelscope1.27.1-swift3.5.3 modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py310-torch2.6.0-vllm0.8.5.post1-modelscope1.27.1-swift3.5.3 diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 94123c8c4e..215ee4daaa 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -53,9 +53,9 @@ MAX_JOBS=8 pip install "flash-attn<2.8.2" --no-build-isolation Alternatively, you can also use the image: (See historical images [here](../GetStarted/SWIFT-installation.md#mirror)) ``` -modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.9.3 -modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.9.3 -modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.9.3 +modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 +modelscope-registry.cn-beijing.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 +modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.8.0-vllm0.11.0-modelscope1.31.0-swift3.10.1 ``` Recommended Operating Environment: From a98abf39f5199209293de3ef74a0d6fc2ccf2cdc Mon Sep 17 00:00:00 2001 From: Jintao Date: Mon, 17 Nov 2025 14:18:09 +0800 Subject: [PATCH 19/29] [model] support mistral 2506 (#6624) --- .../Supported-models-and-datasets.md | 1 + .../Supported-models-and-datasets.md | 1 + swift/llm/model/constant.py | 1 + swift/llm/model/model/mistral.py | 43 ++++++++-- swift/llm/model/model_arch.py | 1 - swift/llm/template/constant.py | 1 + swift/llm/template/template/llm.py | 23 ----- swift/llm/template/template/mistral.py | 84 +++++++++++++------ tests/test_align/test_template/test_vision.py | 12 ++- 9 files changed, 108 insertions(+), 59 deletions(-) diff --git a/docs/source/Instruction/Supported-models-and-datasets.md b/docs/source/Instruction/Supported-models-and-datasets.md index d953dc4e84..1b4e35058f 100644 --- a/docs/source/Instruction/Supported-models-and-datasets.md +++ b/docs/source/Instruction/Supported-models-and-datasets.md @@ -1024,6 +1024,7 @@ |[google/gemma-3n-E4B-it](https://modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)| |[mistralai/Mistral-Small-3.1-24B-Base-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Base-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503)| |[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)| +|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://modelscope.cn/models/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|mistral_2506|mistral_2506|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://huggingface.co/mistralai/Mistral-Small-3.2-24B-Instruct-2506)| |[PaddlePaddle/PaddleOCR-VL](https://modelscope.cn/models/PaddlePaddle/PaddleOCR-VL)|paddle_ocr|paddle_ocr|-|✘|-|[PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL)| |[JinaAI/jina-reranker-m0](https://modelscope.cn/models/JinaAI/jina-reranker-m0)|jina_reranker_m0|jina_reranker_m0|-|✘|reranker, vision|[JinaAI/jina-reranker-m0](https://huggingface.co/JinaAI/jina-reranker-m0)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 13bbdb6ac9..808cec5a85 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -1024,6 +1024,7 @@ The table below introduces the models integrated with ms-swift: |[google/gemma-3n-E4B-it](https://modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)| |[mistralai/Mistral-Small-3.1-24B-Base-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Base-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503)| |[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)| +|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://modelscope.cn/models/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|mistral_2506|mistral_2506|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://huggingface.co/mistralai/Mistral-Small-3.2-24B-Instruct-2506)| |[PaddlePaddle/PaddleOCR-VL](https://modelscope.cn/models/PaddlePaddle/PaddleOCR-VL)|paddle_ocr|paddle_ocr|-|✘|-|[PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL)| |[JinaAI/jina-reranker-m0](https://modelscope.cn/models/JinaAI/jina-reranker-m0)|jina_reranker_m0|jina_reranker_m0|-|✘|reranker, vision|[JinaAI/jina-reranker-m0](https://huggingface.co/JinaAI/jina-reranker-m0)| diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py index 72c96fe343..dd252e0fd8 100644 --- a/swift/llm/model/constant.py +++ b/swift/llm/model/constant.py @@ -274,6 +274,7 @@ class MLLMModelType: gemma3_vision = 'gemma3_vision' gemma3n = 'gemma3n' mistral_2503 = 'mistral_2503' + mistral_2506 = 'mistral_2506' paddle_ocr = 'paddle_ocr' diff --git a/swift/llm/model/model/mistral.py b/swift/llm/model/model/mistral.py index 6d23b23b41..ceaa6e6ed8 100644 --- a/swift/llm/model/model/mistral.py +++ b/swift/llm/model/model/mistral.py @@ -1,8 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - from typing import Any, Dict -from transformers import AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer from swift.llm import TemplateType from ..constant import LLMModelType, MLLMModelType @@ -130,12 +129,7 @@ def get_model_tokenizer_mistral_2503(model_dir: str, model_kwargs: Dict[str, Any], load_model: bool = True, **kwargs): - try: - from transformers import Mistral3ForConditionalGeneration - except ImportError: - raise ImportError('Please install Mistral3ForConditionalGeneration by running ' - '`pip install git+https://github.com/huggingface/transformers@v4.49.0-Mistral-3`') - + from transformers import Mistral3ForConditionalGeneration kwargs['automodel_class'] = kwargs['automodel_class'] or Mistral3ForConditionalGeneration model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs) @@ -184,4 +178,35 @@ def get_model_tokenizer_devstral_2505(model_dir: str, architectures=['Mistral3ForConditionalGeneration'], model_arch=ModelArch.llava_hf, requires=['transformers>=4.49'], - ), ) + )) + + +def get_model_tokenizer_mistral_2506(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + from transformers import Mistral3ForConditionalGeneration + tokenizer_dir = safe_snapshot_download('mistralai/Mistral-Small-3.1-24B-Instruct-2503', download_model=False) + processor = AutoProcessor.from_pretrained(tokenizer_dir) + kwargs['automodel_class'] = kwargs['automodel_class'] or Mistral3ForConditionalGeneration + kwargs['tokenizer'] = processor.tokenizer + model, _ = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.mistral_2506, + [ + ModelGroup([ + Model('mistralai/Mistral-Small-3.2-24B-Instruct-2506', 'mistralai/Mistral-Small-3.2-24B-Instruct-2506'), + ]), + ], + TemplateType.mistral_2506, + get_model_tokenizer_mistral_2506, + architectures=['Mistral3ForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.49'], + )) diff --git a/swift/llm/model/model_arch.py b/swift/llm/model/model_arch.py index 5c8e2f4051..fed4ff8055 100644 --- a/swift/llm/model/model_arch.py +++ b/swift/llm/model/model_arch.py @@ -81,7 +81,6 @@ class MLLMModelArch: megrez_omni = 'megrez_omni' valley = 'valley' gemma3n = 'gemma3n' - mistral_2503 = 'mistral_2503' keye_vl = 'keye_vl' midashenglm = 'midashenglm' diff --git a/swift/llm/template/constant.py b/swift/llm/template/constant.py index 4a073f5442..82810452ad 100644 --- a/swift/llm/template/constant.py +++ b/swift/llm/template/constant.py @@ -229,6 +229,7 @@ class MLLMTemplateType: gemma3_vision = 'gemma3_vision' gemma3n = 'gemma3n' mistral_2503 = 'mistral_2503' + mistral_2506 = 'mistral_2506' paddle_ocr = 'paddle_ocr' diff --git a/swift/llm/template/template/llm.py b/swift/llm/template/template/llm.py index 67d9a1514e..dbd9aa2850 100644 --- a/swift/llm/template/template/llm.py +++ b/swift/llm/template/template/llm.py @@ -119,29 +119,6 @@ def _preprocess_inputs(self, inputs: StdTemplateInputs) -> None: chat_sep=['[INST] '], suffix=[''])) -today = datetime.now().strftime('%Y-%m-%d') - -mistral_2501_system = ( - 'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup ' - 'headquartered in Paris.\n' - f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n' - "When you're not sure about some information, you say that you don't have the information and don't " - 'make up anything.\n' - "If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer " - 'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. ' - '"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "' - 'Where do you travel from?")') - -register_template( - TemplateMeta( - LLMTemplateType.mistral_2501, - prefix=[''], - prompt=['[INST]{{QUERY}}[/INST]'], - chat_sep=[''], - suffix=[''], - system_prefix=['[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], - default_system=mistral_2501_system)) - register_template( TemplateMeta( LLMTemplateType.xverse, diff --git a/swift/llm/template/template/mistral.py b/swift/llm/template/template/mistral.py index 679599a6d8..ec2870df8f 100644 --- a/swift/llm/template/template/mistral.py +++ b/swift/llm/template/template/mistral.py @@ -1,14 +1,39 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os +from dataclasses import dataclass, field +from datetime import datetime, timedelta from typing import Any, Dict, List, Literal, Optional -import torch - from ..base import Template -from ..constant import MLLMTemplateType +from ..constant import LLMTemplateType, MLLMTemplateType from ..register import TemplateMeta, register_template from ..template_inputs import StdTemplateInputs -from ..utils import Context, findall -from .llm import mistral_2501_system +from ..utils import Context, Prompt, findall + +today = datetime.now().strftime('%Y-%m-%d') + +mistral_2501_system = ( + 'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup ' + 'headquartered in Paris.\n' + f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n' + "When you're not sure about some information, you say that you don't have the information and don't " + 'make up anything.\n' + "If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer " + 'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. ' + '"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "' + 'Where do you travel from?")') + + +@dataclass +class Mistral3TemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['']) + prompt: Prompt = field(default_factory=lambda: ['[INST]{{QUERY}}[/INST]']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['']) + suffix: Prompt = field(default_factory=lambda: ['']) + system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]']) + + +register_template(Mistral3TemplateMeta(LLMTemplateType.mistral_2501, default_system=mistral_2501_system)) class Mistral2503Template(Template): @@ -28,15 +53,16 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: labels = encoded['labels'] loss_scale = encoded.get('loss_scale', None) idx_list = findall(input_ids, self.image_token) + patch_size = processor.patch_size * processor.spatial_merge_size if idx_list: - image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt') + image_inputs = processor.image_processor(images, patch_size=patch_size, return_tensors='pt') encoded['pixel_values'] = image_inputs['pixel_values'].to(self.model_info.torch_dtype) encoded['image_sizes'] = image_sizes = image_inputs['image_sizes'] def _get_new_tokens(i): height, width = image_sizes[i] - num_height_tokens = height // (processor.patch_size * processor.spatial_merge_size) - num_width_tokens = width // (processor.patch_size * processor.spatial_merge_size) + num_height_tokens = height // patch_size + num_width_tokens = width // patch_size replace_tokens = [[processor.image_token] * num_width_tokens + [processor.image_break_token] ] * num_height_tokens # Flatten list @@ -52,15 +78,8 @@ def _get_new_tokens(i): register_template( - TemplateMeta( - MLLMTemplateType.mistral_2503, - prefix=[''], - prompt=['[INST]{{QUERY}}[/INST]'], - chat_sep=[''], - suffix=[''], - system_prefix=['[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], - default_system=mistral_2501_system, - template_cls=Mistral2503Template)) + Mistral3TemplateMeta( + MLLMTemplateType.mistral_2503, default_system=mistral_2501_system, template_cls=Mistral2503Template)) devstral_small_2505_system = ( # from https://huggingface.co/mistralai/Devstral-Small-2505/blob/main/SYSTEM_PROMPT.txt 'You are Devstral, a helpful agentic model trained by Mistral AI and using the OpenHands scaffold. ' @@ -122,12 +141,27 @@ def _get_new_tokens(i): 'executing a plan from the user, please don\'t try to directly work around it. Instead, propose a new ' 'plan and confirm with the user before proceeding.\n') +register_template(Mistral3TemplateMeta('devstral', default_system=devstral_small_2505_system)) + + +class Mistral2506Template(Mistral2503Template): + + def _get_mistral_system(self): + from swift.llm import get_model_name + model_dir = self.model_info.model_dir + model_name = get_model_name(model_dir) + file_path = os.path.join(model_dir, 'SYSTEM_PROMPT.txt') + with open(file_path, 'r') as file: + system_prompt = file.read() + today = datetime.today().strftime('%Y-%m-%d') + yesterday = (datetime.today() - timedelta(days=1)).strftime('%Y-%m-%d') + return system_prompt.format(name=model_name, today=today, yesterday=yesterday) + + def _swift_encode(self, inputs: StdTemplateInputs): + if inputs.system is None: + inputs.system = self._get_mistral_system() + return super()._swift_encode(inputs) + + register_template( - TemplateMeta( - 'devstral', - prefix=[''], - prompt=['[INST]{{QUERY}}[/INST]'], # the user query - chat_sep=[''], - suffix=[''], - system_prefix=['[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], # the system prompt - default_system=devstral_small_2505_system)) + Mistral3TemplateMeta(MLLMTemplateType.mistral_2506, default_system=None, template_cls=Mistral2506Template)) diff --git a/tests/test_align/test_template/test_vision.py b/tests/test_align/test_template/test_vision.py index 46f8f494f1..ab5d72baa2 100644 --- a/tests/test_align/test_template/test_vision.py +++ b/tests/test_align/test_template/test_vision.py @@ -1092,6 +1092,15 @@ def test_ernie_vl_thinking(): assert response == '\n\n' + response2 +def test_mistral_2506(): + pt_engine = PtEngine('mistralai/Mistral-Small-3.2-24B-Instruct-2506') + response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': 'describe the image.'}]) + assert response[:200] == ( + 'The image features a close-up of a kitten with striking blue eyes. The kitten has a soft, ' + 'fluffy coat with a mix of white, gray, and brown fur. Its fur pattern includes distinct ' + 'stripes, particularly ') + + if __name__ == '__main__': from swift.llm import PtEngine, RequestConfig from swift.utils import get_logger, seed_everything @@ -1168,4 +1177,5 @@ def test_ernie_vl_thinking(): # test_llava_onevision1_5() # test_paddle_ocr() # test_ernie_vl() - test_ernie_vl_thinking() + # test_ernie_vl_thinking() + test_mistral_2506() From 4bd007d3b9d8ada29fd67adaba2037a2b495a2c0 Mon Sep 17 00:00:00 2001 From: Jintao Date: Mon, 17 Nov 2025 14:44:40 +0800 Subject: [PATCH 20/29] update peft version (#6621) --- README.md | 2 +- README_CN.md | 2 +- docs/source/GetStarted/SWIFT-installation.md | 2 +- docs/source/Megatron-SWIFT/Quick-start.md | 2 +- docs/source_en/GetStarted/SWIFT-installation.md | 2 +- docs/source_en/Megatron-SWIFT/Quick-start.md | 2 +- requirements/framework.txt | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 455ff142d3..5f909279b2 100644 --- a/README.md +++ b/README.md @@ -134,7 +134,7 @@ Running Environment: | torch | >=2.0 | 2.8.0 | | | transformers | >=4.33 | 4.57.1 | | | modelscope | >=1.23 | | | -| peft | >=0.11,<0.18 | | | +| peft | >=0.11,<0.19 | | | | flash_attn | | 2.8.1/3.0.0b1 | | | trl | >=0.15,<0.25 | 0.23.1 | RLHF | | deepspeed | >=0.14 | 0.17.6 | Training | diff --git a/README_CN.md b/README_CN.md index 08a7f1b93d..7706c5b780 100644 --- a/README_CN.md +++ b/README_CN.md @@ -130,7 +130,7 @@ pip install -e . | torch | >=2.0 | 2.8.0 | | | transformers | >=4.33 | 4.57.1 | | | modelscope | >=1.23 | | | -| peft | >=0.11,<0.18 | | | +| peft | >=0.11,<0.19 | | | | flash_attn | | 2.8.1/3.0.0b1 | | | trl | >=0.15,<0.25 | 0.23.1 | RLHF | | deepspeed | >=0.14 | 0.17.6 | 训练 | diff --git a/docs/source/GetStarted/SWIFT-installation.md b/docs/source/GetStarted/SWIFT-installation.md index 7750912ffe..28e92cfc33 100644 --- a/docs/source/GetStarted/SWIFT-installation.md +++ b/docs/source/GetStarted/SWIFT-installation.md @@ -111,7 +111,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 | torch | >=2.0 | 2.8.0 | | | transformers | >=4.33 | 4.57.1 | | | modelscope | >=1.23 | | | -| peft | >=0.11,<0.18 | | | +| peft | >=0.11,<0.19 | | | | flash_attn | | 2.8.1/3.0.0b1 | | | trl | >=0.15,<0.25 | 0.23.1 | RLHF | | deepspeed | >=0.14 | 0.17.6 | 训练 | diff --git a/docs/source/Megatron-SWIFT/Quick-start.md b/docs/source/Megatron-SWIFT/Quick-start.md index 9a176965e9..7c977cb923 100644 --- a/docs/source/Megatron-SWIFT/Quick-start.md +++ b/docs/source/Megatron-SWIFT/Quick-start.md @@ -71,7 +71,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2 | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.57.1 | | | modelscope | >=1.23 | | | -| peft | >=0.11,<0.18 | | LoRA | +| peft | >=0.11,<0.19 | | LoRA | | trl | >=0.15,<0.25 | | RLHF | diff --git a/docs/source_en/GetStarted/SWIFT-installation.md b/docs/source_en/GetStarted/SWIFT-installation.md index 826829d90c..8818d07af9 100644 --- a/docs/source_en/GetStarted/SWIFT-installation.md +++ b/docs/source_en/GetStarted/SWIFT-installation.md @@ -112,7 +112,7 @@ More images can be found [here](https://modelscope.cn/docs/intro/environment-set | torch | >=2.0 | 2.8.0 | | | transformers | >=4.33 | 4.57.1 | | | modelscope | >=1.23 | | | -| peft | >=0.11,<0.18 | | | +| peft | >=0.11,<0.19 | | | | flash_attn | | 2.8.1 /3.0.0b1 | | | trl | >=0.15,<0.25 | 0.23.1 | RLHF | | deepspeed | >=0.14 | 0.17.6 | Training | diff --git a/docs/source_en/Megatron-SWIFT/Quick-start.md b/docs/source_en/Megatron-SWIFT/Quick-start.md index 215ee4daaa..3e4f390680 100644 --- a/docs/source_en/Megatron-SWIFT/Quick-start.md +++ b/docs/source_en/Megatron-SWIFT/Quick-start.md @@ -71,7 +71,7 @@ Recommended Operating Environment: | flash_attn | | 2.8.1/3.0.0b1 | | | transformers | >=4.33 | 4.57.1 | | | modelscope | >=1.23 | | | -| peft | >=0.11,<0.18 | | LoRA | +| peft | >=0.11,<0.19 | | LoRA | | trl | >=0.15,<0.25 | | RLHF | diff --git a/requirements/framework.txt b/requirements/framework.txt index 61c7f1cddd..fbe52b3f12 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -21,7 +21,7 @@ omegaconf openai oss2 pandas -peft>=0.11,<0.18 +peft>=0.11,<0.19 pillow PyYAML>=5.4 requests From cdf2a110970dbee89cf1c1c44b0c7f70cbe565f7 Mon Sep 17 00:00:00 2001 From: Jintao Date: Mon, 17 Nov 2025 15:06:06 +0800 Subject: [PATCH 21/29] [bugfix] Fix multinode write conflict mcore-bridge (deepseek-v3) (#6626) --- swift/megatron/model/gpt_bridge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index b86c16e188..40aff04568 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -14,7 +14,7 @@ from transformers.modeling_utils import custom_object_save from swift.llm import deep_getattr, get_model_tokenizer, safe_snapshot_download, save_checkpoint -from swift.utils import disable_safe_ddp_context_use_barrier, get_logger, is_last_rank +from swift.utils import get_logger, is_last_rank from ..tuners import LoraParallelLinear from ..utils import SafetensorLazyLoader, StreamingSafetensorSaver @@ -65,7 +65,7 @@ def __init__(self, disable_tqmd: bool = False): self.ep_rank = mpu.get_expert_model_parallel_rank() def _init_meta_hf_model(self): - with torch.device('meta'), disable_safe_ddp_context_use_barrier(): + with torch.device('meta'): self.hf_model, self.processor = get_model_tokenizer( self.args.model_dir, model_type=self.args.hf_model_type, return_dummy_model=True) From 229f1c2d34b4098246e19fda550ecbdbb3c7f879 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E7=AB=A5?= <31961076+tongchen126@users.noreply.github.com> Date: Mon, 17 Nov 2025 17:54:56 +0800 Subject: [PATCH 22/29] [bugfix] Initialize chord dataset after accelerator setup in GRPOTrainer (#6638) The get_chord_sft_dataloader() method relies on GRPOTrainer.accelerator, but the function was previously called before the parent class (super().__init__) finished initializing the accelerator. As a result, the get_chord_sft_dataloader will raise exception regarding non-existent attribute GRPOTrainer.accelerator. --- swift/trainers/rlhf_trainer/grpo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 53cc4b5c99..350154a569 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -81,6 +81,7 @@ def __init__(self, reward_templates = kwargs.pop('reward_template', None) self._prepare_algorithm_params() super().__init__(model, ref_model, *_args, **kwargs) + self._prepare_chord_dataset() self.prepare_rollout() self._prepare_rewards(reward_funcs, reward_model, reward_templates) @@ -1868,6 +1869,7 @@ def _prepare_algorithm_params(self): self.advantage_estimator = args.advantage_estimator self.kl_in_reward = args.kl_in_reward + def _prepare_chord_dataset(self): # CHORD, https://arxiv.org/abs/2508.11408 self.chord_sft_iterator = None if self.chord_sft_dataset: From 83d52f5b419a9f049780052cb1e35e51afa6bae3 Mon Sep 17 00:00:00 2001 From: jinghanhu Date: Tue, 18 Nov 2025 16:24:21 +0800 Subject: [PATCH 23/29] [bugfix] fix megatron grpo max_epochs (#6646) --- swift/megatron/trainers/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 999745acbf..f2e55224a2 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -135,7 +135,12 @@ def new_cyclic_iter(self, iterable): yield from x x = next_x logger.info(f'Training of {i + 1} epochs has been completed, the training has finished.') - x[0]['is_finished'] = True + if isinstance(x, list) and all(isinstance(item, dict) for item in x): + x[0]['is_finished'] = True + elif isinstance(x, list) and all(isinstance(item, list) for item in x): + # grpo + for item in x: + item[0]['is_finished'] = True yield from x else: for x in iterable: From 5ed41fc79a3dc27b0b06df80222461280476f1e6 Mon Sep 17 00:00:00 2001 From: jinghanhu Date: Tue, 18 Nov 2025 16:24:53 +0800 Subject: [PATCH 24/29] [bugfix] fix megatron grpo server mode sync weight (#6648) --- swift/megatron/trainers/grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index d3253d4b39..d6e0a68690 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -245,7 +245,7 @@ def _export_and_load_weights(self): # Colocate mode: load_weights supports iterator, pass directly llm_model = self.engine.inner_model llm_model.load_weights(weight_iterator) - elif self.vllm_mode == 'server' and self.is_main_process: + elif self.vllm_mode == 'server': # Server mode: process in buckets and sync with flattened tensors self._load_weights_to_server_in_buckets(weight_iterator) @@ -285,7 +285,7 @@ def _sync_bucket_to_server(self, bucket_params: List[Tuple[str, torch.Tensor]]): Args: bucket_params: List of (name, tensor) tuples to sync """ - if not bucket_params: + if not bucket_params or not self.is_main_process: return # Create FlattenedTensorBucket for efficient transfer From 290c6c31f93edb813adc8ca8d61ff555c8eb2b15 Mon Sep 17 00:00:00 2001 From: Jintao Date: Tue, 18 Nov 2025 20:07:07 +0800 Subject: [PATCH 25/29] [megatron] fix save barrier (#6653) --- swift/megatron/train/sft.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/swift/megatron/train/sft.py b/swift/megatron/train/sft.py index b369261482..5d9cb7a423 100644 --- a/swift/megatron/train/sft.py +++ b/swift/megatron/train/sft.py @@ -4,6 +4,7 @@ from typing import List, Optional, Union import torch +import torch.distributed as dist from swift.llm import TEMPLATE_MAPPING from swift.llm.train import SwiftSft @@ -60,6 +61,7 @@ def run(self): try: self.trainer.train(train_dataset, val_dataset, data_collator) + dist.barrier() # Ensure all weights are saved completely finally: # Visualization if is_last_rank(): From cd85a49c396e88a8d09f1c571a1d3f88c08ebd0a Mon Sep 17 00:00:00 2001 From: jinghanhu Date: Tue, 18 Nov 2025 21:50:24 +0800 Subject: [PATCH 26/29] [bugfix] fix megatron grpo rollout_group (#6655) --- swift/megatron/trainers/grpo_trainer.py | 83 +++++++++++-------------- swift/megatron/utils/utils.py | 2 +- 2 files changed, 39 insertions(+), 46 deletions(-) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py index d6e0a68690..e81e470994 100644 --- a/swift/megatron/trainers/grpo_trainer.py +++ b/swift/megatron/trainers/grpo_trainer.py @@ -32,7 +32,7 @@ from swift.utils import (get_current_device, get_logger, is_last_rank, is_vllm_available, is_wandb_available, remove_response) from ..argument import MegatronArguments, MegatronRLHFArguments -from ..utils import forward_step_helper +from ..utils import forward_step_helper, get_padding_to from .rlhf_mixin import MegatronRLHFTrainer from .utils import (gather, gather_object, get_swift_datasets_provider, load_megatron_model_to_gpu, load_megatron_optimizer, offload_megatron_model_to_cpu, offload_megatron_optimizer, @@ -53,7 +53,6 @@ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs): self.hf_model_dir = args.model_info.model_dir self.processing_class = self.template.processor self._prepare_metrics() - self._prepare_template_data_collator() self._init_grpo_params() self._prepare_rewards() self._prepare_scheduler() # TODO @@ -66,21 +65,6 @@ def train(self, train_dataset, val_dataset, data_collator): self._train_valid_test_dataset_provider.is_distributed = True super().train(train_dataset, val_dataset, data_collator) - def _prepare_template_data_collator(self): - template = self.template - args = self.args - data_collator = template.data_collator - padding_to = None - if args.tensor_model_parallel_size > 1 and args.sequence_parallel: - padding_to = args.tensor_model_parallel_size - if args.context_parallel_size > 1: - padding_to = (padding_to or 1) * args.context_parallel_size - if args.fp8_format: - padding_to = max((padding_to or 1) * 8, 16) - logger.info(f'padding_to: {padding_to}') - data_collator = partial(data_collator, padding_to=padding_to) - template.data_collator = data_collator - def _init_grpo_params(self): args: MegatronArguments = self.args # distributed params @@ -368,17 +352,16 @@ def _get_rollout_group(self): Get or create the rollout process group (TP×PP×CP). The rollout group is used for: - 1. Data slicing: distributing rollout data across all model parallel ranks (including CP) - 2. Gather operations: collecting results from all model parallel ranks (including CP) + 1. Data slicing: distributing rollout data across ranks with same data samples + 2. Gather operations: collecting results from ranks with same data samples - Note: MODEL_PARALLEL_GROUP only includes TP×PP, but we need TP×PP×CP for correct - data distribution during rollout phase. + Note: Groups are created per data parallel index, containing TP×PP×CP ranks each. + This follows Megatron's data_iterator logic where same data_parallel_rank processes + identical data samples. - Key insight: ranks with the same DP index but different TP/PP/CP indices should be - in the same rollout group. These ranks will: - - During rollout: each process different data slices - - During training: TP/PP ranks process same data (model split), CP ranks process same data (sequence split) - - During gather: collect all data from TP×PP×CP ranks for training + Key insight: ranks with the SAME data parallel index process the SAME data samples + and must coordinate for rollout data distribution. + Megatron rank order: TP → CP → EP → DP → PP """ if self._rollout_group is not None: return self._rollout_group @@ -389,31 +372,38 @@ def _get_rollout_group(self): self._rollout_group = mpu.get_model_parallel_group() return self._rollout_group + # Use RankGenerator to create rollout groups following Megatron-LM logic + global_rank = torch.distributed.get_rank() + # Get parallel dimensions tp_size = mpu.get_tensor_model_parallel_world_size() pp_size = mpu.get_pipeline_model_parallel_world_size() dp_size = mpu.get_data_parallel_world_size() - global_rank = torch.distributed.get_rank() - - # Calculate rollout group size - rollout_group_size = tp_size * pp_size * cp_size - - # Simple and reliable method: assume ranks are organized in contiguous blocks per DP group - # This is typically true for the default order (tp-cp-ep-dp-pp) - # Each DP group has rollout_group_size consecutive ranks - ranks_per_dp_group = rollout_group_size - my_dp_block_index = global_rank // ranks_per_dp_group + cp_size = mpu.get_context_parallel_world_size() - # Calculate the rank range for my rollout group - group_start = my_dp_block_index * ranks_per_dp_group + # Create RankGenerator following Megatron-LM pattern + # Order: tp-cp-ep-dp-pp (default in Megatron-LM) + decoder_rank_generator = mpu.RankGenerator( + tp=tp_size, + ep=1, + dp=dp_size, + pp=pp_size, + cp=cp_size, + order='tp-cp-ep-dp-pp', + rank_offset=0, + ) - # Create all rollout groups (must be done on all ranks) + # Create rollout groups based on data consistency from data_iterator + # Same data_parallel_rank processes same data - group ranks with same DP index if not hasattr(self, '_rollout_groups_created'): - for dp_idx in range(dp_size): - group_start = dp_idx * ranks_per_dp_group - group_ranks = list(range(group_start, min(group_start + ranks_per_dp_group, self.world_size))) - group = torch.distributed.new_group(ranks=group_ranks, group_desc='ROLLOUT_GROUP') - if global_rank in group_ranks: + # Use 'tp-cp-ep-pp' to get groups with same DP index (DP is excluded from variation) + dp_groups = decoder_rank_generator.get_ranks('tp-cp-ep-pp') + for dp_group_ranks in dp_groups: + # Sort for consistency + dp_group_ranks = sorted(dp_group_ranks) + group = torch.distributed.new_group(ranks=dp_group_ranks, group_desc='ROLLOUT_GROUP') + + if global_rank in dp_group_ranks: self._rollout_group = group self._rollout_groups_created = True @@ -488,6 +478,8 @@ def _replace_data_iterator(self, data_iterator, model): def _generate_and_score_completions(self, batch): # Get or create the rollout group (TP×PP×CP) + args = get_args() + rollout_group = self._get_rollout_group() rollout_batch = self.get_local_rollout_batch(batch) @@ -506,7 +498,8 @@ def _get_encoded_batch(rollout_batch, advantages): template = self.template with self._template_context(template): encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch] - encoded_batch = to_device(template.data_collator(encoded_batch), self.device) + encoded_batch = to_device( + template.data_collator(encoded_batch, padding_to=get_padding_to(args)), self.device) labels = encoded_batch['labels'] assert self.template.padding_free position_ids = encoded_batch.get('text_position_ids') diff --git a/swift/megatron/utils/utils.py b/swift/megatron/utils/utils.py index c7b15cf652..ba4af92f7b 100644 --- a/swift/megatron/utils/utils.py +++ b/swift/megatron/utils/utils.py @@ -279,7 +279,7 @@ def forward_step_helper(model, inputs, dtype=None): args = get_args() if mpu.is_pipeline_first_stage(): micro_batch_size = 1 # use qkv_format 'thd' - seq_length = inputs['input_ids'].shape[1] + seq_length = inputs['position_ids'].shape[-1] if args.sequence_parallel: seq_length //= mpu.get_tensor_model_parallel_world_size() recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], From 8406f889db827e755acc58b1f88dda9ce96b71b4 Mon Sep 17 00:00:00 2001 From: Jintao Date: Tue, 18 Nov 2025 22:00:44 +0800 Subject: [PATCH 27/29] [bugfix] fix chatml chat template (#6656) --- swift/llm/template/template/utils.py | 2 +- swift/llm/template/template_meta.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/swift/llm/template/template/utils.py b/swift/llm/template/template/utils.py index c55aaf0ae1..18c5d47079 100644 --- a/swift/llm/template/template/utils.py +++ b/swift/llm/template/template/utils.py @@ -15,7 +15,7 @@ class ChatmlTemplateMeta(TemplateMeta): prefix: Prompt = field(default_factory=list) prompt: Prompt = field(default_factory=lambda: ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n']) chat_sep: Optional[Prompt] = field(default_factory=lambda: ['<|im_end|>\n']) - suffix: Prompt = field(default_factory=lambda: ['<|im_end|>']) + suffix: Prompt = field(default_factory=lambda: ['<|im_end|>\n']) system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n']) auto_add_bos: bool = True diff --git a/swift/llm/template/template_meta.py b/swift/llm/template/template_meta.py index 7db56f705d..ed23c84de9 100644 --- a/swift/llm/template/template_meta.py +++ b/swift/llm/template/template_meta.py @@ -114,18 +114,20 @@ def init(self, tokenizer: PreTrainedTokenizerBase) -> None: value = self._token_attr_to_id(tokenizer, value) setattr(self, key, value) - if self.suffix and self.suffix[-1] not in self.stop_words: - self.stop_words.append(self.suffix[-1]) + suffix_stop = self.suffix[-1] if self.suffix else None + if isinstance(suffix_stop, str): + suffix_stop = suffix_stop.strip() + if suffix_stop and suffix_stop not in self.stop_words: + self.stop_words.append(suffix_stop) if tokenizer.eos_token not in self.stop_words: self.stop_words.append(tokenizer.eos_token) self.stop_token_id = tokenizer.eos_token_id - if self.suffix: - suffix_tokens = self.suffix[-1] - if isinstance(suffix_tokens, str): - stop_token_id = tokenizer.convert_tokens_to_ids(suffix_tokens) - elif isinstance(suffix_tokens, list) and len(suffix_tokens) == 1: - stop_token_id = suffix_tokens[0] + if suffix_stop: + if isinstance(suffix_stop, str): + stop_token_id = tokenizer.convert_tokens_to_ids(suffix_stop) + elif isinstance(suffix_stop, list) and len(suffix_stop) == 1: + stop_token_id = suffix_stop[0] else: stop_token_id = None if stop_token_id is not None: From 71073e65275dc973c4b273008c3d62a8ba80cbc2 Mon Sep 17 00:00:00 2001 From: Jintao Date: Tue, 18 Nov 2025 22:09:16 +0800 Subject: [PATCH 28/29] [bugfix] fix train_type full freeze_llm (#6651) --- swift/llm/model/model_arch.py | 30 +++++++++++++++--------------- swift/llm/train/tuner.py | 7 +++++-- swift/megatron/trainers/utils.py | 5 ++--- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/swift/llm/model/model_arch.py b/swift/llm/model/model_arch.py index fed4ff8055..d62daa25b6 100644 --- a/swift/llm/model/model_arch.py +++ b/swift/llm/model/model_arch.py @@ -337,7 +337,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.llava_hf, - language_model='model.language_model', + language_model=['model.language_model', 'lm_head'], aligner='model.multi_modal_projector', vision_tower='model.vision_tower', )) @@ -362,7 +362,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.llava_next_video_hf, - language_model='model.language_model', + language_model=['model.language_model', 'lm_head'], aligner=['model.multi_modal_projector'], vision_tower='model.vision_tower')) else: @@ -400,7 +400,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.interns1, - language_model='model.language_model', + language_model=['model.language_model', 'lm_head'], aligner='model.multi_modal_projector', vision_tower='model.vision_tower', )) @@ -521,7 +521,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.qwen2_vl, - language_model='model.language_model', + language_model=['model.language_model', 'lm_head'], aligner='model.visual.merger', vision_tower='model.visual', )) @@ -529,7 +529,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.qwen2_vl, - language_model='model', + language_model=['model', 'lm_head'], aligner='visual.merger', vision_tower='visual', )) @@ -537,7 +537,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.qwen3_vl, - language_model='model.language_model', + language_model=['model.language_model', 'lm_head'], aligner=['model.visual.merger', 'model.visual.deepstack_merger_list'], vision_tower='model.visual', )) @@ -545,7 +545,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.qwen2_5_omni, - language_model='thinker.model', + language_model=['thinker.model', 'thinker.lm_head'], vision_tower=['thinker.audio_tower', 'thinker.visual'], aligner=['thinker.audio_tower.proj', 'thinker.visual.merger'], generator=['talker', 'token2wav'], @@ -554,7 +554,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.qwen3_omni, - language_model='thinker.model', + language_model=['thinker.model', 'thinker.lm_head'], vision_tower=['thinker.audio_tower', 'thinker.visual'], aligner=[ 'thinker.audio_tower.proj1', 'thinker.audio_tower.proj2', 'thinker.visual.merger', @@ -574,7 +574,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.step_audio2_mini, - language_model='model', + language_model=['model', 'lm_head'], aligner=['adapter'], vision_tower=['encoder'], )) @@ -589,7 +589,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.glm4_1v, - language_model='model.language_model', + language_model=['model.language_model', 'lm_head'], aligner='model.visual.merger', vision_tower='model.visual', )) @@ -622,7 +622,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.ernie_vl, - language_model='model', + language_model=['model', 'lm_head'], aligner='model.resampler_model', vision_tower='vision_model', )) @@ -631,7 +631,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.llama3_2_vision, - language_model='model.language_model', + language_model=['model.language_model', 'lm_head'], aligner='model.multi_modal_projector', vision_tower='model.vision_model', )) @@ -696,7 +696,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.gemma3n, - language_model='model.language_model', + language_model=['model.language_model', 'lm_head'], aligner=['model.embed_vision', 'model.embed_audio'], vision_tower=['model.vision_tower', 'model.audio_tower'], )) @@ -704,7 +704,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.keye_vl, - language_model='model', + language_model=['model', 'lm_head'], aligner='mlp_AR', vision_tower='visual', )) @@ -717,7 +717,7 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non register_model_arch( MultiModelKeys( MLLMModelArch.llava_onevision1_5, - language_model='model.language_model', + language_model=['model.language_model', 'lm_head'], aligner='model.visual.merger', vision_tower='model.visual', )) diff --git a/swift/llm/train/tuner.py b/swift/llm/train/tuner.py index b8cac28064..286a9f4b04 100644 --- a/swift/llm/train/tuner.py +++ b/swift/llm/train/tuner.py @@ -111,13 +111,16 @@ def get_multimodal_target_regex( res = [] for module in modules: rejected_modules = [] - if not freeze_vit: + if not freeze_vit or not freeze_llm: for aligner in model_arch.aligner: if aligner.startswith(f'{module}.'): rejected_modules.append(aligner) sub_module = deep_getattr(model, module) - target_modules = find_all_linears(sub_module, model_arch, extra_layers) + if isinstance(sub_module, nn.Linear) and module.endswith('lm_head'): + target_modules = [] + else: + target_modules = find_all_linears(sub_module, model_arch, extra_layers) if exclude_router and model.model_info.is_moe_model: target_modules = [tm for tm in target_modules if tm not in {'gate'}] if not target_modules: diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 594561cdd8..9d4c4c96b3 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -111,9 +111,8 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): if cp_size > 1: args = get_args() keys = ['labels', 'attention_mask', 'position_ids', 'loss_scale'] - if args.is_multimodal: - keys.append('decoder_input') - else: + if not args.is_multimodal: + # Multimodal models will handle CP in input_embeds. keys.append('input_ids') packed_seq_params = batch.get('packed_seq_params') From 3ab76c1c6ddcb04e7f1c7bcca426bcfb2a85800c Mon Sep 17 00:00:00 2001 From: Jintao Date: Wed, 19 Nov 2025 02:31:00 +0800 Subject: [PATCH 29/29] [mcore-bridge] optimize gpt_bridge comm (#6659) --- swift/llm/argument/export_args.py | 5 ++ swift/megatron/model/gpt_bridge.py | 74 ++++++++++++++++++------------ 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index 921163e4ae..16568c24b6 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -67,6 +67,11 @@ class ExportArguments(MergeArguments, BaseArguments): to_peft_format: bool = False exist_ok: bool = False + def load_args_from_ckpt(self) -> None: + if self.to_cached_dataset: + return + super().load_args_from_ckpt() + def _init_output_dir(self): if self.output_dir is None: ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}' diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 40aff04568..451925285d 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -64,6 +64,27 @@ def __init__(self, disable_tqmd: bool = False): self.etp_rank = mpu.get_expert_tensor_parallel_rank() self.ep_rank = mpu.get_expert_model_parallel_rank() + dp_size = dist.get_world_size() // self.etp_size // self.ep_size // self.pp_size + expert_decoder_rank_generator = mpu.RankGenerator( + tp=self.etp_size, + ep=self.ep_size, + dp=dp_size, + pp=self.pp_size, + cp=1, + order='tp-cp-ep-dp-pp', + rank_offset=0, + ) + rank = dist.get_rank() + for ranks in expert_decoder_rank_generator.get_ranks('ep-pp'): + group = mpu.create_group( + ranks, + group_desc='EP-PP-GROUP', + ) + if rank in ranks: + self.ep_pp_size = self.ep_size * self.pp_size + self.ep_pp_group = group + self.ep_pp_rank = dist.get_rank(group) + def _init_meta_hf_model(self): with torch.device('meta'): self.hf_model, self.processor = get_model_tokenizer( @@ -198,6 +219,9 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl tensor = None if mg_weight is None else mg_weight.to('cuda') tp_size = self.etp_size if is_expert else self.tp_size tp_group = self.etp_group if is_expert else self.tp_group + pp_group = self.ep_pp_group if is_expert else self.pp_group + pp_size = self.ep_pp_size if is_expert else self.pp_size + pp_rank = self.ep_pp_rank if is_expert else self.pp_rank if tensor is not None and tp_dim is not None and tp_size > 1: if tp_dim == 0: # save memory @@ -220,34 +244,26 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl tensor = torch.cat(output, dim=tp_dim) del output # pp/ep - for parallel_state in ['ep', 'pp']: - if parallel_state == 'pp' and self.pp_size > 1: - parallel_group = self.pp_group - parallel_rank = self.pp_rank - elif parallel_state == 'ep' and is_expert and self.ep_size > 1: - parallel_group = self.ep_group - parallel_rank = self.ep_rank - else: - continue - src_rank = torch.tensor([0 if tensor is None else parallel_rank], dtype=torch.int64, device='cuda') - dist.all_reduce(src_rank, group=parallel_group) - src_rank = dist.get_global_rank(parallel_group, src_rank.item()) + if pp_size > 1: + src_rank = torch.tensor([0 if tensor is None else pp_rank], dtype=torch.int64, device='cuda') + dist.all_reduce(src_rank, group=pp_group) + src_rank = dist.get_global_rank(pp_group, src_rank.item()) meta_data = torch.zeros(10, dtype=torch.int64, device='cuda') dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3} dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} if tensor is None: - dist.broadcast(meta_data, src=src_rank, group=parallel_group) - if meta_data[0].item() > 0: - shape = meta_data[1:1 + meta_data[0]].tolist() - dtype = dtype_mapping_r[meta_data[-1].item()] - tensor = torch.empty(shape, device='cuda', dtype=dtype) - dist.broadcast(tensor, src=src_rank, group=parallel_group) + dist.broadcast(meta_data, src=src_rank, group=pp_group) + assert meta_data[0].item() > 0, f'meta_data: {meta_data}' + shape = meta_data[1:1 + meta_data[0]].tolist() + dtype = dtype_mapping_r[meta_data[-1].item()] + tensor = torch.empty(shape, device='cuda', dtype=dtype) + dist.broadcast(tensor, src=src_rank, group=pp_group) else: meta_data[0] = tensor.ndim meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda') meta_data[-1] = dtype_mapping[tensor.dtype] - dist.broadcast(meta_data, src=src_rank, group=parallel_group) - dist.broadcast(tensor, src=src_rank, group=parallel_group) + dist.broadcast(meta_data, src=src_rank, group=pp_group) + dist.broadcast(tensor, src=src_rank, group=pp_group) assert tensor is not None, f'mg_key: {mg_key}' if offset: tensor = tensor + offset @@ -273,10 +289,10 @@ def _set_state_dict(self, is_modules_to_save = isinstance(sub_module, ModulesToSaveWrapper) if not to_mcore: state = torch.tensor([is_lora, is_modules_to_save], dtype=torch.bool, device='cuda') - if self.pp_size > 1: + if is_expert and self.ep_pp_size > 1: + dist.all_reduce(state, group=self.ep_pp_group) + elif not is_expert and self.pp_size > 1: dist.all_reduce(state, group=self.pp_group) - if is_expert and self.ep_size > 1: - dist.all_reduce(state, group=self.ep_group) is_lora, is_modules_to_save = state if is_lora and self._is_peft_format and param_key != 'layer_norm_weight': if to_mcore: @@ -627,10 +643,10 @@ def _set_mlp_state(self, is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1, LoraParallelLinear) and self._is_peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') - if self.pp_size > 1: + if is_expert and self.ep_pp_size > 1: + dist.all_reduce(is_lora, group=self.ep_pp_group) + elif not is_expert and self.pp_size > 1: dist.all_reduce(is_lora, group=self.pp_group) - if is_expert and self.ep_size > 1: - dist.all_reduce(is_lora, group=self.ep_group) if is_lora: assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' if mg_mlp is None: @@ -779,10 +795,10 @@ def _set_mlp_state(self, is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc2, LoraParallelLinear) and self._is_peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') - if self.pp_size > 1: + if is_expert and self.ep_pp_size > 1: + dist.all_reduce(is_lora, group=self.ep_pp_group) + elif not is_expert and self.pp_size > 1: dist.all_reduce(is_lora, group=self.pp_group) - if is_expert and self.ep_size > 1: - dist.all_reduce(is_lora, group=self.ep_group) if is_lora: assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' if mg_mlp is None: