diff options
Diffstat (limited to 'src/model')
| -rw-r--r-- | src/model/predictor.py | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/src/model/predictor.py b/src/model/predictor.py index 0bc0ae3..ed243ad 100644 --- a/src/model/predictor.py +++ b/src/model/predictor.py @@ -83,7 +83,8 @@ class PredictorMLP(nn.Module): See CLAUDE.md §2.3 for architecture. """ - def __init__(self, input_dim: int, hidden_dim: int = 1024, rank: int = 32, num_nodes: int = 256): + def __init__(self, input_dim: int, hidden_dim: int = 1024, rank: int = 32, num_nodes: int = 256, + init_logit: float = 15.0): super().__init__() self.rank = rank self.num_nodes = num_nodes @@ -97,6 +98,12 @@ class PredictorMLP(nn.Module): self.head_U = nn.Linear(hidden_dim, num_nodes * rank) self.head_V = nn.Linear(hidden_dim, num_nodes * rank) + # Learnable bias added to Z logits. Initialized positive so that + # σ(init_logit / τ_init) ≈ 1, reproducing dense connectivity (A≈1) + # at init. With τ_init=5.0: σ(15/5) = σ(3) ≈ 0.95. + # Training can decrease this to enable sparsity. + self.logit_bias = nn.Parameter(torch.tensor(init_logit)) + def forward(self, e: torch.Tensor) -> torch.Tensor: """Map embedding to logit matrix. @@ -110,6 +117,7 @@ class PredictorMLP(nn.Module): U = self.head_U(h).view(-1, self.num_nodes, self.rank) # [B, 256, r] V = self.head_V(h).view(-1, self.num_nodes, self.rank) # [B, 256, r] Z = torch.bmm(U, V.transpose(-1, -2)) # [B, 256, 256] + Z = Z + self.logit_bias # shift logits positive → A≈1 at init return Z @@ -197,6 +205,7 @@ class StructurePredictor(nn.Module): rank: int = 32, cascading_gate_k: float = 5.0, qwen_input_prefix: str = "", + init_logit: float = 15.0, num_nodes: int = 256, heads_per_layer: int = 16, device: Optional[torch.device] = None, @@ -215,6 +224,7 @@ class StructurePredictor(nn.Module): input_dim=self.qwen_encoder.embed_dim, hidden_dim=hidden_dim, rank=rank, + init_logit=init_logit, num_nodes=num_nodes, ) |
