diff options
| -rw-r--r-- | train.py | 4 |
1 files changed, 2 insertions, 2 deletions
@@ -473,7 +473,7 @@ def train(args): avg_len = np.mean(all_game_lengths[-args.eval_every:]) if all_game_lengths else 0 avg_loss_log = recent_loss / max(recent_loss_count, 1) if recent_loss_count > 0 else 0 vs_wr = evaluate_vs_greedy_batch(model, num_players=args.num_players, - num_games=500, device=collect_device) + num_games=2000, device=collect_device) with open(log_path, "a") as f: f.write(f"{ep},{avg_len:.1f},{avg_loss_log:.4f},{vs_wr:.4f}\n") tqdm.write(f" [Eval ep{ep}] avg_len={avg_len:.1f} vs_greedy={vs_wr:.1%}") @@ -571,7 +571,7 @@ if __name__ == "__main__": parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--lam", type=float, default=0.95) parser.add_argument("--clip_eps", type=float, default=0.2) - parser.add_argument("--ent_coef", type=float, default=0.01, + parser.add_argument("--ent_coef", type=float, default=0.02, help="Final entropy coefficient") parser.add_argument("--ent_start", type=float, default=None, help="Initial entropy coef, linearly decays to ent_coef (default: 5x ent_coef)") |
