diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-09 11:23:15 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-09 11:23:15 -0600 |
| commit | 93d77b197d457b1fdfa7341ecd59fc460b20d6b1 (patch) | |
| tree | 0becc0a9c122ddd80a2f88431546a59b3915e0e3 /src | |
| parent | 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b (diff) | |
Fix init state: add logit_bias so A≈1 at init (dense connectivity)
- Add learnable logit_bias=15.0 to PredictorMLP, so σ(15/τ_init) ≈ 0.95
at init, reproducing dense connectivity instead of random A≈0.25
- Fix dtype mismatch: cast A to model dtype (bfloat16) in DAGFormerOLMo.forward
- Fix YAML lr parsing: add type coercion in TrainConfig.from_yaml
- Fix device mismatch: call self.to(device) in StructurePredictor.__init__
- Add python -u for unbuffered SLURM output, TOKENIZERS_PARALLELISM=false
- Delete stale eval_cache.pt (built with buggy MLP input code)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'src')
| -rw-r--r-- | src/model/predictor.py | 12 | ||||
| -rw-r--r-- | src/training/trainer.py | 2 |
2 files changed, 13 insertions, 1 deletions
diff --git a/src/model/predictor.py b/src/model/predictor.py index 0bc0ae3..ed243ad 100644 --- a/src/model/predictor.py +++ b/src/model/predictor.py @@ -83,7 +83,8 @@ class PredictorMLP(nn.Module): See CLAUDE.md §2.3 for architecture. """ - def __init__(self, input_dim: int, hidden_dim: int = 1024, rank: int = 32, num_nodes: int = 256): + def __init__(self, input_dim: int, hidden_dim: int = 1024, rank: int = 32, num_nodes: int = 256, + init_logit: float = 15.0): super().__init__() self.rank = rank self.num_nodes = num_nodes @@ -97,6 +98,12 @@ class PredictorMLP(nn.Module): self.head_U = nn.Linear(hidden_dim, num_nodes * rank) self.head_V = nn.Linear(hidden_dim, num_nodes * rank) + # Learnable bias added to Z logits. Initialized positive so that + # σ(init_logit / τ_init) ≈ 1, reproducing dense connectivity (A≈1) + # at init. With τ_init=5.0: σ(15/5) = σ(3) ≈ 0.95. + # Training can decrease this to enable sparsity. + self.logit_bias = nn.Parameter(torch.tensor(init_logit)) + def forward(self, e: torch.Tensor) -> torch.Tensor: """Map embedding to logit matrix. @@ -110,6 +117,7 @@ class PredictorMLP(nn.Module): U = self.head_U(h).view(-1, self.num_nodes, self.rank) # [B, 256, r] V = self.head_V(h).view(-1, self.num_nodes, self.rank) # [B, 256, r] Z = torch.bmm(U, V.transpose(-1, -2)) # [B, 256, 256] + Z = Z + self.logit_bias # shift logits positive → A≈1 at init return Z @@ -197,6 +205,7 @@ class StructurePredictor(nn.Module): rank: int = 32, cascading_gate_k: float = 5.0, qwen_input_prefix: str = "", + init_logit: float = 15.0, num_nodes: int = 256, heads_per_layer: int = 16, device: Optional[torch.device] = None, @@ -215,6 +224,7 @@ class StructurePredictor(nn.Module): input_dim=self.qwen_encoder.embed_dim, hidden_dim=hidden_dim, rank=rank, + init_logit=init_logit, num_nodes=num_nodes, ) diff --git a/src/training/trainer.py b/src/training/trainer.py index 6be949e..de0eb96 100644 --- a/src/training/trainer.py +++ b/src/training/trainer.py @@ -44,6 +44,7 @@ class TrainConfig: cascading_gate_k: float = 5.0 input_norm: str = "none" qwen_input_prefix: str = "" + init_logit: float = 15.0 # bias on Z logits so A≈1 at init (dense connectivity) # Data dataset: str = "allenai/dolma" @@ -185,6 +186,7 @@ class Trainer: rank=config.predictor_rank, cascading_gate_k=config.cascading_gate_k, qwen_input_prefix=config.qwen_input_prefix, + init_logit=config.init_logit, device=self.device, ) |
