1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
|
"""Evaluate a trained DAGFormer checkpoint.
Usage:
python scripts/eval.py --config configs/sanity_check.yaml --checkpoint checkpoints/checkpoint_step1000.pt
"""
from __future__ import annotations
import argparse
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.data.dolma import build_eval_dataloader
from src.model.olmo_graph import DAGFormerOLMo, create_all_ones_A
from src.model.predictor import StructurePredictor
from src.training.checkpointing import load_checkpoint
from src.training.trainer import TrainConfig
from src.utils.topology import compute_topology_metrics
def main():
parser = argparse.ArgumentParser(description="Evaluate DAGFormer")
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--device", type=str, default="cuda")
args = parser.parse_args()
config = TrainConfig.from_yaml(args.config)
device = torch.device(args.device)
# Load models
print(f"Loading {config.olmo_model_id}...")
olmo = AutoModelForCausalLM.from_pretrained(
config.olmo_model_id, torch_dtype=torch.bfloat16
).to(device).eval()
for p in olmo.parameters():
p.requires_grad_(False)
olmo_tokenizer = AutoTokenizer.from_pretrained(config.olmo_model_id)
olmo_wrapper = DAGFormerOLMo(model=olmo, input_norm=config.input_norm).to(device)
print(f"Loading {config.qwen_model_id}...")
predictor = StructurePredictor(
qwen_model_id=config.qwen_model_id,
hidden_dim=config.predictor_hidden_dim,
rank=config.predictor_rank,
cascading_gate_k=config.cascading_gate_k,
qwen_input_prefix=config.qwen_input_prefix,
device=device,
)
# Load checkpoint
load_checkpoint(args.checkpoint, predictor, device=device)
predictor.eval()
# Build eval data
cache_path = f"{config.save_dir}/eval_cache.pt"
eval_batches = build_eval_dataloader(
olmo_tokenizer=olmo_tokenizer,
seq_len=config.seq_len,
batch_size=config.micro_batch_size,
dataset_name=config.dataset,
dataset_version=config.dataset_name,
eval_skip=config.eval_skip,
eval_size=config.eval_size,
cache_path=cache_path,
)
vocab_size = olmo.config.vocab_size
tau = config.tau_final # use final temperature for eval
# Evaluate
nll_soft_sum = 0.0
nll_hard_sum = 0.0
nll_baseline_sum = 0.0
n = 0
with torch.no_grad():
for batch in eval_batches:
olmo_ids = batch["olmo_ids"].to(device)
olmo_labels = batch["olmo_labels"].to(device)
raw_texts = batch["raw_text"]
# Soft
A_soft = predictor(raw_texts, tau=tau, mode="eval_soft")
logits_soft = olmo_wrapper(olmo_ids, A_soft)
nll_soft = F.cross_entropy(
logits_soft[:, :-1].contiguous().view(-1, vocab_size),
olmo_labels[:, 1:].contiguous().view(-1),
)
nll_soft_sum += nll_soft.item()
# Hard
A_hard = predictor(raw_texts, tau=tau, mode="eval_hard")
logits_hard = olmo_wrapper(olmo_ids, A_hard)
nll_hard = F.cross_entropy(
logits_hard[:, :-1].contiguous().view(-1, vocab_size),
olmo_labels[:, 1:].contiguous().view(-1),
)
nll_hard_sum += nll_hard.item()
# Baseline
A_ones = create_all_ones_A(olmo_ids.shape[0]).to(device)
logits_base = olmo_wrapper(olmo_ids, A_ones)
nll_base = F.cross_entropy(
logits_base[:, :-1].contiguous().view(-1, vocab_size),
olmo_labels[:, 1:].contiguous().view(-1),
)
nll_baseline_sum += nll_base.item()
# Topology
topo = compute_topology_metrics(A_soft)
n += 1
print(f"\n{'='*50}")
print(f"Evaluation Results ({n} batches)")
print(f"{'='*50}")
print(f" eval/nll_soft: {nll_soft_sum / n:.4f}")
print(f" eval/nll_hard: {nll_hard_sum / n:.4f}")
print(f" eval/nll_baseline: {nll_baseline_sum / n:.4f}")
print(f" topology/mean_A: {topo['topology/mean_A']:.4f}")
print(f" topology/seq_gate: {topo['topology/seq_gate_frac']:.4f}")
print(f" topology/hyp_gate: {topo['topology/hyp_gate_frac']:.4f}")
if __name__ == "__main__":
main()
|