summaryrefslogtreecommitdiff
path: root/src/utils/logging.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/utils/logging.py')
-rw-r--r--src/utils/logging.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/src/utils/logging.py b/src/utils/logging.py
new file mode 100644
index 0000000..a9b3b10
--- /dev/null
+++ b/src/utils/logging.py
@@ -0,0 +1,63 @@
+"""Wandb integration for DAGFormer training.
+
+Logs all metrics from CLAUDE.md §7.
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional
+
+import wandb
+
+
+def init_wandb(
+ project: str,
+ run_name: str,
+ config: dict[str, Any],
+) -> Optional[wandb.sdk.wandb_run.Run]:
+ """Initialize wandb run.
+
+ Args:
+ project: wandb project name
+ run_name: run display name
+ config: full config dict to log
+
+ Returns:
+ wandb run object (or None if wandb init fails)
+ """
+ try:
+ run = wandb.init(
+ project=project,
+ name=run_name,
+ config=config,
+ )
+ return run
+ except Exception as e:
+ print(f"WARNING: wandb init failed: {e}. Continuing without wandb.")
+ return None
+
+
+def log_metrics(
+ metrics: dict[str, float],
+ step: int,
+ run: Optional[wandb.sdk.wandb_run.Run] = None,
+) -> None:
+ """Log metrics to wandb.
+
+ Args:
+ metrics: dict of metric_name → value
+ step: global training step
+ run: wandb run (if None, prints to stdout instead)
+ """
+ if run is not None:
+ wandb.log(metrics, step=step)
+ else:
+ # Fallback: print metrics
+ parts = [f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}" for k, v in metrics.items()]
+ print(f"[step {step}] {', '.join(parts)}")
+
+
+def finish_wandb(run: Optional[wandb.sdk.wandb_run.Run] = None) -> None:
+ """Finish wandb run."""
+ if run is not None:
+ wandb.finish()