Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ Trinity-RFT provides functionalities for users with different backgrounds and ob
| CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k)] | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |
| JSD (Jensen-Shannon Divergence) | [[GSM8K Example](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k/opd_gsm8k_jsd.yaml)] | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/jsd_advantage.py)] | `algorithm_type: jsd` |


---
Expand Down
2 changes: 2 additions & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能:
| CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k)] | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |
| JSD(Jensen-Shannon 散度) | [[GSM8K 示例](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k/opd_gsm8k_jsd.yaml)] | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/jsd_advantage.py)] | `algorithm_type: jsd` |



---
Expand Down
Binary file added docs/sphinx_doc/assets/jsd_acc.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/sphinx_doc/assets/jsd_kl.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/sphinx_doc/source/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor
| CISPO [[Paper](https://arxiv.org/pdf/2506.13585)] | - | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[Paper](https://arxiv.org/pdf/2511.20347)] | - | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[Paper](https://arxiv.org/pdf/2306.13649)] | [[GSM8K Example](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k)] | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |
| JSD (Jensen-Shannon Divergence) | [[GSM8K Example](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k/opd_gsm8k_jsd.yaml)] | [[Code](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/jsd_advantage.py)] | `algorithm_type: jsd` |



Expand Down
1 change: 1 addition & 0 deletions docs/sphinx_doc/source_zh/main.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能:
| CISPO [[论文](https://arxiv.org/pdf/2506.13585)] | - | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py)] | `algorithm_type: cispo` |
| SAPO [[论文](https://arxiv.org/pdf/2511.20347)] | - | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py)] | `algorithm_type: sapo` |
| On-Policy Distillation [[博客](https://thinkingmachines.ai/blog/on-policy-distillation/)] [[论文](https://arxiv.org/pdf/2306.13649)] | [[GSM8K 示例](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k)] | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/common/workflows/on_policy_distill_workflow.py)] | `algorithm_type: on_policy_distill` |
| JSD(Jensen-Shannon 散度) | [[GSM8K 示例](https://github.com/agentscope-ai/Trinity-RFT/tree/main/examples/opd_gsm8k/opd_gsm8k_jsd.yaml)] | [[代码](https://github.com/agentscope-ai/Trinity-RFT/tree/main/trinity/algorithm/advantage_fn/jsd_advantage.py)] | `algorithm_type: jsd` |



Expand Down
34 changes: 33 additions & 1 deletion examples/opd_gsm8k/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,59 @@ On-Policy Distillation is a knowledge distillation method, where in this example
3. The advantage is computed as: `advantages = kl_coef * (teacher_logprobs - student_logprobs)`
4. The student model is trained to minimize this KL divergence, effectively learning from the teacher

## Key Configuration
## Algorithm Variants

### 1. On-Policy Distill (KL-based)

Uses KL divergence for advantage computation:
- **Algorithm**: `on_policy_distill`
- **Workflow**: `on_policy_distill_workflow`
- **Advantage**: `advantages = kl_coef * (teacher_logprobs - student_logprobs)`
- **Config file**: `opd_gsm8k.yaml`

### 2. JSD (Jensen-Shannon Divergence)

Uses Jensen-Shannon Divergence for advantage computation:
- **Algorithm**: `jsd`
- **Workflow**: `on_policy_distill_math_workflow`
- **Advantage**: `advantages = -kl_coef * JSD`, where `JSD(P||Q) = lambda_coef * KL(P||M) + (1-lambda_coef) * KL(Q||M)` and `M = (P+Q)/2`
- **Parameters**:
- `kl_coef`: Overall scaling coefficient for advantages (default: 1.0)
- `lambda_coef`: Weight for mixing KL(P||M) and KL(Q||M) in JSD (default: 0.5). When lambda=0.5, this gives the standard symmetric JSD.
- **Config file**: `opd_gsm8k_jsd.yaml`

## Key Configuration

- **Student Model**: `Qwen/Qwen2.5-1.5B-Instruct`
- **Teacher Model**: `Qwen/Qwen2.5-Math-7B-Instruct` (configured as auxiliary model)

## Running the Example

Download the model checkpoint and modify your config file, then run:

For KL-based OPD:
```bash
trinity run examples/opd_gsm8k/opd_gsm8k.yaml
```

For JSD-based OPD:
```bash
trinity run examples/opd_gsm8k/opd_gsm8k_jsd.yaml
```

Then you are all set! It should be pretty simple😄, and the training should converge very quick.



### KL-based OPD Results

![](../../docs/sphinx_doc/assets/opd_acc.png)
![](../../docs/sphinx_doc/assets/opd_kl.png)

### JSD-based OPD Results

![](../../docs/sphinx_doc/assets/jsd_acc.png)
![](../../docs/sphinx_doc/assets/jsd_kl.png)

## References

Expand Down
75 changes: 75 additions & 0 deletions examples/opd_gsm8k/opd_gsm8k_jsd.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
project: "Trinity-RFT-gsm8k-opd"
name: "qwen2.5-1.5B-distill-from-math-7B-lr1e-5"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: jsd
repeat_times: 8
optimizer:
lr: 1e-5
advantage_fn_args:
kl_coef: 1.0
lambda_coef: 0.5
model:
# Student model
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
max_response_tokens: 1024
max_model_len: 2048
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 1
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH,openai/gsm8k}
subset_name: main
split: train
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
# Use on_policy_distill_math_workflow for Qwen2.5-Math style format with accuracy reward
default_workflow_type: 'on_policy_distill_math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_opd_buffer
storage_type: queue
explorer:
eval_interval: 50
runner_per_model: 8
rollout_model:
# Student model for rollout
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
auxiliary_models:
# Teacher model for distillation
- model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-Math-7B-Instruct}
engine_num: 1
tensor_parallel_size: 2
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
max_model_len: 4096
max_prompt_tokens: 2048
max_response_tokens: 1024
synchronizer:
sync_method: 'nccl'
sync_interval: 1
sync_timeout: 1200
trainer:
save_interval: 100
grad_clip: 1.0
use_dynamic_bsz: true
max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
monitor:
monitor_type: wandb
1 change: 1 addition & 0 deletions trinity/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"rec": "trinity.algorithm.algorithm.RECAlgorithm",
"multi_step_grpo": "trinity.algorithm.algorithm.MultiStepGRPOAlgorithm",
"on_policy_distill": "trinity.algorithm.algorithm.OnPolicyDistillAlgorithm",
"jsd": "trinity.algorithm.algorithm.JSDAlgorithm",
},
)

Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/advantage_fn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"asymre_verl": "trinity.algorithm.advantage_fn.asymre_advantage.ASYMREAdvantageFn",
"rec": "trinity.algorithm.advantage_fn.rec_advantage.RECGroupedAdvantage",
"on_policy_distill": "trinity.algorithm.advantage_fn.on_policy_distill_advantage.OnPolicyDistillAdvantage",
"jsd": "trinity.algorithm.advantage_fn.jsd_advantage.JSDAdvantage",
},
)

Expand Down
202 changes: 202 additions & 0 deletions trinity/algorithm/advantage_fn/jsd_advantage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# -*- coding: utf-8 -*-
"""Jensen-Shannon Divergence (JSD) advantage computation.

JSD(P||Q) = beta * KL(teacher||M) + (1-beta) * KL(student||M), where M = beta*teacher + (1-beta)*student.
When beta=0.5, this gives the standard symmetric JSD. All computations in log-space (no exp).
Aligned with SWIFT: beta=0/1 yield pure KL; temperature and optional chunking supported.
"""

from typing import Dict, Optional, Tuple

import torch
from verl import DataProto

from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn


class JSDAdvantage(AdvantageFn):
"""Advantage function using Jensen-Shannon Divergence (log-space, SWIFT-aligned).

Computes JSD in log-space only:
- beta=0: JSD = KL(student || teacher) [pure KL]
- beta=1: JSD = KL(teacher || student) [pure KL]
- else: JSD = beta*KL(teacher||M) + (1-beta)*KL(student||M), M = mixture in log-space.

The teacher_logprobs should be stored in Experience.teacher_logprobs
by the workflow during exploration.
"""

def __init__(
self,
lambda_coef: float = 0.5,
kl_coef: float = 1.0,
temperature: float = 1.0,
chunk_size: Optional[int] = None,
) -> None:
"""Initialize JSD advantage function.

Args:
lambda_coef: Weight beta for mixture. JSD = beta*KL(teacher||M) + (1-beta)*KL(student||M).
beta=0 => KL(student||teacher), beta=1 => KL(teacher||student). Range: [0, 1].
kl_coef: Overall scaling coefficient for advantages.
temperature: Temperature scaling for log-probs (log_probs / temperature). 1.0 = no scaling.
chunk_size: If set, process flattened valid tokens in chunks to reduce peak memory; None = no chunking.
"""
self.lambda_coef = lambda_coef
self.kl_coef = kl_coef
self.temperature = temperature
self.chunk_size = chunk_size

def _js_divergence_per_token(
self,
student_logprobs: torch.Tensor,
teacher_logprobs: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Compute JSD per token in log-space only (no exp).
- beta=0: KL(student || teacher) = student_logprobs - teacher_logprobs
- beta=1: KL(teacher || student) = teacher_logprobs - student_logprobs
- else: mixture log_m = logsumexp([s+log(1-beta), t+log(beta)]); JSD = beta*(t-log_m) + (1-beta)*(s-log_m).
"""
beta = self.lambda_coef
s = student_logprobs
t = teacher_logprobs

if beta == 0:
# Pure KL(student || teacher)
jsd = s - t
elif beta == 1:
# Pure KL(teacher || student)
jsd = t - s
else:
# Precompute log coefficients once
dtype, device = s.dtype, s.device
beta_t = torch.tensor(beta, dtype=dtype, device=device)
log_beta = torch.log(beta_t)
log_1_minus_beta = torch.log1p(-beta_t)

# log(mixture) = log(beta*teacher + (1-beta)*student)
mixture_log = torch.logsumexp(
torch.stack([s + log_1_minus_beta, t + log_beta]),
dim=0,
)
# KL(teacher || mixture) = t - log_m, KL(student || mixture) = s - log_m (log-space, no exp)
kl_teacher = t - mixture_log
kl_student = s - mixture_log
jsd = beta_t * kl_teacher + (1 - beta_t) * kl_student

if mask is not None:
jsd = jsd * mask
return jsd

def _js_divergence_per_token_chunked(
self,
student_logprobs: torch.Tensor,
teacher_logprobs: torch.Tensor,
response_mask: torch.Tensor,
) -> torch.Tensor:
"""Compute JSD per token with optional chunking over valid positions (for memory)."""
flat_s = student_logprobs.reshape(-1)
flat_t = teacher_logprobs.reshape(-1)
flat_mask = response_mask.reshape(-1)
valid = flat_mask > 0
n_valid = valid.sum().item()
if n_valid == 0:
return (flat_s * 0).reshape_as(response_mask)

s_valid = flat_s[valid]
t_valid = flat_t[valid]
chunk_size = self.chunk_size or n_valid
beta = self.lambda_coef

if beta == 0:
jsd_valid = s_valid - t_valid
elif beta == 1:
jsd_valid = t_valid - s_valid
else:
dtype, device = s_valid.dtype, s_valid.device
beta_t = torch.tensor(beta, dtype=dtype, device=device)
log_beta = torch.log(beta_t)
log_1_minus_beta = torch.log1p(-beta_t)
jsd_valid = s_valid.new_zeros(s_valid.shape)
for start in range(0, n_valid, chunk_size):
end = min(start + chunk_size, n_valid)
s_chunk = s_valid[start:end]
t_chunk = t_valid[start:end]
mixture_log = torch.logsumexp(
torch.stack([s_chunk + log_1_minus_beta, t_chunk + log_beta]),
dim=0,
)
kl_t = t_chunk - mixture_log
kl_s = s_chunk - mixture_log
jsd_valid[start:end] = beta_t * kl_t + (1 - beta_t) * kl_s

out = flat_s.new_zeros(flat_s.shape)
out[valid] = jsd_valid
return out.reshape_as(response_mask)

def __call__(self, exps: DataProto, **kwargs) -> Tuple[DataProto, Dict]:
"""Compute advantages using JSD.

Advantages are computed directly from JSD: advantages = -kl_coef * JSD
Since we want to minimize JSD, we use negative JSD as advantage.
Lower JSD (better alignment with teacher) → higher advantage.
The advantage guides the policy gradient to reduce JSD.

Args:
exps: DataProto containing:
- old_log_probs: student's sampling logprobs [batch, seq]
- teacher_logprobs: teacher's logprobs [batch, seq]
- response_mask: mask for response tokens [batch, seq]

Returns:
exps: DataProto with advantages and returns added
metrics: Dict with jsd and advantage statistics
"""
metrics = {}

old_log_probs = exps.batch["old_log_probs"] # student sampling logprobs
teacher_log_probs = exps.batch["teacher_logprobs"]
response_mask = exps.batch["response_mask"]

# Temperature scaling (align with SWIFT: logits / T => apply to log-probs here)
if self.temperature != 1.0:
old_log_probs = old_log_probs / self.temperature
teacher_log_probs = teacher_log_probs / self.temperature

# Compute JSD per token (with optional chunking for memory)
if self.chunk_size is not None:
jsd_per_token = self._js_divergence_per_token_chunked(
old_log_probs, teacher_log_probs, response_mask
)
else:
jsd_per_token = self._js_divergence_per_token(
old_log_probs, teacher_log_probs, mask=response_mask
)

# For advantage function, use JSD directly
# Since we want to minimize JSD, we use negative JSD as advantage
advantages = -self.kl_coef * jsd_per_token
advantages = advantages * response_mask

exps.batch["advantages"] = advantages
exps.batch["returns"] = advantages.clone()

# JSD metrics (over valid tokens)
jsd_sum = (jsd_per_token * response_mask).sum(dim=-1)
metrics["jsd/mean"] = jsd_sum.mean().item()
metrics["jsd/std"] = jsd_sum.std().item() if jsd_sum.numel() > 1 else 0.0

metrics["advantages/mean"] = advantages.sum(dim=-1).mean().item()

return exps, metrics

@classmethod
def default_args(cls) -> Dict:
return {
"lambda_coef": 0.5,
"kl_coef": 1.0,
"temperature": 1.0,
"chunk_size": None,
}
Loading