diff options
| author | haoyuren <13851610112@163.com> | 2026-02-22 11:36:53 -0600 |
|---|---|---|
| committer | haoyuren <13851610112@163.com> | 2026-02-22 11:36:53 -0600 |
| commit | dc8421e251f059e2136d5535bca2182af67fff75 (patch) | |
| tree | 13456ff8df1ac4e4ef839f30c97916c6bda232d6 | |
| parent | 3887054e02e622ca2cb7878bc0dec63d28c7f223 (diff) | |
Separate CPU collect / GPU train, add training CSV log
- Game collection always on CPU, PPO update on GPU (avoids per-step transfer overhead)
- Log avg_len, loss, vs_greedy win rate to CSV every 10k episodes
- Add --eval_every flag for periodic evaluation
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| -rw-r--r-- | train.py | 86 |
1 files changed, 59 insertions, 27 deletions
@@ -308,27 +308,40 @@ def greedy_warmup(model: PolicyValueNet, optimizer: torch.optim.Optimizer, # Training Loop # --------------------------------------------------------------------------- def train(args): - device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"Device: {device}") + train_device = "cuda" if torch.cuda.is_available() else "cpu" + collect_device = "cpu" # env simulation always on CPU + print(f"Train device: {train_device}, Collect device: {collect_device}") print(f"Training for {args.num_players} players, {args.episodes} episodes") - model = PolicyValueNet().to(device) + # Model lives on CPU for game collection; moves to GPU for PPO updates + model = PolicyValueNet().to(collect_device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # Greedy warmup: imitate greedy play before self-play if args.greedy_warmup > 0: + if train_device != collect_device: + model.to(train_device) greedy_warmup(model, optimizer, args.num_players, - num_games=args.greedy_warmup, device=device) + num_games=args.greedy_warmup, device=train_device) + if train_device != collect_device: + model.to(collect_device) + + # Training log + log_path = f"{args.save_path}_log.csv" + with open(log_path, "w") as f: + f.write("episode,avg_len,loss,vs_greedy_wr\n") # Stats win_counts = defaultdict(int) game_lengths = [] batch_transitions = [] + recent_loss = 0.0 + recent_loss_count = 0 pbar = tqdm(range(1, args.episodes + 1), desc="Training", unit="ep") for ep in pbar: env = BlazingEightsEnv(num_players=args.num_players) - trajectories = collect_game(env, model, device) + trajectories = collect_game(env, model, collect_device) # Record stats if env.done: @@ -340,43 +353,36 @@ def train(args): returns, advantages = compute_gae(trans_list, gamma=args.gamma, lam=args.lam) for i, t in enumerate(trans_list): t.reward = returns[i] if i < len(returns) else t.reward - # Store advantage in a hacky way: overwrite reward with return, - # and we'll use (return - value) as advantage in update batch_transitions.extend(trans_list) # Update every `update_every` episodes - if ep % args.update_every == 0: - # Recompute advantages from stored returns and values - for t in batch_transitions: - pass # returns already in t.reward - - # Build proper advantages - for t in batch_transitions: - # t.reward is now the GAE return; advantage = return - value - t.reward = t.reward # this is the return - # We'll set the advantage in the update - # Actually, let's just pass returns and let update compute + if ep % args.update_every == 0 and batch_transitions: returns_for_update = np.array([t.reward for t in batch_transitions]) values_for_update = np.array([t.value for t in batch_transitions]) advs = returns_for_update - values_for_update - # Overwrite for the update function obs_arr = np.array([t.obs for t in batch_transitions]) actions_arr = np.array([t.action for t in batch_transitions]) old_lp_arr = np.array([t.log_prob for t in batch_transitions]) masks_arr = np.array([t.legal_mask for t in batch_transitions]) - obs_t = torch.tensor(obs_arr, dtype=torch.float32, device=device) - actions_t = torch.tensor(actions_arr, dtype=torch.long, device=device) - old_lp_t = torch.tensor(old_lp_arr, dtype=torch.float32, device=device) - masks_t = torch.tensor(masks_arr, dtype=torch.float32, device=device) - returns_t = torch.tensor(returns_for_update, dtype=torch.float32, device=device) - advs_t = torch.tensor(advs, dtype=torch.float32, device=device) + # Move model to train device for PPO update + if train_device != collect_device: + model.to(train_device) + + obs_t = torch.tensor(obs_arr, dtype=torch.float32, device=train_device) + actions_t = torch.tensor(actions_arr, dtype=torch.long, device=train_device) + old_lp_t = torch.tensor(old_lp_arr, dtype=torch.float32, device=train_device) + masks_t = torch.tensor(masks_arr, dtype=torch.float32, device=train_device) + returns_t = torch.tensor(returns_for_update, dtype=torch.float32, device=train_device) + advs_t = torch.tensor(advs, dtype=torch.float32, device=train_device) if len(advs_t) > 1: advs_t = (advs_t - advs_t.mean()) / (advs_t.std() + 1e-8) - # Manual PPO update + # PPO update + batch_loss = 0.0 + n_updates = 0 for _ in range(args.ppo_epochs): indices = np.arange(len(batch_transitions)) np.random.shuffle(indices) @@ -404,14 +410,38 @@ def train(args): nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() + batch_loss += loss.item() + n_updates += 1 + + recent_loss += batch_loss / max(n_updates, 1) + recent_loss_count += 1 + + # Move model back to CPU for collection + if train_device != collect_device: + model.to(collect_device) + batch_transitions = [] # Logging if ep % args.log_every == 0: avg_len = np.mean(game_lengths[-args.log_every:]) if game_lengths else 0 + avg_loss = recent_loss / max(recent_loss_count, 1) total_games = sum(win_counts.values()) wr0 = win_counts[0] / max(total_games, 1) - pbar.set_postfix(avg_len=f"{avg_len:.1f}", wr0=f"{wr0:.1%}", games=total_games) + pbar.set_postfix(avg_len=f"{avg_len:.1f}", loss=f"{avg_loss:.3f}", + wr0=f"{wr0:.1%}", games=total_games) + recent_loss = 0.0 + recent_loss_count = 0 + + # Evaluate vs greedy + write log every eval_every episodes + if ep % args.eval_every == 0: + avg_len = np.mean(game_lengths[-args.eval_every:]) if game_lengths else 0 + avg_loss_log = recent_loss / max(recent_loss_count, 1) if recent_loss_count > 0 else 0 + vs_wr = evaluate_vs_random(model, num_players=args.num_players, + num_games=500, device=collect_device) + with open(log_path, "a") as f: + f.write(f"{ep},{avg_len:.1f},{avg_loss_log:.4f},{vs_wr:.4f}\n") + tqdm.write(f" [Eval ep{ep}] avg_len={avg_len:.1f} vs_greedy={vs_wr:.1%}") # Save checkpoint if ep % args.save_every == 0: @@ -432,6 +462,7 @@ def train(args): "num_players": args.num_players, }, f"{args.save_path}_final.pt") print(f"Training complete. Final model saved to {args.save_path}_final.pt") + print(f"Training log saved to {log_path}") return model @@ -475,6 +506,7 @@ if __name__ == "__main__": parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--update_every", type=int, default=64) parser.add_argument("--log_every", type=int, default=1000) + parser.add_argument("--eval_every", type=int, default=10000) parser.add_argument("--save_every", type=int, default=10000) parser.add_argument("--save_path", type=str, default="blazing_ppo") parser.add_argument("--greedy_warmup", type=int, default=2000, |
