summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--configs/a12_init_logit_3.yaml53
-rw-r--r--configs/a13_init_logit_0.yaml53
-rw-r--r--configs/a14_init_logit_1.yaml53
-rw-r--r--experiments/results.md89
-rw-r--r--scripts/slurm_a12.sh22
-rw-r--r--scripts/slurm_a13.sh22
-rw-r--r--scripts/slurm_a14.sh22
-rw-r--r--src/data/dolma.py93
8 files changed, 366 insertions, 41 deletions
diff --git a/configs/a12_init_logit_3.yaml b/configs/a12_init_logit_3.yaml
new file mode 100644
index 0000000..bf7f6db
--- /dev/null
+++ b/configs/a12_init_logit_3.yaml
@@ -0,0 +1,53 @@
+# A12 — Init logit = 3.0 (moderate, sigmoid not saturated)
+# Purpose: A starts at σ(1.5)≈0.82, gradient ~250× larger than logit=15.
+# Does predictor learn useful topology when sigmoid is not saturated?
+# Run: python scripts/train.py --config configs/a12_init_logit_3.yaml
+
+# Model
+olmo_model_id: "allenai/OLMo-2-0425-1B"
+qwen_model_id: "Qwen/Qwen3-Embedding-0.6B"
+
+# Predictor
+predictor_hidden_dim: 1024
+predictor_rank: 32
+cascading_gate_k: 5.0
+input_norm: "none"
+init_logit: 3.0
+
+# Data
+dataset: "allenai/dolma"
+dataset_name: "v1_7"
+seq_len: 1024
+batch_size: 4
+micro_batch_size: 2
+qwen_input_prefix: ""
+
+# Eval
+eval_skip: 10000
+eval_size: 50
+
+# Training — ~50M tokens = 12500 steps
+total_steps: 12500
+lr: 3e-4
+weight_decay: 0.01
+optimizer: "adamw"
+
+# Schedules — constant tau=2 (same as S2), no sparsity
+tau_init: 2.0
+tau_final: 2.0
+tau_schedule: "cosine"
+lambda_max: 0.0
+lambda_warmup_frac: 0.2
+
+# Logging
+wandb_project: "dagformer"
+wandb_run_name: "a12-init-logit-3"
+log_every: 10
+eval_every: 500
+
+# Checkpointing
+save_every: 2500
+save_dir: "checkpoints/a12/"
+
+# Hardware
+num_gpus: 1
diff --git a/configs/a13_init_logit_0.yaml b/configs/a13_init_logit_0.yaml
new file mode 100644
index 0000000..5bd356d
--- /dev/null
+++ b/configs/a13_init_logit_0.yaml
@@ -0,0 +1,53 @@
+# A13 — Init logit = 0.0 (maximum gradient, A starts at 0.5)
+# Purpose: σ(0)=0.5, maximum gradient signal. Init NLL will be bad
+# but learning space is large. Tests if predictor can learn from scratch.
+# Run: python scripts/train.py --config configs/a13_init_logit_0.yaml
+
+# Model
+olmo_model_id: "allenai/OLMo-2-0425-1B"
+qwen_model_id: "Qwen/Qwen3-Embedding-0.6B"
+
+# Predictor
+predictor_hidden_dim: 1024
+predictor_rank: 32
+cascading_gate_k: 5.0
+input_norm: "none"
+init_logit: 0.0
+
+# Data
+dataset: "allenai/dolma"
+dataset_name: "v1_7"
+seq_len: 1024
+batch_size: 4
+micro_batch_size: 2
+qwen_input_prefix: ""
+
+# Eval
+eval_skip: 10000
+eval_size: 50
+
+# Training — ~50M tokens = 12500 steps
+total_steps: 12500
+lr: 3e-4
+weight_decay: 0.01
+optimizer: "adamw"
+
+# Schedules — constant tau=2 (same as S2), no sparsity
+tau_init: 2.0
+tau_final: 2.0
+tau_schedule: "cosine"
+lambda_max: 0.0
+lambda_warmup_frac: 0.2
+
+# Logging
+wandb_project: "dagformer"
+wandb_run_name: "a13-init-logit-0"
+log_every: 10
+eval_every: 500
+
+# Checkpointing
+save_every: 2500
+save_dir: "checkpoints/a13/"
+
+# Hardware
+num_gpus: 1
diff --git a/configs/a14_init_logit_1.yaml b/configs/a14_init_logit_1.yaml
new file mode 100644
index 0000000..f3278cf
--- /dev/null
+++ b/configs/a14_init_logit_1.yaml
@@ -0,0 +1,53 @@
+# A14 — Init logit = 1.0 (compromise: A≈0.62, strong gradient, mild dense bias)
+# Purpose: σ(0.5)≈0.62, strong gradient with slight bias toward connectivity.
+# Middle ground between A12 (too high?) and A13 (no bias at all).
+# Run: python scripts/train.py --config configs/a14_init_logit_1.yaml
+
+# Model
+olmo_model_id: "allenai/OLMo-2-0425-1B"
+qwen_model_id: "Qwen/Qwen3-Embedding-0.6B"
+
+# Predictor
+predictor_hidden_dim: 1024
+predictor_rank: 32
+cascading_gate_k: 5.0
+input_norm: "none"
+init_logit: 1.0
+
+# Data
+dataset: "allenai/dolma"
+dataset_name: "v1_7"
+seq_len: 1024
+batch_size: 4
+micro_batch_size: 2
+qwen_input_prefix: ""
+
+# Eval
+eval_skip: 10000
+eval_size: 50
+
+# Training — ~50M tokens = 12500 steps
+total_steps: 12500
+lr: 3e-4
+weight_decay: 0.01
+optimizer: "adamw"
+
+# Schedules — constant tau=2 (same as S2), no sparsity
+tau_init: 2.0
+tau_final: 2.0
+tau_schedule: "cosine"
+lambda_max: 0.0
+lambda_warmup_frac: 0.2
+
+# Logging
+wandb_project: "dagformer"
+wandb_run_name: "a14-init-logit-1"
+log_every: 10
+eval_every: 500
+
+# Checkpointing
+save_every: 2500
+save_dir: "checkpoints/a14/"
+
+# Hardware
+num_gpus: 1
diff --git a/experiments/results.md b/experiments/results.md
index d362f7e..b6f33ea 100644
--- a/experiments/results.md
+++ b/experiments/results.md
@@ -84,13 +84,21 @@ Observations (attempt 1):
- No checkpoint saved (save_every=2500, crashed at 1860)
- Crashed due to Dolma streaming HTTP error, not code bug
-**Attempt 2** — Job 15798568 (fresh start, no checkpoint from attempt 1)
+**Attempt 2** — Job 15798568, crashed at step ~5190 (same Dolma HTTP error)
-| Metric | Value |
-|--------|-------|
-| eval/nll_soft | |
-| eval/nll_hard | |
-| topology/mean_A | |
+| Step | nll_soft | nll_hard | nll_baseline | mean_A |
+|------|----------|----------|--------------|--------|
+| 500 | 2.4578 | 2.4597 | 2.4569 | 0.992 |
+| 1000 | 2.4566 | 2.4591 | 2.4569 | 0.989 |
+| 2000 | 2.4531 | 2.4537 | 2.4569 | 0.991 |
+| 3000 | 2.4484 | 2.4491 | 2.4569 | 0.992 |
+| 3500 | **2.4475** | 2.4487 | 2.4569 | 0.993 |
+| 4000 | 2.4569 | 2.4569 | 2.4569 | 0.999 |
+| 5000 | 2.4550 | 2.4578 | 2.4569 | 0.980 |
+
+**S2 conclusion**: Predictor stuck near init (mean_A≈0.99). Sigmoid saturation confirmed —
+init_logit=15 + τ=2 gives ∂A/∂Z≈0.0003, insufficient gradient. Moving to A12-A14 (lower init_logit).
+Status: **DONE** (no need to re-run, hypothesis confirmed across 2 attempts, ~7K total steps)
---
@@ -165,6 +173,75 @@ Observations (attempt 1):
| P1 | k=5 fixed | | | | (reference) |
| A11 | k=5 learnable | | | | |
+### A12–A14: Init logit ablation (sigmoid saturation fix)
+
+**Problem diagnosis (from S1 & S2):**
+
+S1 (τ=5) 和 S2 (τ=2) 的 predictor 都没有学到有意义的拓扑偏离(eval NLL ≈ baseline,mean_A ≈ 0.99)。
+初始假设:sigmoid 饱和导致梯度消失(init_logit=15, ∂A/∂Z ≈ 0.0003 at τ=2)。
+
+| ID | init_logit | Init A (τ=2) | ∂A/∂Z (τ=2) | Tokens | Purpose |
+|----|-----------|--------------|-------------|--------|---------|
+| A12 | 3.0 | σ(1.5) ≈ 0.82 | 0.074 | 50M | Moderate: A starts high but not saturated. |
+| A13 | 0.0 | σ(0) = 0.50 | 0.125 | 50M | Maximum gradient signal. |
+| A14 | 1.0 | σ(0.5) ≈ 0.62 | 0.118 | 50M | Compromise. |
+
+**A12** — Job 15803742 (**DONE**, 12500/12500 steps)
+
+| Step | nll_soft | nll_hard | nll_baseline | mean_A |
+|------|----------|----------|--------------|--------|
+| 500 | 2.4566 | 2.4566 | 2.4569 | 0.985 |
+| 1500 | 2.8781 | 2.8781 | 2.4569 | 0.884 |
+| 3000 | 2.6844 | 2.6844 | 2.4569 | 0.903 |
+| 5000 | 2.7062 | 2.7062 | 2.4569 | 0.897 |
+| 8500 | 2.7556 | 2.7556 | 2.4569 | 0.894 |
+| **12500** | **2.7563** | **2.7563** | **2.4569** | **0.894** |
+
+**A13** — Job 15803743 (**DONE**, 12500/12500 steps)
+
+| Step | nll_soft | nll_hard | nll_baseline | mean_A |
+|------|----------|----------|--------------|--------|
+| 500 | 2.9356 | 2.9362 | 2.4569 | 0.834 |
+| 1500 | 3.6700 | 3.6694 | 2.4569 | 0.722 |
+| 3000 | 3.4731 | 3.4731 | 2.4569 | 0.694 |
+| 5000 | 3.6150 | 3.6150 | 2.4569 | 0.678 |
+| 8500 | 3.5187 | 3.5187 | 2.4569 | 0.676 |
+| **12500** | **3.5100** | **3.5100** | **2.4569** | **0.677** |
+
+**A14** — Job 15803744 (**DONE**, 12500/12500 steps)
+
+| Step | nll_soft | nll_hard | nll_baseline | mean_A |
+|------|----------|----------|--------------|--------|
+| 500 | 2.4553 | 2.4553 | 2.4569 | 0.992 |
+| 1500 | 3.7050 | 3.7050 | 2.4569 | 0.734 |
+| 3000 | 3.8919 | 3.8925 | 2.4569 | 0.721 |
+| 5000 | 3.4019 | 3.4019 | 2.4569 | 0.726 |
+| 8500 | 3.2725 | 3.2725 | 2.4569 | 0.733 |
+| **12500** | **3.2550** | **3.2550** | **2.4569** | **0.734** |
+
+**A12–A14 综合结论:**
+
+| ID | init_logit | Final nll_soft | vs baseline | Final mean_A | 收敛 |
+|----|-----------|---------------|-------------|-------------|------|
+| S2 | 15.0 | 2.4569 | +0.00 | 0.99 | 卡在 init(sigmoid 饱和) |
+| A12 | 3.0 | 2.7563 | **+0.30** | 0.894 | 完全收敛(后 4K 步不动) |
+| A14 | 1.0 | 3.2550 | **+0.80** | 0.734 | 完全收敛 |
+| A13 | 0.0 | 3.5100 | **+1.05** | 0.677 | 完全收敛 |
+
+| 发现 | 详情 |
+|------|------|
+| 梯度饱和假设 | **部分正确**:降低 init_logit 后 gate 确实在动,梯度流通了 |
+| NLL 趋势 | **全部恶化**:init_logit 越低 → 偏离 dense 越远 → NLL 越差 |
+| Context 依赖性 | **无**:jaccard_var=NaN,predictor 对所有 context 输出相同的 static A |
+| 收敛行为 | 三个都在 ~8K 步后完全停滞,学到固定的 context-independent topology |
+| nll_soft vs nll_hard | 完全相同,τ=2 下 soft≈hard |
+
+**根本结论**:OLMo 的权重是在 dense topology 下预训练的。**Frozen 状态下,A=1 就是全局最优**。
+任何偏离 dense 的拓扑 = 删除模型期望的信息 = NLL 必然变差。这不是梯度问题、init 问题或数据量问题,
+而是 **loss landscape 本身不允许 frozen model 从 sparse topology 中获益**。
+
+Phase 1(frozen OLMo)的局限性已确认。需要 Phase 2(unfreeze OLMo)让模型适应新拓扑。
+
---
## Analysis Experiments
diff --git a/scripts/slurm_a12.sh b/scripts/slurm_a12.sh
new file mode 100644
index 0000000..941caae
--- /dev/null
+++ b/scripts/slurm_a12.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+#SBATCH --signal=SIGUSR1@120
+export HF_HOME=/projects/bfqt/users/yurenh2/hf_cache
+export TRANSFORMERS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/transformers
+export HF_HUB_CACHE=/projects/bfqt/users/yurenh2/hf_cache/hub
+export HF_DATASETS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/datasets
+export TOKENIZERS_PARALLELISM=false
+
+export PYTHONPATH=/projects/bfqt/users/yurenh2/ml-projects/DAGFormer:$PYTHONPATH
+export PATH=$HOME/.local/bin:$PATH
+
+cd /projects/bfqt/users/yurenh2/ml-projects/DAGFormer
+mkdir -p logs checkpoints/a12
+
+echo "=== Job Info ==="
+echo "Job ID: $SLURM_JOB_ID"
+echo "Node: $SLURM_NODELIST"
+echo "GPU: $(nvidia-smi --query-gpu=name,memory.total --format=csv,noheader)"
+echo ""
+
+echo "=== Starting A12: init_logit=3.0 ==="
+python3 -u scripts/train.py --config configs/a12_init_logit_3.yaml
diff --git a/scripts/slurm_a13.sh b/scripts/slurm_a13.sh
new file mode 100644
index 0000000..8b4e5c0
--- /dev/null
+++ b/scripts/slurm_a13.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+#SBATCH --signal=SIGUSR1@120
+export HF_HOME=/projects/bfqt/users/yurenh2/hf_cache
+export TRANSFORMERS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/transformers
+export HF_HUB_CACHE=/projects/bfqt/users/yurenh2/hf_cache/hub
+export HF_DATASETS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/datasets
+export TOKENIZERS_PARALLELISM=false
+
+export PYTHONPATH=/projects/bfqt/users/yurenh2/ml-projects/DAGFormer:$PYTHONPATH
+export PATH=$HOME/.local/bin:$PATH
+
+cd /projects/bfqt/users/yurenh2/ml-projects/DAGFormer
+mkdir -p logs checkpoints/a13
+
+echo "=== Job Info ==="
+echo "Job ID: $SLURM_JOB_ID"
+echo "Node: $SLURM_NODELIST"
+echo "GPU: $(nvidia-smi --query-gpu=name,memory.total --format=csv,noheader)"
+echo ""
+
+echo "=== Starting A13: init_logit=0.0 ==="
+python3 -u scripts/train.py --config configs/a13_init_logit_0.yaml
diff --git a/scripts/slurm_a14.sh b/scripts/slurm_a14.sh
new file mode 100644
index 0000000..a5daabe
--- /dev/null
+++ b/scripts/slurm_a14.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+#SBATCH --signal=SIGUSR1@120
+export HF_HOME=/projects/bfqt/users/yurenh2/hf_cache
+export TRANSFORMERS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/transformers
+export HF_HUB_CACHE=/projects/bfqt/users/yurenh2/hf_cache/hub
+export HF_DATASETS_CACHE=/projects/bfqt/users/yurenh2/hf_cache/datasets
+export TOKENIZERS_PARALLELISM=false
+
+export PYTHONPATH=/projects/bfqt/users/yurenh2/ml-projects/DAGFormer:$PYTHONPATH
+export PATH=$HOME/.local/bin:$PATH
+
+cd /projects/bfqt/users/yurenh2/ml-projects/DAGFormer
+mkdir -p logs checkpoints/a14
+
+echo "=== Job Info ==="
+echo "Job ID: $SLURM_JOB_ID"
+echo "Node: $SLURM_NODELIST"
+echo "GPU: $(nvidia-smi --query-gpu=name,memory.total --format=csv,noheader)"
+echo ""
+
+echo "=== Starting A14: init_logit=1.0 ==="
+python3 -u scripts/train.py --config configs/a14_init_logit_1.yaml
diff --git a/src/data/dolma.py b/src/data/dolma.py
index 4e2baaf..ed8c13b 100644
--- a/src/data/dolma.py
+++ b/src/data/dolma.py
@@ -7,6 +7,7 @@ See CLAUDE.md §3.1.1 for sequence packing specification.
from __future__ import annotations
import os
+import time
from typing import Iterator, Optional
import torch
@@ -14,6 +15,9 @@ from datasets import load_dataset
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer
+MAX_RETRIES = 10
+RETRY_WAIT = 30 # seconds
+
class DolmaPackedDataset(IterableDataset):
"""Streaming Dolma dataset with sequence packing.
@@ -49,8 +53,8 @@ class DolmaPackedDataset(IterableDataset):
self.eos_id = olmo_tokenizer.eos_token_id
assert self.eos_id is not None, "OLMo tokenizer must have an EOS token"
- def __iter__(self) -> Iterator[dict]:
- """Yield packed sequences from Dolma stream."""
+ def _load_stream(self):
+ """Load Dolma streaming dataset with fallback."""
try:
dataset = load_dataset(
self.dataset_name,
@@ -60,7 +64,6 @@ class DolmaPackedDataset(IterableDataset):
trust_remote_code=True,
)
except Exception:
- # Fallback if specific version not available
dataset = load_dataset(
self.dataset_name,
split="train",
@@ -68,43 +71,63 @@ class DolmaPackedDataset(IterableDataset):
trust_remote_code=True,
)
- # Shard for multi-GPU
if self.world_size > 1:
dataset = dataset.shard(num_shards=self.world_size, index=self.rank)
+ return dataset
+
+ def __iter__(self) -> Iterator[dict]:
+ """Yield packed sequences from Dolma stream with retry on HTTP errors."""
buffer: list[int] = []
sample_count = 0
-
- for doc in dataset:
- if self.max_samples is not None and sample_count >= self.max_samples:
- break
-
- text = doc.get("text", "")
- if not text.strip():
- continue
-
- tokens = self.olmo_tokenizer(text, add_special_tokens=False)["input_ids"]
- buffer.extend(tokens)
- buffer.append(self.eos_id)
-
- # Yield packed sequences as buffer fills
- while len(buffer) >= self.seq_len + 1:
- chunk = buffer[:self.seq_len + 1]
- buffer = buffer[self.seq_len + 1:]
-
- olmo_ids = torch.tensor(chunk[:self.seq_len], dtype=torch.long)
- olmo_labels = torch.tensor(chunk[1:self.seq_len + 1], dtype=torch.long)
- raw_text = self.olmo_tokenizer.decode(chunk[:self.seq_len], skip_special_tokens=False)
-
- yield {
- "olmo_ids": olmo_ids,
- "olmo_labels": olmo_labels,
- "raw_text": raw_text,
- }
- sample_count += 1
-
- if self.max_samples is not None and sample_count >= self.max_samples:
- break
+ retries = 0
+
+ while retries <= MAX_RETRIES:
+ try:
+ dataset = self._load_stream()
+
+ for doc in dataset:
+ if self.max_samples is not None and sample_count >= self.max_samples:
+ return
+
+ text = doc.get("text", "")
+ if not text.strip():
+ continue
+
+ tokens = self.olmo_tokenizer(text, add_special_tokens=False)["input_ids"]
+ buffer.extend(tokens)
+ buffer.append(self.eos_id)
+
+ # Yield packed sequences as buffer fills
+ while len(buffer) >= self.seq_len + 1:
+ chunk = buffer[:self.seq_len + 1]
+ buffer = buffer[self.seq_len + 1:]
+
+ olmo_ids = torch.tensor(chunk[:self.seq_len], dtype=torch.long)
+ olmo_labels = torch.tensor(chunk[1:self.seq_len + 1], dtype=torch.long)
+ raw_text = self.olmo_tokenizer.decode(chunk[:self.seq_len], skip_special_tokens=False)
+
+ yield {
+ "olmo_ids": olmo_ids,
+ "olmo_labels": olmo_labels,
+ "raw_text": raw_text,
+ }
+ sample_count += 1
+
+ if self.max_samples is not None and sample_count >= self.max_samples:
+ return
+
+ # Stream exhausted normally
+ return
+
+ except Exception as e:
+ retries += 1
+ if retries > MAX_RETRIES:
+ raise RuntimeError(f"Dolma stream failed after {MAX_RETRIES} retries: {e}") from e
+ print(f"[DolmaDataset] Stream error (retry {retries}/{MAX_RETRIES}): {e}")
+ print(f"[DolmaDataset] Waiting {RETRY_WAIT}s before reconnecting...")
+ time.sleep(RETRY_WAIT)
+ buffer = [] # reset buffer, data order doesn't matter for training
def build_train_dataloader(