summaryrefslogtreecommitdiff
path: root/src/model/predictor.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/model/predictor.py')
-rw-r--r--src/model/predictor.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/src/model/predictor.py b/src/model/predictor.py
index 0bc0ae3..ed243ad 100644
--- a/src/model/predictor.py
+++ b/src/model/predictor.py
@@ -83,7 +83,8 @@ class PredictorMLP(nn.Module):
See CLAUDE.md §2.3 for architecture.
"""
- def __init__(self, input_dim: int, hidden_dim: int = 1024, rank: int = 32, num_nodes: int = 256):
+ def __init__(self, input_dim: int, hidden_dim: int = 1024, rank: int = 32, num_nodes: int = 256,
+ init_logit: float = 15.0):
super().__init__()
self.rank = rank
self.num_nodes = num_nodes
@@ -97,6 +98,12 @@ class PredictorMLP(nn.Module):
self.head_U = nn.Linear(hidden_dim, num_nodes * rank)
self.head_V = nn.Linear(hidden_dim, num_nodes * rank)
+ # Learnable bias added to Z logits. Initialized positive so that
+ # σ(init_logit / τ_init) ≈ 1, reproducing dense connectivity (A≈1)
+ # at init. With τ_init=5.0: σ(15/5) = σ(3) ≈ 0.95.
+ # Training can decrease this to enable sparsity.
+ self.logit_bias = nn.Parameter(torch.tensor(init_logit))
+
def forward(self, e: torch.Tensor) -> torch.Tensor:
"""Map embedding to logit matrix.
@@ -110,6 +117,7 @@ class PredictorMLP(nn.Module):
U = self.head_U(h).view(-1, self.num_nodes, self.rank) # [B, 256, r]
V = self.head_V(h).view(-1, self.num_nodes, self.rank) # [B, 256, r]
Z = torch.bmm(U, V.transpose(-1, -2)) # [B, 256, 256]
+ Z = Z + self.logit_bias # shift logits positive → A≈1 at init
return Z
@@ -197,6 +205,7 @@ class StructurePredictor(nn.Module):
rank: int = 32,
cascading_gate_k: float = 5.0,
qwen_input_prefix: str = "",
+ init_logit: float = 15.0,
num_nodes: int = 256,
heads_per_layer: int = 16,
device: Optional[torch.device] = None,
@@ -215,6 +224,7 @@ class StructurePredictor(nn.Module):
input_dim=self.qwen_encoder.embed_dim,
hidden_dim=hidden_dim,
rank=rank,
+ init_logit=init_logit,
num_nodes=num_nodes,
)