summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorhaoyuren <13851610112@163.com>2026-02-22 11:36:53 -0600
committerhaoyuren <13851610112@163.com>2026-02-22 11:36:53 -0600
commitdc8421e251f059e2136d5535bca2182af67fff75 (patch)
tree13456ff8df1ac4e4ef839f30c97916c6bda232d6 /train.py
parent3887054e02e622ca2cb7878bc0dec63d28c7f223 (diff)
Separate CPU collect / GPU train, add training CSV log
- Game collection always on CPU, PPO update on GPU (avoids per-step transfer overhead) - Log avg_len, loss, vs_greedy win rate to CSV every 10k episodes - Add --eval_every flag for periodic evaluation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'train.py')
-rw-r--r--train.py86
1 files changed, 59 insertions, 27 deletions
diff --git a/train.py b/train.py
index d247cbb..65fb633 100644
--- a/train.py
+++ b/train.py
@@ -308,27 +308,40 @@ def greedy_warmup(model: PolicyValueNet, optimizer: torch.optim.Optimizer,
# Training Loop
# ---------------------------------------------------------------------------
def train(args):
- device = "cuda" if torch.cuda.is_available() else "cpu"
- print(f"Device: {device}")
+ train_device = "cuda" if torch.cuda.is_available() else "cpu"
+ collect_device = "cpu" # env simulation always on CPU
+ print(f"Train device: {train_device}, Collect device: {collect_device}")
print(f"Training for {args.num_players} players, {args.episodes} episodes")
- model = PolicyValueNet().to(device)
+ # Model lives on CPU for game collection; moves to GPU for PPO updates
+ model = PolicyValueNet().to(collect_device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# Greedy warmup: imitate greedy play before self-play
if args.greedy_warmup > 0:
+ if train_device != collect_device:
+ model.to(train_device)
greedy_warmup(model, optimizer, args.num_players,
- num_games=args.greedy_warmup, device=device)
+ num_games=args.greedy_warmup, device=train_device)
+ if train_device != collect_device:
+ model.to(collect_device)
+
+ # Training log
+ log_path = f"{args.save_path}_log.csv"
+ with open(log_path, "w") as f:
+ f.write("episode,avg_len,loss,vs_greedy_wr\n")
# Stats
win_counts = defaultdict(int)
game_lengths = []
batch_transitions = []
+ recent_loss = 0.0
+ recent_loss_count = 0
pbar = tqdm(range(1, args.episodes + 1), desc="Training", unit="ep")
for ep in pbar:
env = BlazingEightsEnv(num_players=args.num_players)
- trajectories = collect_game(env, model, device)
+ trajectories = collect_game(env, model, collect_device)
# Record stats
if env.done:
@@ -340,43 +353,36 @@ def train(args):
returns, advantages = compute_gae(trans_list, gamma=args.gamma, lam=args.lam)
for i, t in enumerate(trans_list):
t.reward = returns[i] if i < len(returns) else t.reward
- # Store advantage in a hacky way: overwrite reward with return,
- # and we'll use (return - value) as advantage in update
batch_transitions.extend(trans_list)
# Update every `update_every` episodes
- if ep % args.update_every == 0:
- # Recompute advantages from stored returns and values
- for t in batch_transitions:
- pass # returns already in t.reward
-
- # Build proper advantages
- for t in batch_transitions:
- # t.reward is now the GAE return; advantage = return - value
- t.reward = t.reward # this is the return
- # We'll set the advantage in the update
- # Actually, let's just pass returns and let update compute
+ if ep % args.update_every == 0 and batch_transitions:
returns_for_update = np.array([t.reward for t in batch_transitions])
values_for_update = np.array([t.value for t in batch_transitions])
advs = returns_for_update - values_for_update
- # Overwrite for the update function
obs_arr = np.array([t.obs for t in batch_transitions])
actions_arr = np.array([t.action for t in batch_transitions])
old_lp_arr = np.array([t.log_prob for t in batch_transitions])
masks_arr = np.array([t.legal_mask for t in batch_transitions])
- obs_t = torch.tensor(obs_arr, dtype=torch.float32, device=device)
- actions_t = torch.tensor(actions_arr, dtype=torch.long, device=device)
- old_lp_t = torch.tensor(old_lp_arr, dtype=torch.float32, device=device)
- masks_t = torch.tensor(masks_arr, dtype=torch.float32, device=device)
- returns_t = torch.tensor(returns_for_update, dtype=torch.float32, device=device)
- advs_t = torch.tensor(advs, dtype=torch.float32, device=device)
+ # Move model to train device for PPO update
+ if train_device != collect_device:
+ model.to(train_device)
+
+ obs_t = torch.tensor(obs_arr, dtype=torch.float32, device=train_device)
+ actions_t = torch.tensor(actions_arr, dtype=torch.long, device=train_device)
+ old_lp_t = torch.tensor(old_lp_arr, dtype=torch.float32, device=train_device)
+ masks_t = torch.tensor(masks_arr, dtype=torch.float32, device=train_device)
+ returns_t = torch.tensor(returns_for_update, dtype=torch.float32, device=train_device)
+ advs_t = torch.tensor(advs, dtype=torch.float32, device=train_device)
if len(advs_t) > 1:
advs_t = (advs_t - advs_t.mean()) / (advs_t.std() + 1e-8)
- # Manual PPO update
+ # PPO update
+ batch_loss = 0.0
+ n_updates = 0
for _ in range(args.ppo_epochs):
indices = np.arange(len(batch_transitions))
np.random.shuffle(indices)
@@ -404,14 +410,38 @@ def train(args):
nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
+ batch_loss += loss.item()
+ n_updates += 1
+
+ recent_loss += batch_loss / max(n_updates, 1)
+ recent_loss_count += 1
+
+ # Move model back to CPU for collection
+ if train_device != collect_device:
+ model.to(collect_device)
+
batch_transitions = []
# Logging
if ep % args.log_every == 0:
avg_len = np.mean(game_lengths[-args.log_every:]) if game_lengths else 0
+ avg_loss = recent_loss / max(recent_loss_count, 1)
total_games = sum(win_counts.values())
wr0 = win_counts[0] / max(total_games, 1)
- pbar.set_postfix(avg_len=f"{avg_len:.1f}", wr0=f"{wr0:.1%}", games=total_games)
+ pbar.set_postfix(avg_len=f"{avg_len:.1f}", loss=f"{avg_loss:.3f}",
+ wr0=f"{wr0:.1%}", games=total_games)
+ recent_loss = 0.0
+ recent_loss_count = 0
+
+ # Evaluate vs greedy + write log every eval_every episodes
+ if ep % args.eval_every == 0:
+ avg_len = np.mean(game_lengths[-args.eval_every:]) if game_lengths else 0
+ avg_loss_log = recent_loss / max(recent_loss_count, 1) if recent_loss_count > 0 else 0
+ vs_wr = evaluate_vs_random(model, num_players=args.num_players,
+ num_games=500, device=collect_device)
+ with open(log_path, "a") as f:
+ f.write(f"{ep},{avg_len:.1f},{avg_loss_log:.4f},{vs_wr:.4f}\n")
+ tqdm.write(f" [Eval ep{ep}] avg_len={avg_len:.1f} vs_greedy={vs_wr:.1%}")
# Save checkpoint
if ep % args.save_every == 0:
@@ -432,6 +462,7 @@ def train(args):
"num_players": args.num_players,
}, f"{args.save_path}_final.pt")
print(f"Training complete. Final model saved to {args.save_path}_final.pt")
+ print(f"Training log saved to {log_path}")
return model
@@ -475,6 +506,7 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--update_every", type=int, default=64)
parser.add_argument("--log_every", type=int, default=1000)
+ parser.add_argument("--eval_every", type=int, default=10000)
parser.add_argument("--save_every", type=int, default=10000)
parser.add_argument("--save_path", type=str, default="blazing_ppo")
parser.add_argument("--greedy_warmup", type=int, default=2000,