diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-09 11:00:39 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-09 11:00:39 -0600 |
| commit | 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b (patch) | |
| tree | 073534138604c1c49021ca7e334322262129f6ac /scripts/eval.py | |
Initial implementation: DAGFormer Phase 1
- olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A
- Proportional attribution for post-norm decomposition
- All 6 GPU sanity checks pass (baseline diff = 0.000001)
- predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate
- pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL)
- trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing
- dolma.py: Streaming Dolma v1.7 with sequence packing
- 43/43 unit tests pass
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/eval.py')
| -rw-r--r-- | scripts/eval.py | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/scripts/eval.py b/scripts/eval.py new file mode 100644 index 0000000..bc471dc --- /dev/null +++ b/scripts/eval.py @@ -0,0 +1,131 @@ +"""Evaluate a trained DAGFormer checkpoint. + +Usage: + python scripts/eval.py --config configs/sanity_check.yaml --checkpoint checkpoints/checkpoint_step1000.pt +""" + +from __future__ import annotations + +import argparse + +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer + +from src.data.dolma import build_eval_dataloader +from src.model.olmo_graph import DAGFormerOLMo, create_all_ones_A +from src.model.predictor import StructurePredictor +from src.training.checkpointing import load_checkpoint +from src.training.trainer import TrainConfig +from src.utils.topology import compute_topology_metrics + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate DAGFormer") + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--checkpoint", type=str, required=True) + parser.add_argument("--device", type=str, default="cuda") + args = parser.parse_args() + + config = TrainConfig.from_yaml(args.config) + device = torch.device(args.device) + + # Load models + print(f"Loading {config.olmo_model_id}...") + olmo = AutoModelForCausalLM.from_pretrained( + config.olmo_model_id, torch_dtype=torch.bfloat16 + ).to(device).eval() + for p in olmo.parameters(): + p.requires_grad_(False) + + olmo_tokenizer = AutoTokenizer.from_pretrained(config.olmo_model_id) + + olmo_wrapper = DAGFormerOLMo(model=olmo, input_norm=config.input_norm).to(device) + + print(f"Loading {config.qwen_model_id}...") + predictor = StructurePredictor( + qwen_model_id=config.qwen_model_id, + hidden_dim=config.predictor_hidden_dim, + rank=config.predictor_rank, + cascading_gate_k=config.cascading_gate_k, + qwen_input_prefix=config.qwen_input_prefix, + device=device, + ) + + # Load checkpoint + load_checkpoint(args.checkpoint, predictor, device=device) + predictor.eval() + + # Build eval data + cache_path = f"{config.save_dir}/eval_cache.pt" + eval_batches = build_eval_dataloader( + olmo_tokenizer=olmo_tokenizer, + seq_len=config.seq_len, + batch_size=config.micro_batch_size, + dataset_name=config.dataset, + dataset_version=config.dataset_name, + eval_skip=config.eval_skip, + eval_size=config.eval_size, + cache_path=cache_path, + ) + + vocab_size = olmo.config.vocab_size + tau = config.tau_final # use final temperature for eval + + # Evaluate + nll_soft_sum = 0.0 + nll_hard_sum = 0.0 + nll_baseline_sum = 0.0 + n = 0 + + with torch.no_grad(): + for batch in eval_batches: + olmo_ids = batch["olmo_ids"].to(device) + olmo_labels = batch["olmo_labels"].to(device) + raw_texts = batch["raw_text"] + + # Soft + A_soft = predictor(raw_texts, tau=tau, mode="eval_soft") + logits_soft = olmo_wrapper(olmo_ids, A_soft) + nll_soft = F.cross_entropy( + logits_soft[:, :-1].contiguous().view(-1, vocab_size), + olmo_labels[:, 1:].contiguous().view(-1), + ) + nll_soft_sum += nll_soft.item() + + # Hard + A_hard = predictor(raw_texts, tau=tau, mode="eval_hard") + logits_hard = olmo_wrapper(olmo_ids, A_hard) + nll_hard = F.cross_entropy( + logits_hard[:, :-1].contiguous().view(-1, vocab_size), + olmo_labels[:, 1:].contiguous().view(-1), + ) + nll_hard_sum += nll_hard.item() + + # Baseline + A_ones = create_all_ones_A(olmo_ids.shape[0]).to(device) + logits_base = olmo_wrapper(olmo_ids, A_ones) + nll_base = F.cross_entropy( + logits_base[:, :-1].contiguous().view(-1, vocab_size), + olmo_labels[:, 1:].contiguous().view(-1), + ) + nll_baseline_sum += nll_base.item() + + # Topology + topo = compute_topology_metrics(A_soft) + + n += 1 + + print(f"\n{'='*50}") + print(f"Evaluation Results ({n} batches)") + print(f"{'='*50}") + print(f" eval/nll_soft: {nll_soft_sum / n:.4f}") + print(f" eval/nll_hard: {nll_hard_sum / n:.4f}") + print(f" eval/nll_baseline: {nll_baseline_sum / n:.4f}") + print(f" topology/mean_A: {topo['topology/mean_A']:.4f}") + print(f" topology/seq_gate: {topo['topology/seq_gate_frac']:.4f}") + print(f" topology/hyp_gate: {topo['topology/hyp_gate_frac']:.4f}") + + +if __name__ == "__main__": + main() |
