diff options
| -rw-r--r-- | scripts/slurm_train.sh | 13 | ||||
| -rw-r--r-- | src/model/predictor.py | 12 | ||||
| -rw-r--r-- | src/training/trainer.py | 2 |
3 files changed, 15 insertions, 12 deletions
diff --git a/scripts/slurm_train.sh b/scripts/slurm_train.sh index 6b283ea..e1df687 100644 --- a/scripts/slurm_train.sh +++ b/scripts/slurm_train.sh @@ -1,18 +1,9 @@ #!/bin/bash -#SBATCH --partition=gpuA40x4 -#SBATCH --account=bfqt-delta-gpu -#SBATCH --nodes=1 -#SBATCH --gpus-per-node=1 -#SBATCH --time=02:00:00 -#SBATCH --mem=64g -#SBATCH --job-name=dagformer-sanity -#SBATCH --output=logs/sanity_%j.out -#SBATCH --error=logs/sanity_%j.err - export HF_HOME=/projects/bfqt/users/yurenh2/hf_cache export TRANSFORMERS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/transformers export HF_HUB_CACHE=/projects/bfqt/users/yurenh2/hf_cache/hub export HF_DATASETS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/datasets +export TOKENIZERS_PARALLELISM=false export PYTHONPATH=/projects/bfqt/users/yurenh2/ml-projects/DAGFormer:$PYTHONPATH export PATH=$HOME/.local/bin:$PATH @@ -27,4 +18,4 @@ echo "GPU: $(nvidia-smi --query-gpu=name,memory.total --format=csv,noheader)" echo "" echo "=== Starting training ===" -python3 scripts/train.py --config configs/sanity_check.yaml +python3 -u scripts/train.py --config configs/sanity_check.yaml diff --git a/src/model/predictor.py b/src/model/predictor.py index 0bc0ae3..ed243ad 100644 --- a/src/model/predictor.py +++ b/src/model/predictor.py @@ -83,7 +83,8 @@ class PredictorMLP(nn.Module): See CLAUDE.md §2.3 for architecture. """ - def __init__(self, input_dim: int, hidden_dim: int = 1024, rank: int = 32, num_nodes: int = 256): + def __init__(self, input_dim: int, hidden_dim: int = 1024, rank: int = 32, num_nodes: int = 256, + init_logit: float = 15.0): super().__init__() self.rank = rank self.num_nodes = num_nodes @@ -97,6 +98,12 @@ class PredictorMLP(nn.Module): self.head_U = nn.Linear(hidden_dim, num_nodes * rank) self.head_V = nn.Linear(hidden_dim, num_nodes * rank) + # Learnable bias added to Z logits. Initialized positive so that + # σ(init_logit / τ_init) ≈ 1, reproducing dense connectivity (A≈1) + # at init. With τ_init=5.0: σ(15/5) = σ(3) ≈ 0.95. + # Training can decrease this to enable sparsity. + self.logit_bias = nn.Parameter(torch.tensor(init_logit)) + def forward(self, e: torch.Tensor) -> torch.Tensor: """Map embedding to logit matrix. @@ -110,6 +117,7 @@ class PredictorMLP(nn.Module): U = self.head_U(h).view(-1, self.num_nodes, self.rank) # [B, 256, r] V = self.head_V(h).view(-1, self.num_nodes, self.rank) # [B, 256, r] Z = torch.bmm(U, V.transpose(-1, -2)) # [B, 256, 256] + Z = Z + self.logit_bias # shift logits positive → A≈1 at init return Z @@ -197,6 +205,7 @@ class StructurePredictor(nn.Module): rank: int = 32, cascading_gate_k: float = 5.0, qwen_input_prefix: str = "", + init_logit: float = 15.0, num_nodes: int = 256, heads_per_layer: int = 16, device: Optional[torch.device] = None, @@ -215,6 +224,7 @@ class StructurePredictor(nn.Module): input_dim=self.qwen_encoder.embed_dim, hidden_dim=hidden_dim, rank=rank, + init_logit=init_logit, num_nodes=num_nodes, ) diff --git a/src/training/trainer.py b/src/training/trainer.py index 6be949e..de0eb96 100644 --- a/src/training/trainer.py +++ b/src/training/trainer.py @@ -44,6 +44,7 @@ class TrainConfig: cascading_gate_k: float = 5.0 input_norm: str = "none" qwen_input_prefix: str = "" + init_logit: float = 15.0 # bias on Z logits so A≈1 at init (dense connectivity) # Data dataset: str = "allenai/dolma" @@ -185,6 +186,7 @@ class Trainer: rank=config.predictor_rank, cascading_gate_k=config.cascading_gate_k, qwen_input_prefix=config.qwen_input_prefix, + init_logit=config.init_logit, device=self.device, ) |
