diff options
| -rw-r--r-- | train.py | 19 |
1 files changed, 16 insertions, 3 deletions
@@ -349,6 +349,12 @@ def train(args): recent_loss = 0.0 recent_loss_count = 0 + # Entropy annealing: start high to escape greedy local minimum, decay to target + ent_start = args.ent_start if args.ent_start is not None else args.ent_coef * 5 + ent_end = args.ent_coef + if ent_start != ent_end: + print(f"Entropy annealing: {ent_start} → {ent_end}") + ep = 0 next_log = args.log_every next_eval = args.eval_every @@ -429,7 +435,10 @@ def train(args): surr2 = torch.clamp(ratio, 1 - args.clip_eps, 1 + args.clip_eps) * b_advs policy_loss = -torch.min(surr1, surr2).mean() value_loss = F.mse_loss(values, b_returns) - loss = policy_loss + 0.5 * value_loss - args.ent_coef * entropy.mean() + # Entropy coefficient with linear annealing + progress = min(ep / args.episodes, 1.0) + ent_coef = ent_start + (ent_end - ent_start) * progress + loss = policy_loss + 0.5 * value_loss - ent_coef * entropy.mean() optimizer.zero_grad() loss.backward() @@ -452,8 +461,9 @@ def train(args): avg_loss = recent_loss / max(recent_loss_count, 1) total_games = sum(win_counts.values()) wr0 = win_counts[0] / max(total_games, 1) + cur_ent = ent_start + (ent_end - ent_start) * min(ep / args.episodes, 1.0) pbar.set_postfix(avg_len=f"{avg_len:.1f}", loss=f"{avg_loss:.3f}", - wr0=f"{wr0:.1%}", games=total_games) + ent=f"{cur_ent:.3f}", wr0=f"{wr0:.1%}") recent_loss = 0.0 recent_loss_count = 0 next_log += args.log_every @@ -561,7 +571,10 @@ if __name__ == "__main__": parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--lam", type=float, default=0.95) parser.add_argument("--clip_eps", type=float, default=0.2) - parser.add_argument("--ent_coef", type=float, default=0.01) + parser.add_argument("--ent_coef", type=float, default=0.01, + help="Final entropy coefficient") + parser.add_argument("--ent_start", type=float, default=None, + help="Initial entropy coef, linearly decays to ent_coef (default: 5x ent_coef)") parser.add_argument("--ppo_epochs", type=int, default=4) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--update_every", type=int, default=64) |
