summaryrefslogtreecommitdiff
path: root/src/utils/topology.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
commit13ddc8dc583d8b1355909970cb8c27f85b7d3c8b (patch)
tree073534138604c1c49021ca7e334322262129f6ac /src/utils/topology.py
Initial implementation: DAGFormer Phase 1
- olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'src/utils/topology.py')
-rw-r--r--src/utils/topology.py87
1 files changed, 87 insertions, 0 deletions
diff --git a/src/utils/topology.py b/src/utils/topology.py
new file mode 100644
index 0000000..3a2e0a8
--- /dev/null
+++ b/src/utils/topology.py
@@ -0,0 +1,87 @@
+"""A matrix analysis utilities for logging and monitoring.
+
+Computes topology metrics per CLAUDE.md §7.
+"""
+
+from __future__ import annotations
+
+import torch
+
+from src.model.olmo_graph import create_block_upper_triangular_mask
+
+
+def compute_topology_metrics(A: torch.Tensor, num_heads: int = 16) -> dict[str, float]:
+ """Compute topology metrics from adjacency matrix A.
+
+ Args:
+ A: [batch, 256, 256] — gate matrix
+ num_heads: heads per layer (16 for OLMo2-1B)
+
+ Returns:
+ dict with metrics: mean_A, seq_gate_frac, hyp_gate_frac, jaccard_var
+ """
+ batch = A.shape[0]
+ num_nodes = A.shape[1]
+ mask = create_block_upper_triangular_mask(num_nodes, num_heads).to(A.device)
+
+ # Mean A over valid entries
+ valid_entries = A[:, mask.bool()] # [batch, 30720]
+ mean_A = valid_entries.mean().item()
+
+ # Classify connections: adjacent (layer diff == 1) vs skip (layer diff > 1)
+ layer_idx = torch.arange(num_nodes, device=A.device) // num_heads
+ layer_diff = layer_idx.unsqueeze(0) - layer_idx.unsqueeze(1) # [256, 256]
+ # layer_diff[i,j] = layer(j) - layer(i)
+
+ adj_mask = (layer_diff == 1) & mask.bool() # adjacent-layer connections
+ skip_mask = (layer_diff > 1) & mask.bool() # skip connections
+
+ # Fraction of gates > 0.5
+ adj_vals = A[:, adj_mask] # [batch, 3840]
+ skip_vals = A[:, skip_mask] # [batch, 26880]
+
+ seq_gate_frac = (adj_vals > 0.5).float().mean().item() if adj_vals.numel() > 0 else 0.0
+ hyp_gate_frac = (skip_vals > 0.5).float().mean().item() if skip_vals.numel() > 0 else 0.0
+
+ # Jaccard variance across batch
+ jaccard_var = _jaccard_variance(A, mask).item() if batch > 1 else 0.0
+
+ return {
+ "topology/mean_A": mean_A,
+ "topology/seq_gate_frac": seq_gate_frac,
+ "topology/hyp_gate_frac": hyp_gate_frac,
+ "topology/jaccard_var": jaccard_var,
+ }
+
+
+def _jaccard_variance(A: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ """Compute variance of pairwise Jaccard similarity across batch.
+
+ Measures how context-dependent the topologies are.
+ Higher variance = more context-dependent routing.
+ """
+ batch = A.shape[0]
+ if batch < 2:
+ return torch.tensor(0.0)
+
+ # Binarize at 0.5 threshold for Jaccard
+ binary = (A > 0.5).float()
+ valid = mask.bool()
+
+ # Extract valid entries: [batch, num_valid]
+ entries = binary[:, valid]
+
+ # Pairwise Jaccard
+ jaccards = []
+ for i in range(batch):
+ for j in range(i + 1, batch):
+ intersection = (entries[i] * entries[j]).sum()
+ union = ((entries[i] + entries[j]) > 0).float().sum()
+ jaccard = intersection / union.clamp(min=1.0)
+ jaccards.append(jaccard)
+
+ if not jaccards:
+ return torch.tensor(0.0)
+
+ jaccards = torch.stack(jaccards)
+ return jaccards.var()