diff --git a/README.md b/README.md index 048800de670..de9c43af49f 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,7 @@ Trinity-RFT provides functionalities for users with different backgrounds and ob | CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` | | SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` | | On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k)] | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` | +| JSD (Jensen-Shannon Divergence) | [[GSM8K Example](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k/opd_gsm8k_jsd.yaml)] | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/jsd_advantage.py)] | `algorithm_type: jsd` | --- diff --git a/README_zh.md b/README_zh.md index 657d4f6bcde..5642481dd2d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -138,6 +138,8 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: | CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` | | SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` | | On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k)] | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` | +| JSD(Jensen-Shannon 散度) | [[GSM8K 示例](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k/opd_gsm8k_jsd.yaml)] | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/jsd_advantage.py)] | `algorithm_type: jsd` | + --- diff --git a/docs/sphinx_doc/assets/jsd_acc.png b/docs/sphinx_doc/assets/jsd_acc.png new file mode 100644 index 00000000000..51dee1d6ef1 Binary files /dev/null and b/docs/sphinx_doc/assets/jsd_acc.png differ diff --git a/docs/sphinx_doc/assets/jsd_kl.png b/docs/sphinx_doc/assets/jsd_kl.png new file mode 100644 index 00000000000..8e75512549f Binary files /dev/null and b/docs/sphinx_doc/assets/jsd_kl.png differ diff --git a/docs/sphinx_doc/source/main.md b/docs/sphinx_doc/source/main.md index 6340212f4a9..4caf6987276 100644 --- a/docs/sphinx_doc/source/main.md +++ b/docs/sphinx_doc/source/main.md @@ -89,6 +89,7 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor | CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` | | SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` | | On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k)] | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` | +| JSD (Jensen-Shannon Divergence) | [[GSM8K Example](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k/opd_gsm8k_jsd.yaml)] | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/jsd_advantage.py)] | `algorithm_type: jsd` | diff --git a/docs/sphinx_doc/source_zh/main.md b/docs/sphinx_doc/source_zh/main.md index 05ab256e15b..1252742d4ec 100644 --- a/docs/sphinx_doc/source_zh/main.md +++ b/docs/sphinx_doc/source_zh/main.md @@ -85,6 +85,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: | CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` | | SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` | | On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k)] | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` | +| JSD(Jensen-Shannon 散度) | [[GSM8K 示例](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k/opd_gsm8k_jsd.yaml)] | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/jsd_advantage.py)] | `algorithm_type: jsd` | diff --git a/examples/opd_gsm8k/README.md b/examples/opd_gsm8k/README.md index eb0f9d6dadf..bd911a994a1 100644 --- a/examples/opd_gsm8k/README.md +++ b/examples/opd_gsm8k/README.md @@ -8,27 +8,59 @@ On-Policy Distillation is a knowledge distillation method, where in this example 3. The advantage is computed as: `advantages = kl_coef * (teacher_logprobs - student_logprobs)` 4. The student model is trained to minimize this KL divergence, effectively learning from the teacher -## Key Configuration +## Algorithm Variants + +### 1. On-Policy Distill (KL-based) +Uses KL divergence for advantage computation: - **Algorithm**: `on_policy_distill` - **Workflow**: `on_policy_distill_workflow` +- **Advantage**: `advantages = kl_coef * (teacher_logprobs - student_logprobs)` +- **Config file**: `opd_gsm8k.yaml` + +### 2. JSD (Jensen-Shannon Divergence) + +Uses Jensen-Shannon Divergence for advantage computation: +- **Algorithm**: `jsd` +- **Workflow**: `on_policy_distill_math_workflow` +- **Advantage**: `advantages = -kl_coef * JSD`, where `JSD(P||Q) = lambda_coef * KL(P||M) + (1-lambda_coef) * KL(Q||M)` and `M = (P+Q)/2` +- **Parameters**: + - `kl_coef`: Overall scaling coefficient for advantages (default: 1.0) + - `lambda_coef`: Weight for mixing KL(P||M) and KL(Q||M) in JSD (default: 0.5). When lambda=0.5, this gives the standard symmetric JSD. +- **Config file**: `opd_gsm8k_jsd.yaml` + +## Key Configuration + - **Student Model**: `Qwen/Qwen2.5-1.5B-Instruct` - **Teacher Model**: `Qwen/Qwen2.5-Math-7B-Instruct` (configured as auxiliary model) ## Running the Example Download the model checkpoint and modify your config file, then run: + +For KL-based OPD: ```bash trinity run examples/opd_gsm8k/opd_gsm8k.yaml ``` +For JSD-based OPD: +```bash +trinity run examples/opd_gsm8k/opd_gsm8k_jsd.yaml +``` + Then you are all set! It should be pretty simple😄, and the training should converge very quick. +### KL-based OPD Results + ![](../../docs/sphinx_doc/assets/opd_acc.png) ![](../../docs/sphinx_doc/assets/opd_kl.png) +### JSD-based OPD Results + +![](../../docs/sphinx_doc/assets/jsd_acc.png) +![](../../docs/sphinx_doc/assets/jsd_kl.png) ## References diff --git a/examples/opd_gsm8k/opd_gsm8k_jsd.yaml b/examples/opd_gsm8k/opd_gsm8k_jsd.yaml new file mode 100644 index 00000000000..b8670d28638 --- /dev/null +++ b/examples/opd_gsm8k/opd_gsm8k_jsd.yaml @@ -0,0 +1,75 @@ +project: "Trinity-RFT-gsm8k-opd" +name: "qwen2.5-1.5B-distill-from-math-7B-lr1e-5" +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: jsd + repeat_times: 8 + optimizer: + lr: 1e-5 + advantage_fn_args: + kl_coef: 1.0 + lambda_coef: 0.5 +model: + # Student model + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct} + max_response_tokens: 1024 + max_model_len: 2048 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 96 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k} + subset_name: main + split: train + format: + prompt_key: 'question' + response_key: 'answer' + rollout_args: + temperature: 1.0 + # Use on_policy_distill_math_workflow for Qwen2.5-Math style format with accuracy reward + default_workflow_type: 'on_policy_distill_math_workflow' + trainer_input: + experience_buffer: + name: gsm8k_opd_buffer + storage_type: queue +explorer: + eval_interval: 50 + runner_per_model: 8 + rollout_model: + # Student model for rollout + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 + auxiliary_models: + # Teacher model for distillation + - model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-Math-7B-Instruct} + engine_num: 1 + tensor_parallel_size: 2 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 + max_model_len: 4096 + max_prompt_tokens: 2048 + max_response_tokens: 1024 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + save_interval: 100 + grad_clip: 1.0 + use_dynamic_bsz: true + max_token_len_per_gpu: 16384 + ulysses_sequence_parallel_size: 1 +monitor: + monitor_type: wandb diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 979d08a779b..52bb605bcd1 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -28,6 +28,7 @@ "rec": "trinity.algorithm.algorithm.RECAlgorithm", "multi_step_grpo": "trinity.algorithm.algorithm.MultiStepGRPOAlgorithm", "on_policy_distill": "trinity.algorithm.algorithm.OnPolicyDistillAlgorithm", + "jsd": "trinity.algorithm.algorithm.JSDAlgorithm", }, ) diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 3b4dfe887e2..239862ba580 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -18,6 +18,7 @@ "asymre_verl": "trinity.algorithm.advantage_fn.asymre_advantage.ASYMREAdvantageFn", "rec": "trinity.algorithm.advantage_fn.rec_advantage.RECGroupedAdvantage", "on_policy_distill": "trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage", + "jsd": "trinity.algorithm.advantage_fn.jsd_advantage.JSDAdvantage", }, ) diff --git a/trinity/algorithm/advantage_fn/jsd_advantage.py b/trinity/algorithm/advantage_fn/jsd_advantage.py new file mode 100644 index 00000000000..75ffe6894de --- /dev/null +++ b/trinity/algorithm/advantage_fn/jsd_advantage.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +"""Jensen-Shannon Divergence (JSD) advantage computation. + +JSD(P||Q) = beta * KL(teacher||M) + (1-beta) * KL(student||M), where M = beta*teacher + (1-beta)*student. +When beta=0.5, this gives the standard symmetric JSD. All computations in log-space (no exp). +Aligned with SWIFT: beta=0/1 yield pure KL; temperature and optional chunking supported. +""" + +from typing import Dict, Optional, Tuple + +import torch +from verl import DataProto + +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn + + +class JSDAdvantage(AdvantageFn): + """Advantage function using Jensen-Shannon Divergence (log-space, SWIFT-aligned). + + Computes JSD in log-space only: + - beta=0: JSD = KL(student || teacher) [pure KL] + - beta=1: JSD = KL(teacher || student) [pure KL] + - else: JSD = beta*KL(teacher||M) + (1-beta)*KL(student||M), M = mixture in log-space. + + The teacher_logprobs should be stored in Experience.teacher_logprobs + by the workflow during exploration. + """ + + def __init__( + self, + lambda_coef: float = 0.5, + kl_coef: float = 1.0, + temperature: float = 1.0, + chunk_size: Optional[int] = None, + ) -> None: + """Initialize JSD advantage function. + + Args: + lambda_coef: Weight beta for mixture. JSD = beta*KL(teacher||M) + (1-beta)*KL(student||M). + beta=0 => KL(student||teacher), beta=1 => KL(teacher||student). Range: [0, 1]. + kl_coef: Overall scaling coefficient for advantages. + temperature: Temperature scaling for log-probs (log_probs / temperature). 1.0 = no scaling. + chunk_size: If set, process flattened valid tokens in chunks to reduce peak memory; None = no chunking. + """ + self.lambda_coef = lambda_coef + self.kl_coef = kl_coef + self.temperature = temperature + self.chunk_size = chunk_size + + def _js_divergence_per_token( + self, + student_logprobs: torch.Tensor, + teacher_logprobs: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Compute JSD per token in log-space only (no exp). + - beta=0: KL(student || teacher) = student_logprobs - teacher_logprobs + - beta=1: KL(teacher || student) = teacher_logprobs - student_logprobs + - else: mixture log_m = logsumexp([s+log(1-beta), t+log(beta)]); JSD = beta*(t-log_m) + (1-beta)*(s-log_m). + """ + beta = self.lambda_coef + s = student_logprobs + t = teacher_logprobs + + if beta == 0: + # Pure KL(student || teacher) + jsd = s - t + elif beta == 1: + # Pure KL(teacher || student) + jsd = t - s + else: + # Precompute log coefficients once + dtype, device = s.dtype, s.device + beta_t = torch.tensor(beta, dtype=dtype, device=device) + log_beta = torch.log(beta_t) + log_1_minus_beta = torch.log1p(-beta_t) + + # log(mixture) = log(beta*teacher + (1-beta)*student) + mixture_log = torch.logsumexp( + torch.stack([s + log_1_minus_beta, t + log_beta]), + dim=0, + ) + # KL(teacher || mixture) = t - log_m, KL(student || mixture) = s - log_m (log-space, no exp) + kl_teacher = t - mixture_log + kl_student = s - mixture_log + jsd = beta_t * kl_teacher + (1 - beta_t) * kl_student + + if mask is not None: + jsd = jsd * mask + return jsd + + def _js_divergence_per_token_chunked( + self, + student_logprobs: torch.Tensor, + teacher_logprobs: torch.Tensor, + response_mask: torch.Tensor, + ) -> torch.Tensor: + """Compute JSD per token with optional chunking over valid positions (for memory).""" + flat_s = student_logprobs.reshape(-1) + flat_t = teacher_logprobs.reshape(-1) + flat_mask = response_mask.reshape(-1) + valid = flat_mask > 0 + n_valid = valid.sum().item() + if n_valid == 0: + return (flat_s * 0).reshape_as(response_mask) + + s_valid = flat_s[valid] + t_valid = flat_t[valid] + chunk_size = self.chunk_size or n_valid + beta = self.lambda_coef + + if beta == 0: + jsd_valid = s_valid - t_valid + elif beta == 1: + jsd_valid = t_valid - s_valid + else: + dtype, device = s_valid.dtype, s_valid.device + beta_t = torch.tensor(beta, dtype=dtype, device=device) + log_beta = torch.log(beta_t) + log_1_minus_beta = torch.log1p(-beta_t) + jsd_valid = s_valid.new_zeros(s_valid.shape) + for start in range(0, n_valid, chunk_size): + end = min(start + chunk_size, n_valid) + s_chunk = s_valid[start:end] + t_chunk = t_valid[start:end] + mixture_log = torch.logsumexp( + torch.stack([s_chunk + log_1_minus_beta, t_chunk + log_beta]), + dim=0, + ) + kl_t = t_chunk - mixture_log + kl_s = s_chunk - mixture_log + jsd_valid[start:end] = beta_t * kl_t + (1 - beta_t) * kl_s + + out = flat_s.new_zeros(flat_s.shape) + out[valid] = jsd_valid + return out.reshape_as(response_mask) + + def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]: + """Compute advantages using JSD. + + Advantages are computed directly from JSD: advantages = -kl_coef * JSD + Since we want to minimize JSD, we use negative JSD as advantage. + Lower JSD (better alignment with teacher) → higher advantage. + The advantage guides the policy gradient to reduce JSD. + + Args: + exps: DataProto containing: + - old_log_probs: student's sampling logprobs [batch, seq] + - teacher_logprobs: teacher's logprobs [batch, seq] + - response_mask: mask for response tokens [batch, seq] + + Returns: + exps: DataProto with advantages and returns added + metrics: Dict with jsd and advantage statistics + """ + metrics = {} + + old_log_probs = exps.batch["old_log_probs"] # student sampling logprobs + teacher_log_probs = exps.batch["teacher_logprobs"] + response_mask = exps.batch["response_mask"] + + # Temperature scaling (align with SWIFT: logits / T => apply to log-probs here) + if self.temperature != 1.0: + old_log_probs = old_log_probs / self.temperature + teacher_log_probs = teacher_log_probs / self.temperature + + # Compute JSD per token (with optional chunking for memory) + if self.chunk_size is not None: + jsd_per_token = self._js_divergence_per_token_chunked( + old_log_probs, teacher_log_probs, response_mask + ) + else: + jsd_per_token = self._js_divergence_per_token( + old_log_probs, teacher_log_probs, mask=response_mask + ) + + # For advantage function, use JSD directly + # Since we want to minimize JSD, we use negative JSD as advantage + advantages = -self.kl_coef * jsd_per_token + advantages = advantages * response_mask + + exps.batch["advantages"] = advantages + exps.batch["returns"] = advantages.clone() + + # JSD metrics (over valid tokens) + jsd_sum = (jsd_per_token * response_mask).sum(dim=-1) + metrics["jsd/mean"] = jsd_sum.mean().item() + metrics["jsd/std"] = jsd_sum.std().item() if jsd_sum.numel() > 1 else 0.0 + + metrics["advantages/mean"] = advantages.sum(dim=-1).mean().item() + + return exps, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "lambda_coef": 0.5, + "kl_coef": 1.0, + "temperature": 1.0, + "chunk_size": None, + } diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 5206d3c5138..5352bba449d 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -513,3 +513,30 @@ def default_config(cls) -> Dict: "kl_loss_fn": "none", "entropy_loss_fn": "none", } + + +class JSDAlgorithm(AlgorithmType): + """JSD (Jensen-Shannon Divergence) Algorithm. + + Uses JSD between teacher and student for distillation. + Same structure as On-Policy Distill but with JSD advantage function. + """ + + use_critic: bool = False + use_reference: bool = False + compute_advantage_in_trainer: bool = True # advantage_fn computes JSD from teacher_logprobs + can_balance_batch: bool = True + schema: str = "experience" + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 8, + "advantage_fn": "jsd", + "advantage_fn_args": {"kl_coef": 1.0, "lambda_coef": 0.5}, + "sample_strategy": "default", + "policy_loss_fn": "ppo", + "kl_penalty_fn": "none", + "kl_loss_fn": "none", + "entropy_loss_fn": "none", + }