From 8a398a5aac3da5a15f7d5490d409be737968485b Mon Sep 17 00:00:00 2001 From: dnddnjs Date: Mon, 18 May 2026 06:59:54 +0900 Subject: [PATCH 1/2] PPO tuning: LR anneal, value clipping, per-minibatch adv norm, 10M frames MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three of CleanRL's 'PPO 37 details' that were missing — flagged when the 5M and 10M Breakout runs both plateaued at per-game ~75 with entropy stuck around 0.8 (policy wasn't sharpening, clip rarely activating): - Linear LR anneal from 2.5e-4 -> 0 across the run; lets late updates fine-tune instead of bouncing. - Value-function loss clipping around the old prediction (CLIP_COEF), matching the policy clipping range; stabilizes value targets. - Advantage normalization moved inside the minibatch loop instead of once per batch. Also bumps TOTAL_FRAMES 5M -> 10M to match the CleanRL Atari budget so runs are directly comparable to their published curves. lr now logged to wandb so the anneal is visible. --- 3-atari/2-ppo.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/3-atari/2-ppo.py b/3-atari/2-ppo.py index 2542f29..9ab2d2c 100644 --- a/3-atari/2-ppo.py +++ b/3-atari/2-ppo.py @@ -17,7 +17,7 @@ SAVE_PATH = "atari_ppo.pt" -TOTAL_FRAMES = 5_000_000 +TOTAL_FRAMES = 10_000_000 N_ENVS = 8 ROLLOUT_STEPS = 128 # batch = N_ENVS * ROLLOUT_STEPS = 1024 EPOCHS = 4 @@ -113,6 +113,11 @@ def policy_action(obs): ep_returns = [] for update in range(1, n_updates + 1): + # Linear LR anneal from LR -> 0 over the run (CleanRL convention). + lr_now = LR * (1.0 - (update - 1) / n_updates) + for g in optimizer.param_groups: + g["lr"] = lr_now + obs_buf = np.zeros((ROLLOUT_STEPS, N_ENVS, *obs_shape), dtype=np.uint8) act_buf = np.zeros((ROLLOUT_STEPS, N_ENVS), dtype=np.int64) logp_buf = np.zeros((ROLLOUT_STEPS, N_ENVS), dtype=np.float32) @@ -151,12 +156,12 @@ def policy_action(obs): obs_t = torch.as_tensor(np.asarray(obs), device=device) _, last_value = model(obs_t) advantages, returns = compute_gae(rew_buf, val_buf, done_buf, last_value.cpu().numpy()) - advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Flatten (T, N_ENVS, ...) -> (T*N_ENVS, ...) obs_t = torch.as_tensor(obs_buf.reshape(batch_size, *obs_shape), device=device) act_t = torch.as_tensor(act_buf.reshape(batch_size), device=device) old_logp_t = torch.as_tensor(logp_buf.reshape(batch_size), device=device) + old_val_t = torch.as_tensor(val_buf.reshape(batch_size), device=device) adv_t = torch.as_tensor(advantages.reshape(batch_size), device=device) ret_t = torch.as_tensor(returns.reshape(batch_size), device=device) @@ -173,11 +178,22 @@ def policy_action(obs): new_logp = dist.log_prob(act_t[mb]) entropy = dist.entropy().mean() + # Advantage normalization per minibatch (CleanRL convention). + mb_adv = adv_t[mb] + mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8) + ratio = (new_logp - old_logp_t[mb]).exp() - unclipped = ratio * adv_t[mb] - clipped = torch.clamp(ratio, 1 - CLIP_COEF, 1 + CLIP_COEF) * adv_t[mb] + unclipped = ratio * mb_adv + clipped = torch.clamp(ratio, 1 - CLIP_COEF, 1 + CLIP_COEF) * mb_adv policy_loss = -torch.min(unclipped, clipped).mean() - value_loss = (values - ret_t[mb]).pow(2).mean() + + # Value loss with clipping around the old value prediction. + v_clipped = old_val_t[mb] + torch.clamp( + values - old_val_t[mb], -CLIP_COEF, CLIP_COEF) + vl_unclipped = (values - ret_t[mb]).pow(2) + vl_clipped = (v_clipped - ret_t[mb]).pow(2) + value_loss = 0.5 * torch.max(vl_unclipped, vl_clipped).mean() + loss = policy_loss + VALUE_COEF * value_loss - ENTROPY_COEF * entropy optimizer.zero_grad() @@ -201,6 +217,7 @@ def policy_action(obs): "policy_loss": pl_sum / n_mb, "value_loss": vl_sum / n_mb, "entropy": ent_sum / n_mb, + "lr": lr_now, } if ep_returns: log["recent_mean_return"] = float(np.mean(ep_returns[-20:])) From 69a12a7b2f003d37862d535c7a7aa75f62c71435 Mon Sep 17 00:00:00 2001 From: dnddnjs Date: Sat, 23 May 2026 22:17:16 +0900 Subject: [PATCH 2/2] Atari: shrink DQN buffer 4x, fix life-loss reset, track per-game returns - ReplayBuffer stores single frames and stacks 4 at sample time (~28GB -> ~7GB). - LifeLossTerminalEnv signals terminal on life loss but defers real reset to game-over, so noop_max + FIRE no longer fire every life and GAE/Q chains break only at the right boundary. - DQN: BATCH_SIZE 64 -> 32, TARGET_UPDATE_EVERY 2500 -> 250 train steps (~1k env frames), EPSILON_END 0.1 -> 0.01. - Log per-life and per-game returns separately (DQN and PPO). --- 3-atari/1-dqn.py | 85 +++++++++++++++++++++++++++++++++--------------- 3-atari/2-ppo.py | 20 +++++++++--- 3-atari/env.py | 36 +++++++++++++++++++- 3 files changed, 110 insertions(+), 31 deletions(-) diff --git a/3-atari/1-dqn.py b/3-atari/1-dqn.py index 2ed9132..1f508f9 100644 --- a/3-atari/1-dqn.py +++ b/3-atari/1-dqn.py @@ -11,7 +11,6 @@ serious training. """ import random -import sys from collections import deque import numpy as np @@ -23,17 +22,17 @@ SAVE_PATH = "atari_dqn.pt" -TOTAL_FRAMES = 1_000_000 # bump to ~10M for paper-quality results -BUFFER_CAPACITY = 100_000 # bump to 1M with enough RAM -BATCH_SIZE = 64 +TOTAL_FRAMES = 10_000_000 # Nature uses 50M agent steps; 10M is laptop-friendly +BUFFER_CAPACITY = 1_000_000 # Nature standard; ~7GB RAM (uint8, single frames stacked at sample time) +BATCH_SIZE = 32 GAMMA = 0.99 LR = 1e-4 -LEARN_START = 10_000 # frames of pure exploration before training begins +LEARN_START = 80_000 # frames of pure exploration before training begins TRAIN_EVERY = 4 -TARGET_UPDATE_EVERY = 1_000 # in training steps, not env steps +TARGET_UPDATE_EVERY = 250 # in training steps, not env steps (~1k env frames) EPSILON_START = 1.0 -EPSILON_END = 0.05 -EPSILON_DECAY_FRAMES = 250_000 # linear decay from start to end over this many frames +EPSILON_END = 0.01 +EPSILON_DECAY_FRAMES = 1_000_000 # linear decay from start to end over this many frames # Standard Nature CNN. @@ -57,34 +56,59 @@ def forward(self, x): class ReplayBuffer: - """Uint8 replay buffer — far more memory-efficient than storing floats.""" + """Single-frame uint8 buffer — stacks of 4 are reconstructed at sample time, + cutting RAM ~4x vs. storing the full stack per slot.""" - def __init__(self, capacity, obs_shape): + def __init__(self, capacity, frame_shape=(84, 84), stack=4): self.capacity = capacity - self.obs = np.zeros((capacity, *obs_shape), dtype=np.uint8) - self.next_obs = np.zeros((capacity, *obs_shape), dtype=np.uint8) - self.actions = np.zeros(capacity, dtype=np.int64) - self.rewards = np.zeros(capacity, dtype=np.float32) - self.dones = np.zeros(capacity, dtype=np.float32) + self.stack = stack + self.frames = np.zeros((capacity, *frame_shape), dtype=np.uint8) + self.actions = np.zeros(capacity, dtype=np.int64) + self.rewards = np.zeros(capacity, dtype=np.float32) + self.dones = np.zeros(capacity, dtype=np.float32) self.idx = 0 self.size = 0 - def push(self, obs, action, reward, next_obs, done): - self.obs[self.idx] = obs + def push(self, frame, action, reward, done): + self.frames[self.idx] = frame self.actions[self.idx] = action self.rewards[self.idx] = reward - self.next_obs[self.idx] = next_obs self.dones[self.idx] = float(done) self.idx = (self.idx + 1) % self.capacity self.size = min(self.size + 1, self.capacity) + def _stack(self, idx): + # Gather frames[idx-stack+1 .. idx]; newest at last channel. + offsets = np.arange(self.stack) + gather = (idx[:, None] - (self.stack - 1) + offsets[None, :]) % self.capacity + out = self.frames[gather] + # Zero out frames sitting before an episode boundary inside the stack. + # dones at the (stack-1) older positions mark where a prior episode ended. + older = self.dones[gather[:, :-1]].astype(bool) + # Once we cross any done walking newest→oldest, everything older is invalid. + invalid = np.cumsum(older[:, ::-1], axis=1)[:, ::-1] > 0 + mask = np.concatenate([~invalid, np.ones((idx.shape[0], 1), dtype=bool)], axis=1) + return out * mask[:, :, None, None] + def sample(self, batch_size, device): - idx = np.random.randint(0, self.size, size=batch_size) + # Reject indices whose stack would straddle the write head (stale frames). + while True: + if self.size < self.capacity: + if self.size < self.stack + 2: + raise RuntimeError("buffer too small to sample yet") + idx = np.random.randint(self.stack - 1, self.size - 1, size=batch_size) + break + idx = np.random.randint(0, self.capacity, size=batch_size) + dist = (self.idx - 1 - idx) % self.capacity + if np.all(dist >= self.stack): + break + states = self._stack(idx) + next_states = self._stack((idx + 1) % self.capacity) return ( - torch.as_tensor(self.obs[idx], device=device), + torch.as_tensor(states, device=device), torch.as_tensor(self.actions[idx], device=device), torch.as_tensor(self.rewards[idx], device=device), - torch.as_tensor(self.next_obs[idx], device=device), + torch.as_tensor(next_states, device=device), torch.as_tensor(self.dones[idx], device=device), ) @@ -130,10 +154,12 @@ def greedy_action(obs): print(f"device: {device}, env: {args.env}, actions: {n_actions}") - buffer = ReplayBuffer(BUFFER_CAPACITY, env.observation_space.shape) + buffer = ReplayBuffer(BUFFER_CAPACITY) obs, _ = env.reset() - ep_return = 0.0 + ep_return = 0.0 # accumulates within one life (LifeLossTerminalEnv ends an "episode" per life) + game_return = 0.0 # accumulates across all 5 lives until real game-over recent_returns = deque(maxlen=20) + recent_game_returns = deque(maxlen=20) train_step = 0 last_loss = 0.0 @@ -146,18 +172,23 @@ def greedy_action(obs): else: action = greedy_action(obs) - next_obs, reward, terminated, truncated, _ = env.step(action) + next_obs, reward, terminated, truncated, info = env.step(action) done = terminated or truncated # Reward clipping (DeepMind standard) — keeps Q-values from blowing up # when one game has rewards in tens and another in hundreds. clipped = np.sign(reward) - buffer.push(np.asarray(obs), action, clipped, np.asarray(next_obs), done) + # FrameStack gives (4, 84, 84); store just the newest frame and stack at sample time. + buffer.push(np.asarray(obs)[-1], action, clipped, done) ep_return += reward + game_return += reward obs = next_obs if done: recent_returns.append(ep_return) ep_return = 0.0 + if info.get("game_over", True): + recent_game_returns.append(game_return) + game_return = 0.0 obs, _ = env.reset() # Training. @@ -182,12 +213,14 @@ def greedy_action(obs): # Logging. if frame % 10_000 == 0: mean = float(np.mean(recent_returns)) if recent_returns else 0.0 + game_mean = float(np.mean(recent_game_returns)) if recent_game_returns else 0.0 print(f"frame: {frame:>8} eps: {epsilon(frame):.3f} " - f"recent_mean_return: {mean:.1f} buffer: {buffer.size}") + f"per_life: {mean:.1f} per_game: {game_mean:.1f} buffer: {buffer.size}") if args.wandb: wandb.log({ "global_step": frame, "recent_mean_return": mean, + "recent_mean_game_return": game_mean, "epsilon": epsilon(frame), "loss": last_loss, "buffer_size": buffer.size, diff --git a/3-atari/2-ppo.py b/3-atari/2-ppo.py index 9ab2d2c..9082858 100644 --- a/3-atari/2-ppo.py +++ b/3-atari/2-ppo.py @@ -109,8 +109,10 @@ def policy_action(obs): frames_per_update = batch_size n_updates = TOTAL_FRAMES // frames_per_update obs, _ = envs.reset() - ep_returns_per_env = np.zeros(N_ENVS, dtype=np.float32) + ep_returns_per_env = np.zeros(N_ENVS, dtype=np.float32) # per-life (resets every life loss) + game_returns_per_env = np.zeros(N_ENVS, dtype=np.float32) # per-game (resets only on real game-over) ep_returns = [] + game_returns = [] for update in range(1, n_updates + 1): # Linear LR anneal from LR -> 0 over the run (CleanRL convention). @@ -139,16 +141,22 @@ def policy_action(obs): logp_buf[t] = logp.cpu().numpy() val_buf[t] = value.cpu().numpy() - next_obs, reward, terminated, truncated, _ = envs.step(act_buf[t]) + next_obs, reward, terminated, truncated, info = envs.step(act_buf[t]) done = np.logical_or(terminated, truncated) ep_returns_per_env += reward + game_returns_per_env += reward rew_buf[t] = np.sign(reward).astype(np.float32) # DeepMind reward clipping done_buf[t] = done.astype(np.float32) + # LifeLossTerminalEnv tags each step's info with game_over (True only on real game-over). + game_over = info.get("game_over", done) for i in range(N_ENVS): if done[i]: ep_returns.append(float(ep_returns_per_env[i])) ep_returns_per_env[i] = 0.0 + if bool(game_over[i]): + game_returns.append(float(game_returns_per_env[i])) + game_returns_per_env[i] = 0.0 obs = next_obs # --- GAE --- @@ -208,9 +216,11 @@ def policy_action(obs): global_step = update * frames_per_update if ep_returns: - recent = ep_returns[-20:] + life_mean = float(np.mean(ep_returns[-20:])) + game_mean = float(np.mean(game_returns[-20:])) if game_returns else 0.0 print(f"update: {update:>4} frames: {global_step:>8} " - f"recent_mean_return: {np.mean(recent):.1f} episodes: {len(ep_returns)}") + f"per_life: {life_mean:.1f} per_game: {game_mean:.1f} " + f"lives: {len(ep_returns)} games: {len(game_returns)}") if args.wandb: log = { "global_step": global_step, @@ -221,6 +231,8 @@ def policy_action(obs): } if ep_returns: log["recent_mean_return"] = float(np.mean(ep_returns[-20:])) + if game_returns: + log["recent_mean_game_return"] = float(np.mean(game_returns[-20:])) wandb.log(log, step=global_step) torch.save(model.state_dict(), SAVE_PATH) diff --git a/3-atari/env.py b/3-atari/env.py index 2f3d1b9..1fb8513 100644 --- a/3-atari/env.py +++ b/3-atari/env.py @@ -29,6 +29,39 @@ def reset(self, **kwargs): obs, _ = self.env.reset(**kwargs) return obs, {} + +# Treats each life as its own episode for bootstrapping (so Q-targets / GAE don't +# value-chain across deaths) but only resets the real game when all lives are +# gone. Without this, every life loss triggers a full env.reset() — burning +# frames on noop_max + FIRE and breaking long-horizon credit assignment. +class LifeLossTerminalEnv(gym.Wrapper): + def __init__(self, env): + super().__init__(env) + self.lives = 0 + self.game_over = True + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + self.game_over = terminated or truncated + lives = info.get("lives", 0) + if 0 < lives < self.lives: + terminated = True + self.lives = lives + info["game_over"] = self.game_over + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + if self.game_over: + obs, info = self.env.reset(**kwargs) + else: + # Fake terminal from a life loss — advance one frame instead of + # resetting so the game keeps its remaining lives. + obs, _, terminated, truncated, info = self.env.step(0) + if terminated or truncated: + obs, info = self.env.reset(**kwargs) + self.lives = info.get("lives", 0) + return obs, info + ENV_IDS = { "breakout": "ALE/Breakout-v5", "pong": "ALE/Pong-v5", @@ -61,12 +94,13 @@ def make_env(args): noop_max=30, frame_skip=4, screen_size=84, - terminal_on_life_loss=True, + terminal_on_life_loss=False, # handled by LifeLossTerminalEnv below grayscale_obs=True, scale_obs=False, # keep uint8; we normalize in the model ) if "FIRE" in env.unwrapped.get_action_meanings(): env = FireResetEnv(env) + env = LifeLossTerminalEnv(env) env = gym.wrappers.FrameStackObservation(env, stack_size=4) return env