From 93d77b197d457b1fdfa7341ecd59fc460b20d6b1 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 11:23:15 -0600 Subject: =?UTF-8?q?Fix=20init=20state:=20add=20logit=5Fbias=20so=20A?= =?UTF-8?q?=E2=89=881=20at=20init=20(dense=20connectivity)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add learnable logit_bias=15.0 to PredictorMLP, so σ(15/τ_init) ≈ 0.95 at init, reproducing dense connectivity instead of random A≈0.25 - Fix dtype mismatch: cast A to model dtype (bfloat16) in DAGFormerOLMo.forward - Fix YAML lr parsing: add type coercion in TrainConfig.from_yaml - Fix device mismatch: call self.to(device) in StructurePredictor.__init__ - Add python -u for unbuffered SLURM output, TOKENIZERS_PARALLELISM=false - Delete stale eval_cache.pt (built with buggy MLP input code) Co-Authored-By: Claude Opus 4.6 --- src/training/trainer.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'src/training/trainer.py') diff --git a/src/training/trainer.py b/src/training/trainer.py index 6be949e..de0eb96 100644 --- a/src/training/trainer.py +++ b/src/training/trainer.py @@ -44,6 +44,7 @@ class TrainConfig: cascading_gate_k: float = 5.0 input_norm: str = "none" qwen_input_prefix: str = "" + init_logit: float = 15.0 # bias on Z logits so A≈1 at init (dense connectivity) # Data dataset: str = "allenai/dolma" @@ -185,6 +186,7 @@ class Trainer: rank=config.predictor_rank, cascading_gate_k=config.cascading_gate_k, qwen_input_prefix=config.qwen_input_prefix, + init_logit=config.init_logit, device=self.device, ) -- cgit v1.2.3