summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py19
1 files changed, 16 insertions, 3 deletions
diff --git a/train.py b/train.py
index 31ce926..fd27f44 100644
--- a/train.py
+++ b/train.py
@@ -349,6 +349,12 @@ def train(args):
recent_loss = 0.0
recent_loss_count = 0
+ # Entropy annealing: start high to escape greedy local minimum, decay to target
+ ent_start = args.ent_start if args.ent_start is not None else args.ent_coef * 5
+ ent_end = args.ent_coef
+ if ent_start != ent_end:
+ print(f"Entropy annealing: {ent_start} → {ent_end}")
+
ep = 0
next_log = args.log_every
next_eval = args.eval_every
@@ -429,7 +435,10 @@ def train(args):
surr2 = torch.clamp(ratio, 1 - args.clip_eps, 1 + args.clip_eps) * b_advs
policy_loss = -torch.min(surr1, surr2).mean()
value_loss = F.mse_loss(values, b_returns)
- loss = policy_loss + 0.5 * value_loss - args.ent_coef * entropy.mean()
+ # Entropy coefficient with linear annealing
+ progress = min(ep / args.episodes, 1.0)
+ ent_coef = ent_start + (ent_end - ent_start) * progress
+ loss = policy_loss + 0.5 * value_loss - ent_coef * entropy.mean()
optimizer.zero_grad()
loss.backward()
@@ -452,8 +461,9 @@ def train(args):
avg_loss = recent_loss / max(recent_loss_count, 1)
total_games = sum(win_counts.values())
wr0 = win_counts[0] / max(total_games, 1)
+ cur_ent = ent_start + (ent_end - ent_start) * min(ep / args.episodes, 1.0)
pbar.set_postfix(avg_len=f"{avg_len:.1f}", loss=f"{avg_loss:.3f}",
- wr0=f"{wr0:.1%}", games=total_games)
+ ent=f"{cur_ent:.3f}", wr0=f"{wr0:.1%}")
recent_loss = 0.0
recent_loss_count = 0
next_log += args.log_every
@@ -561,7 +571,10 @@ if __name__ == "__main__":
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--lam", type=float, default=0.95)
parser.add_argument("--clip_eps", type=float, default=0.2)
- parser.add_argument("--ent_coef", type=float, default=0.01)
+ parser.add_argument("--ent_coef", type=float, default=0.01,
+ help="Final entropy coefficient")
+ parser.add_argument("--ent_start", type=float, default=None,
+ help="Initial entropy coef, linearly decays to ent_coef (default: 5x ent_coef)")
parser.add_argument("--ppo_epochs", type=int, default=4)
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--update_every", type=int, default=64)