diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-09 11:00:39 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-09 11:00:39 -0600 |
| commit | 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b (patch) | |
| tree | 073534138604c1c49021ca7e334322262129f6ac /src/utils/logging.py | |
Initial implementation: DAGFormer Phase 1
- olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A
- Proportional attribution for post-norm decomposition
- All 6 GPU sanity checks pass (baseline diff = 0.000001)
- predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate
- pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL)
- trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing
- dolma.py: Streaming Dolma v1.7 with sequence packing
- 43/43 unit tests pass
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'src/utils/logging.py')
| -rw-r--r-- | src/utils/logging.py | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/src/utils/logging.py b/src/utils/logging.py new file mode 100644 index 0000000..a9b3b10 --- /dev/null +++ b/src/utils/logging.py @@ -0,0 +1,63 @@ +"""Wandb integration for DAGFormer training. + +Logs all metrics from CLAUDE.md §7. +""" + +from __future__ import annotations + +from typing import Any, Optional + +import wandb + + +def init_wandb( + project: str, + run_name: str, + config: dict[str, Any], +) -> Optional[wandb.sdk.wandb_run.Run]: + """Initialize wandb run. + + Args: + project: wandb project name + run_name: run display name + config: full config dict to log + + Returns: + wandb run object (or None if wandb init fails) + """ + try: + run = wandb.init( + project=project, + name=run_name, + config=config, + ) + return run + except Exception as e: + print(f"WARNING: wandb init failed: {e}. Continuing without wandb.") + return None + + +def log_metrics( + metrics: dict[str, float], + step: int, + run: Optional[wandb.sdk.wandb_run.Run] = None, +) -> None: + """Log metrics to wandb. + + Args: + metrics: dict of metric_name → value + step: global training step + run: wandb run (if None, prints to stdout instead) + """ + if run is not None: + wandb.log(metrics, step=step) + else: + # Fallback: print metrics + parts = [f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}" for k, v in metrics.items()] + print(f"[step {step}] {', '.join(parts)}") + + +def finish_wandb(run: Optional[wandb.sdk.wandb_run.Run] = None) -> None: + """Finish wandb run.""" + if run is not None: + wandb.finish() |
