1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
|
"""A matrix analysis utilities for logging and monitoring.
Computes topology metrics per CLAUDE.md §7.
"""
from __future__ import annotations
import torch
from src.model.olmo_graph import create_block_upper_triangular_mask
def compute_topology_metrics(A: torch.Tensor, num_heads: int = 16) -> dict[str, float]:
"""Compute topology metrics from adjacency matrix A.
Args:
A: [batch, 256, 256] — gate matrix
num_heads: heads per layer (16 for OLMo2-1B)
Returns:
dict with metrics: mean_A, seq_gate_frac, hyp_gate_frac, jaccard_var
"""
batch = A.shape[0]
num_nodes = A.shape[1]
mask = create_block_upper_triangular_mask(num_nodes, num_heads).to(A.device)
# Mean A over valid entries
valid_entries = A[:, mask.bool()] # [batch, 30720]
mean_A = valid_entries.mean().item()
# Classify connections: adjacent (layer diff == 1) vs skip (layer diff > 1)
layer_idx = torch.arange(num_nodes, device=A.device) // num_heads
layer_diff = layer_idx.unsqueeze(0) - layer_idx.unsqueeze(1) # [256, 256]
# layer_diff[i,j] = layer(j) - layer(i)
adj_mask = (layer_diff == 1) & mask.bool() # adjacent-layer connections
skip_mask = (layer_diff > 1) & mask.bool() # skip connections
# Fraction of gates > 0.5
adj_vals = A[:, adj_mask] # [batch, 3840]
skip_vals = A[:, skip_mask] # [batch, 26880]
seq_gate_frac = (adj_vals > 0.5).float().mean().item() if adj_vals.numel() > 0 else 0.0
hyp_gate_frac = (skip_vals > 0.5).float().mean().item() if skip_vals.numel() > 0 else 0.0
# Jaccard variance across batch
jaccard_var = _jaccard_variance(A, mask).item() if batch > 1 else 0.0
return {
"topology/mean_A": mean_A,
"topology/seq_gate_frac": seq_gate_frac,
"topology/hyp_gate_frac": hyp_gate_frac,
"topology/jaccard_var": jaccard_var,
}
def _jaccard_variance(A: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Compute variance of pairwise Jaccard similarity across batch.
Measures how context-dependent the topologies are.
Higher variance = more context-dependent routing.
"""
batch = A.shape[0]
if batch < 2:
return torch.tensor(0.0)
# Binarize at 0.5 threshold for Jaccard
binary = (A > 0.5).float()
valid = mask.bool()
# Extract valid entries: [batch, num_valid]
entries = binary[:, valid]
# Pairwise Jaccard
jaccards = []
for i in range(batch):
for j in range(i + 1, batch):
intersection = (entries[i] * entries[j]).sum()
union = ((entries[i] + entries[j]) > 0).float().sum()
jaccard = intersection / union.clamp(min=1.0)
jaccards.append(jaccard)
if not jaccards:
return torch.tensor(0.0)
jaccards = torch.stack(jaccards)
return jaccards.var()
|