From 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 11:00:39 -0600 Subject: Initial implementation: DAGFormer Phase 1 - olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 --- src/utils/logging.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/utils/logging.py (limited to 'src/utils/logging.py') diff --git a/src/utils/logging.py b/src/utils/logging.py new file mode 100644 index 0000000..a9b3b10 --- /dev/null +++ b/src/utils/logging.py @@ -0,0 +1,63 @@ +"""Wandb integration for DAGFormer training. + +Logs all metrics from CLAUDE.md §7. +""" + +from __future__ import annotations + +from typing import Any, Optional + +import wandb + + +def init_wandb( + project: str, + run_name: str, + config: dict[str, Any], +) -> Optional[wandb.sdk.wandb_run.Run]: + """Initialize wandb run. + + Args: + project: wandb project name + run_name: run display name + config: full config dict to log + + Returns: + wandb run object (or None if wandb init fails) + """ + try: + run = wandb.init( + project=project, + name=run_name, + config=config, + ) + return run + except Exception as e: + print(f"WARNING: wandb init failed: {e}. Continuing without wandb.") + return None + + +def log_metrics( + metrics: dict[str, float], + step: int, + run: Optional[wandb.sdk.wandb_run.Run] = None, +) -> None: + """Log metrics to wandb. + + Args: + metrics: dict of metric_name → value + step: global training step + run: wandb run (if None, prints to stdout instead) + """ + if run is not None: + wandb.log(metrics, step=step) + else: + # Fallback: print metrics + parts = [f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}" for k, v in metrics.items()] + print(f"[step {step}] {', '.join(parts)}") + + +def finish_wandb(run: Optional[wandb.sdk.wandb_run.Run] = None) -> None: + """Finish wandb run.""" + if run is not None: + wandb.finish() -- cgit v1.2.3