Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions 3-atari/2-ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


SAVE_PATH = "atari_ppo.pt"
TOTAL_FRAMES = 5_000_000
TOTAL_FRAMES = 10_000_000
N_ENVS = 8
ROLLOUT_STEPS = 128 # batch = N_ENVS * ROLLOUT_STEPS = 1024
EPOCHS = 4
Expand Down Expand Up @@ -113,6 +113,11 @@ def policy_action(obs):
ep_returns = []

for update in range(1, n_updates + 1):
# Linear LR anneal from LR -> 0 over the run (CleanRL convention).
lr_now = LR * (1.0 - (update - 1) / n_updates)
for g in optimizer.param_groups:
g["lr"] = lr_now

obs_buf = np.zeros((ROLLOUT_STEPS, N_ENVS, *obs_shape), dtype=np.uint8)
act_buf = np.zeros((ROLLOUT_STEPS, N_ENVS), dtype=np.int64)
logp_buf = np.zeros((ROLLOUT_STEPS, N_ENVS), dtype=np.float32)
Expand Down Expand Up @@ -151,12 +156,12 @@ def policy_action(obs):
obs_t = torch.as_tensor(np.asarray(obs), device=device)
_, last_value = model(obs_t)
advantages, returns = compute_gae(rew_buf, val_buf, done_buf, last_value.cpu().numpy())
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

# Flatten (T, N_ENVS, ...) -> (T*N_ENVS, ...)
obs_t = torch.as_tensor(obs_buf.reshape(batch_size, *obs_shape), device=device)
act_t = torch.as_tensor(act_buf.reshape(batch_size), device=device)
old_logp_t = torch.as_tensor(logp_buf.reshape(batch_size), device=device)
old_val_t = torch.as_tensor(val_buf.reshape(batch_size), device=device)
adv_t = torch.as_tensor(advantages.reshape(batch_size), device=device)
ret_t = torch.as_tensor(returns.reshape(batch_size), device=device)

Expand All @@ -173,11 +178,22 @@ def policy_action(obs):
new_logp = dist.log_prob(act_t[mb])
entropy = dist.entropy().mean()

# Advantage normalization per minibatch (CleanRL convention).
mb_adv = adv_t[mb]
mb_adv = (mb_adv - mb_adv.mean()) / (mb_adv.std() + 1e-8)

ratio = (new_logp - old_logp_t[mb]).exp()
unclipped = ratio * adv_t[mb]
clipped = torch.clamp(ratio, 1 - CLIP_COEF, 1 + CLIP_COEF) * adv_t[mb]
unclipped = ratio * mb_adv
clipped = torch.clamp(ratio, 1 - CLIP_COEF, 1 + CLIP_COEF) * mb_adv
policy_loss = -torch.min(unclipped, clipped).mean()
value_loss = (values - ret_t[mb]).pow(2).mean()

# Value loss with clipping around the old value prediction.
v_clipped = old_val_t[mb] + torch.clamp(
values - old_val_t[mb], -CLIP_COEF, CLIP_COEF)
vl_unclipped = (values - ret_t[mb]).pow(2)
vl_clipped = (v_clipped - ret_t[mb]).pow(2)
value_loss = 0.5 * torch.max(vl_unclipped, vl_clipped).mean()

loss = policy_loss + VALUE_COEF * value_loss - ENTROPY_COEF * entropy

optimizer.zero_grad()
Expand All @@ -201,6 +217,7 @@ def policy_action(obs):
"policy_loss": pl_sum / n_mb,
"value_loss": vl_sum / n_mb,
"entropy": ent_sum / n_mb,
"lr": lr_now,
}
if ep_returns:
log["recent_mean_return"] = float(np.mean(ep_returns[-20:]))
Expand Down