summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:23:15 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:23:15 -0600
commit93d77b197d457b1fdfa7341ecd59fc460b20d6b1 (patch)
tree0becc0a9c122ddd80a2f88431546a59b3915e0e3 /src/training
parent13ddc8dc583d8b1355909970cb8c27f85b7d3c8b (diff)
Fix init state: add logit_bias so A≈1 at init (dense connectivity)
- Add learnable logit_bias=15.0 to PredictorMLP, so σ(15/τ_init) ≈ 0.95 at init, reproducing dense connectivity instead of random A≈0.25 - Fix dtype mismatch: cast A to model dtype (bfloat16) in DAGFormerOLMo.forward - Fix YAML lr parsing: add type coercion in TrainConfig.from_yaml - Fix device mismatch: call self.to(device) in StructurePredictor.__init__ - Add python -u for unbuffered SLURM output, TOKENIZERS_PARALLELISM=false - Delete stale eval_cache.pt (built with buggy MLP input code) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'src/training')
-rw-r--r--src/training/trainer.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/src/training/trainer.py b/src/training/trainer.py
index 6be949e..de0eb96 100644
--- a/src/training/trainer.py
+++ b/src/training/trainer.py
@@ -44,6 +44,7 @@ class TrainConfig:
cascading_gate_k: float = 5.0
input_norm: str = "none"
qwen_input_prefix: str = ""
+ init_logit: float = 15.0 # bias on Z logits so A≈1 at init (dense connectivity)
# Data
dataset: str = "allenai/dolma"
@@ -185,6 +186,7 @@ class Trainer:
rank=config.predictor_rank,
cascading_gate_k=config.cascading_gate_k,
qwen_input_prefix=config.qwen_input_prefix,
+ init_logit=config.init_logit,
device=self.device,
)