summaryrefslogtreecommitdiff
path: root/src/model/predictor.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/model/predictor.py')
-rw-r--r--src/model/predictor.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/model/predictor.py b/src/model/predictor.py
index ed243ad..b5f9674 100644
--- a/src/model/predictor.py
+++ b/src/model/predictor.py
@@ -98,6 +98,16 @@ class PredictorMLP(nn.Module):
self.head_U = nn.Linear(hidden_dim, num_nodes * rank)
self.head_V = nn.Linear(hidden_dim, num_nodes * rank)
+ # Initialize head_U and head_V with small weights so UV^T ≈ 0 at init.
+ # Default Kaiming init gives UV^T with std≈√rank≈5.7 which overwhelms
+ # the logit_bias. Small init ensures Z ≈ logit_bias ± small noise.
+ # std=0.01 gives UV^T std≈0.6 (with hidden_dim=1024, rank=32),
+ # small vs logit_bias=15 but enough for input-dependent gradients.
+ nn.init.normal_(self.head_U.weight, std=0.01)
+ nn.init.normal_(self.head_V.weight, std=0.01)
+ nn.init.zeros_(self.head_U.bias)
+ nn.init.zeros_(self.head_V.bias)
+
# Learnable bias added to Z logits. Initialized positive so that
# σ(init_logit / τ_init) ≈ 1, reproducing dense connectivity (A≈1)
# at init. With τ_init=5.0: σ(15/5) = σ(3) ≈ 0.95.