diff --git a/3-atari/2-ppo.py b/3-atari/2-ppo.py index 2542f29..9ab2d2c 100644 --- a/3-atari/2-ppo.py +++ b/3-atari/2-ppo.py @@ -17,7 +17,7 @@ SAVE_PATH = "atari_ppo.pt" -TOTAL_FRAMES = 5_000_000 +TOTAL_FRAMES = 10_000_000 N_ENVS = 8 ROLLOUT_STEPS = 128 # batch = N_ENVS * ROLLOUT_STEPS = 1024 EPOCHS = 4 @@ -113,6 +113,11 @@ def policy_action(obs): ep_returns = [] for update in range(1, n_updates + 1): + # Linear LR anneal from LR -> 0 over the run (CleanRL convention). + lr_now = LR * (1.0 - (update - 1) / n_updates) + for g in optimizer.param_groups: + g["lr"] = lr_now + obs_buf = np.zeros((ROLLOUT_STEPS, N_ENVS, *obs_shape), dtype=np.uint8) act_buf = np.zeros((ROLLOUT_STEPS, N_ENVS), dtype=np.int64) logp_buf = np.zeros((ROLLOUT_STEPS, N_ENVS), dtype=np.float32) @@ -151,12 +156,12 @@ def policy_action(obs): obs_t = torch.as_tensor(np.asarray(obs), device=device) _, last_value = model(obs_t) advantages, returns = compute_gae(rew_buf, val_buf, done_buf, last_value.cpu().numpy()) - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Flatten (T, N_ENVS, ...) -> (T*N_ENVS, ...) obs_t = torch.as_tensor(obs_buf.reshape(batch_size, *obs_shape), device=device) act_t = torch.as_tensor(act_buf.reshape(batch_size), device=device) old_logp_t = torch.as_tensor(logp_buf.reshape(batch_size), device=device) + old_val_t = torch.as_tensor(val_buf.reshape(batch_size), device=device) adv_t = torch.as_tensor(advantages.reshape(batch_size), device=device) ret_t = torch.as_tensor(returns.reshape(batch_size), device=device) @@ -173,11 +178,22 @@ def policy_action(obs): new_logp = dist.log_prob(act_t[mb]) entropy = dist.entropy().mean() + # Advantage normalization per minibatch (CleanRL convention). + mb_adv = adv_t[mb] + mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8) + ratio = (new_logp - old_logp_t[mb]).exp() - unclipped = ratio * adv_t[mb] - clipped = torch.clamp(ratio, 1 - CLIP_COEF, 1 + CLIP_COEF) * adv_t[mb] + unclipped = ratio * mb_adv + clipped = torch.clamp(ratio, 1 - CLIP_COEF, 1 + CLIP_COEF) * mb_adv policy_loss = -torch.min(unclipped, clipped).mean() - value_loss = (values - ret_t[mb]).pow(2).mean() + + # Value loss with clipping around the old value prediction. + v_clipped = old_val_t[mb] + torch.clamp( + values - old_val_t[mb], -CLIP_COEF, CLIP_COEF) + vl_unclipped = (values - ret_t[mb]).pow(2) + vl_clipped = (v_clipped - ret_t[mb]).pow(2) + value_loss = 0.5 * torch.max(vl_unclipped, vl_clipped).mean() + loss = policy_loss + VALUE_COEF * value_loss - ENTROPY_COEF * entropy optimizer.zero_grad() @@ -201,6 +217,7 @@ def policy_action(obs): "policy_loss": pl_sum / n_mb, "value_loss": vl_sum / n_mb, "entropy": ent_sum / n_mb, + "lr": lr_now, } if ep_returns: log["recent_mean_return"] = float(np.mean(ep_returns[-20:]))