summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2026-02-22 12:09:01 -0600
committerhaoyuren <13851610112@163.com>2026-02-22 12:09:01 -0600
commit0735c68037566ae6731ac5dd349329b1c8d44851 (patch)
tree1adc41bdce029d627dc7b318a8dd379630325ec3
parent800e1f1f33d93cb7a1812dff1dc0ef85289ef075 (diff)
Add entropy annealing to escape greedy local minimum after warmup
After behavioral cloning warmup, policy is very peaked on greedy actions. Start with higher entropy coefficient (default: 5x ent_coef) and linearly decay to target, encouraging exploration of non-greedy strategies early in training. New arg: --ent_start (default: 5x --ent_coef) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
-rw-r--r--train.py19
1 files changed, 16 insertions, 3 deletions
diff --git a/train.py b/train.py
index 31ce926..fd27f44 100644
--- a/train.py
+++ b/train.py
@@ -349,6 +349,12 @@ def train(args):
recent_loss = 0.0
recent_loss_count = 0
+ # Entropy annealing: start high to escape greedy local minimum, decay to target
+ ent_start = args.ent_start if args.ent_start is not None else args.ent_coef * 5
+ ent_end = args.ent_coef
+ if ent_start != ent_end:
+ print(f"Entropy annealing: {ent_start} → {ent_end}")
+
ep = 0
next_log = args.log_every
next_eval = args.eval_every
@@ -429,7 +435,10 @@ def train(args):
surr2 = torch.clamp(ratio, 1 - args.clip_eps, 1 + args.clip_eps) * b_advs
policy_loss = -torch.min(surr1, surr2).mean()
value_loss = F.mse_loss(values, b_returns)
- loss = policy_loss + 0.5 * value_loss - args.ent_coef * entropy.mean()
+ # Entropy coefficient with linear annealing
+ progress = min(ep / args.episodes, 1.0)
+ ent_coef = ent_start + (ent_end - ent_start) * progress
+ loss = policy_loss + 0.5 * value_loss - ent_coef * entropy.mean()
optimizer.zero_grad()
loss.backward()
@@ -452,8 +461,9 @@ def train(args):
avg_loss = recent_loss / max(recent_loss_count, 1)
total_games = sum(win_counts.values())
wr0 = win_counts[0] / max(total_games, 1)
+ cur_ent = ent_start + (ent_end - ent_start) * min(ep / args.episodes, 1.0)
pbar.set_postfix(avg_len=f"{avg_len:.1f}", loss=f"{avg_loss:.3f}",
- wr0=f"{wr0:.1%}", games=total_games)
+ ent=f"{cur_ent:.3f}", wr0=f"{wr0:.1%}")
recent_loss = 0.0
recent_loss_count = 0
next_log += args.log_every
@@ -561,7 +571,10 @@ if __name__ == "__main__":
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--lam", type=float, default=0.95)
parser.add_argument("--clip_eps", type=float, default=0.2)
- parser.add_argument("--ent_coef", type=float, default=0.01)
+ parser.add_argument("--ent_coef", type=float, default=0.01,
+ help="Final entropy coefficient")
+ parser.add_argument("--ent_start", type=float, default=None,
+ help="Initial entropy coef, linearly decays to ent_coef (default: 5x ent_coef)")
parser.add_argument("--ppo_epochs", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--update_every", type=int, default=64)