From 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 11:00:39 -0600 Subject: Initial implementation: DAGFormer Phase 1 - olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 --- src/utils/__init__.py | 0 src/utils/logging.py | 63 +++++++++++++++++++++++++++++++++++++ src/utils/topology.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+) create mode 100644 src/utils/__init__.py create mode 100644 src/utils/logging.py create mode 100644 src/utils/topology.py (limited to 'src/utils') diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/logging.py b/src/utils/logging.py new file mode 100644 index 0000000..a9b3b10 --- /dev/null +++ b/src/utils/logging.py @@ -0,0 +1,63 @@ +"""Wandb integration for DAGFormer training. + +Logs all metrics from CLAUDE.md §7. +""" + +from __future__ import annotations + +from typing import Any, Optional + +import wandb + + +def init_wandb( + project: str, + run_name: str, + config: dict[str, Any], +) -> Optional[wandb.sdk.wandb_run.Run]: + """Initialize wandb run. + + Args: + project: wandb project name + run_name: run display name + config: full config dict to log + + Returns: + wandb run object (or None if wandb init fails) + """ + try: + run = wandb.init( + project=project, + name=run_name, + config=config, + ) + return run + except Exception as e: + print(f"WARNING: wandb init failed: {e}. Continuing without wandb.") + return None + + +def log_metrics( + metrics: dict[str, float], + step: int, + run: Optional[wandb.sdk.wandb_run.Run] = None, +) -> None: + """Log metrics to wandb. + + Args: + metrics: dict of metric_name → value + step: global training step + run: wandb run (if None, prints to stdout instead) + """ + if run is not None: + wandb.log(metrics, step=step) + else: + # Fallback: print metrics + parts = [f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}" for k, v in metrics.items()] + print(f"[step {step}] {', '.join(parts)}") + + +def finish_wandb(run: Optional[wandb.sdk.wandb_run.Run] = None) -> None: + """Finish wandb run.""" + if run is not None: + wandb.finish() diff --git a/src/utils/topology.py b/src/utils/topology.py new file mode 100644 index 0000000..3a2e0a8 --- /dev/null +++ b/src/utils/topology.py @@ -0,0 +1,87 @@ +"""A matrix analysis utilities for logging and monitoring. + +Computes topology metrics per CLAUDE.md §7. +""" + +from __future__ import annotations + +import torch + +from src.model.olmo_graph import create_block_upper_triangular_mask + + +def compute_topology_metrics(A: torch.Tensor, num_heads: int = 16) -> dict[str, float]: + """Compute topology metrics from adjacency matrix A. + + Args: + A: [batch, 256, 256] — gate matrix + num_heads: heads per layer (16 for OLMo2-1B) + + Returns: + dict with metrics: mean_A, seq_gate_frac, hyp_gate_frac, jaccard_var + """ + batch = A.shape[0] + num_nodes = A.shape[1] + mask = create_block_upper_triangular_mask(num_nodes, num_heads).to(A.device) + + # Mean A over valid entries + valid_entries = A[:, mask.bool()] # [batch, 30720] + mean_A = valid_entries.mean().item() + + # Classify connections: adjacent (layer diff == 1) vs skip (layer diff > 1) + layer_idx = torch.arange(num_nodes, device=A.device) // num_heads + layer_diff = layer_idx.unsqueeze(0) - layer_idx.unsqueeze(1) # [256, 256] + # layer_diff[i,j] = layer(j) - layer(i) + + adj_mask = (layer_diff == 1) & mask.bool() # adjacent-layer connections + skip_mask = (layer_diff > 1) & mask.bool() # skip connections + + # Fraction of gates > 0.5 + adj_vals = A[:, adj_mask] # [batch, 3840] + skip_vals = A[:, skip_mask] # [batch, 26880] + + seq_gate_frac = (adj_vals > 0.5).float().mean().item() if adj_vals.numel() > 0 else 0.0 + hyp_gate_frac = (skip_vals > 0.5).float().mean().item() if skip_vals.numel() > 0 else 0.0 + + # Jaccard variance across batch + jaccard_var = _jaccard_variance(A, mask).item() if batch > 1 else 0.0 + + return { + "topology/mean_A": mean_A, + "topology/seq_gate_frac": seq_gate_frac, + "topology/hyp_gate_frac": hyp_gate_frac, + "topology/jaccard_var": jaccard_var, + } + + +def _jaccard_variance(A: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Compute variance of pairwise Jaccard similarity across batch. + + Measures how context-dependent the topologies are. + Higher variance = more context-dependent routing. + """ + batch = A.shape[0] + if batch < 2: + return torch.tensor(0.0) + + # Binarize at 0.5 threshold for Jaccard + binary = (A > 0.5).float() + valid = mask.bool() + + # Extract valid entries: [batch, num_valid] + entries = binary[:, valid] + + # Pairwise Jaccard + jaccards = [] + for i in range(batch): + for j in range(i + 1, batch): + intersection = (entries[i] * entries[j]).sum() + union = ((entries[i] + entries[j]) > 0).float().sum() + jaccard = intersection / union.clamp(min=1.0) + jaccards.append(jaccard) + + if not jaccards: + return torch.tensor(0.0) + + jaccards = torch.stack(jaccards) + return jaccards.var() -- cgit v1.2.3