summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
commit13ddc8dc583d8b1355909970cb8c27f85b7d3c8b (patch)
tree073534138604c1c49021ca7e334322262129f6ac /tests
Initial implementation: DAGFormer Phase 1
- olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/test_gumbel.py68
-rw-r--r--tests/test_olmo_graph.py153
-rw-r--r--tests/test_predictor.py206
3 files changed, 427 insertions, 0 deletions
diff --git a/tests/test_gumbel.py b/tests/test_gumbel.py
new file mode 100644
index 0000000..b458948
--- /dev/null
+++ b/tests/test_gumbel.py
@@ -0,0 +1,68 @@
+"""Tests specifically for Gumbel-Sigmoid correctness and mathematical properties."""
+
+import pytest
+import torch
+
+from src.model.predictor import gumbel_sigmoid
+from src.model.olmo_graph import create_block_upper_triangular_mask
+
+
+class TestGumbelSigmoidMath:
+ """Test mathematical properties of Gumbel-Sigmoid."""
+
+ def test_logistic_noise_distribution(self):
+ """Gumbel noise G = log(U) - log(1-U) should follow Logistic(0,1)."""
+ torch.manual_seed(42)
+ n = 100_000
+ U = torch.rand(n).clamp(1e-8, 1 - 1e-8)
+ G = torch.log(U) - torch.log(1 - U)
+ # Logistic(0,1) has mean=0, variance=pi^2/3
+ assert abs(G.mean().item()) < 0.05, f"Logistic noise mean should be ~0, got {G.mean():.4f}"
+ expected_var = torch.pi ** 2 / 3
+ assert abs(G.var().item() - expected_var) < 0.1, \
+ f"Logistic noise var should be ~{expected_var:.4f}, got {G.var():.4f}"
+
+ def test_sigmoid_saturation_at_large_logits(self):
+ """Large positive logits → A ≈ 1, large negative → A ≈ 0."""
+ Z = torch.tensor([[[100.0, -100.0]]])
+ A = gumbel_sigmoid(Z, tau=1.0, mode="eval_soft")
+ assert A[0, 0, 0] > 0.999
+ assert A[0, 0, 1] < 0.001
+
+ def test_zero_logit_gives_half(self):
+ """σ(0/τ) = 0.5 for any τ."""
+ Z = torch.tensor([[[0.0]]])
+ A = gumbel_sigmoid(Z, tau=1.0, mode="eval_soft")
+ assert abs(A[0, 0, 0].item() - 0.5) < 1e-6
+
+ def test_hard_threshold_at_zero(self):
+ """Hard mode thresholds at logit=0 (prob=0.5)."""
+ Z = torch.tensor([[[0.1, -0.1, 0.0]]])
+ A = gumbel_sigmoid(Z, tau=1.0, mode="eval_hard")
+ assert A[0, 0, 0] == 1.0 # > 0
+ assert A[0, 0, 1] == 0.0 # < 0
+ assert A[0, 0, 2] == 0.0 # = 0 → not > 0
+
+ def test_train_mean_converges_to_sigmoid(self):
+ """With many samples, training mode mean should converge to σ(Z/τ)."""
+ torch.manual_seed(0)
+ Z = torch.tensor([[[1.5]]])
+ tau = 2.0
+ n_samples = 10_000
+ samples = torch.stack([gumbel_sigmoid(Z, tau, mode="train") for _ in range(n_samples)])
+ empirical_mean = samples.mean().item()
+ expected = torch.sigmoid(Z / tau).item()
+ assert abs(empirical_mean - expected) < 0.05, \
+ f"Empirical mean {empirical_mean:.4f} != σ(Z/τ) {expected:.4f}"
+
+ def test_masked_positions_stay_zero(self):
+ """After masking, invalid positions should be ~0 regardless of Z values."""
+ mask = create_block_upper_triangular_mask()
+ Z = torch.ones(1, 256, 256) * 10.0 # all high logits
+ Z_masked = Z * mask + (-1e9) * (1 - mask)
+
+ for mode in ["train", "eval_soft", "eval_hard"]:
+ A = gumbel_sigmoid(Z_masked, tau=1.0, mode=mode)
+ invalid = A[0][~mask.bool()]
+ assert (invalid < 1e-6).all(), \
+ f"Invalid positions not zero in {mode}: max={invalid.max():.6f}"
diff --git a/tests/test_olmo_graph.py b/tests/test_olmo_graph.py
new file mode 100644
index 0000000..efeb57b
--- /dev/null
+++ b/tests/test_olmo_graph.py
@@ -0,0 +1,153 @@
+"""Unit tests for olmo_graph.py.
+
+Tests that don't require model download run with synthetic tensors.
+Integration tests (baseline reproduction) require the model and are
+skipped if model is not available.
+"""
+
+import pytest
+import torch
+import torch.nn as nn
+
+import sys
+import os
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
+
+from src.model.olmo_graph import (
+ create_block_upper_triangular_mask,
+ InputNormalizer,
+)
+
+
+class TestBlockUpperTriangularMask:
+ """Test the DAG constraint mask."""
+
+ def test_shape(self):
+ mask = create_block_upper_triangular_mask(256, 16)
+ assert mask.shape == (256, 256)
+
+ def test_dtype(self):
+ mask = create_block_upper_triangular_mask(256, 16)
+ assert mask.dtype == torch.float32
+
+ def test_no_self_connections(self):
+ """Diagonal should be 0 — a node cannot connect to itself."""
+ mask = create_block_upper_triangular_mask(256, 16)
+ assert mask.diag().sum() == 0
+
+ def test_no_same_layer_connections(self):
+ """Nodes in the same layer should NOT be connected."""
+ mask = create_block_upper_triangular_mask(256, 16)
+ for layer in range(16):
+ start = layer * 16
+ end = start + 16
+ block = mask[start:end, start:end]
+ assert block.sum() == 0, f"Layer {layer} has same-layer connections"
+
+ def test_no_backward_connections(self):
+ """No connections from higher layer to lower layer."""
+ mask = create_block_upper_triangular_mask(256, 16)
+ for src_layer in range(16):
+ for tgt_layer in range(src_layer): # tgt < src = backward
+ src_start = src_layer * 16
+ tgt_start = tgt_layer * 16
+ block = mask[src_start:src_start+16, tgt_start:tgt_start+16]
+ assert block.sum() == 0, f"Backward connection from layer {src_layer} to {tgt_layer}"
+
+ def test_forward_connections_exist(self):
+ """Forward connections (higher layer targets) should be 1."""
+ mask = create_block_upper_triangular_mask(256, 16)
+ for src_layer in range(15):
+ for tgt_layer in range(src_layer + 1, 16):
+ src_start = src_layer * 16
+ tgt_start = tgt_layer * 16
+ block = mask[src_start:src_start+16, tgt_start:tgt_start+16]
+ assert block.sum() == 16 * 16, \
+ f"Missing connections from layer {src_layer} to layer {tgt_layer}"
+
+ def test_total_valid_entries(self):
+ """Should have exactly 30,720 valid entries."""
+ mask = create_block_upper_triangular_mask(256, 16)
+ assert mask.sum().item() == 30720
+
+ def test_adjacent_connections_count(self):
+ """Adjacent layer connections: 15 × 16 × 16 = 3840."""
+ mask = create_block_upper_triangular_mask(256, 16)
+ count = 0
+ for src_layer in range(15):
+ tgt_layer = src_layer + 1
+ src_start = src_layer * 16
+ tgt_start = tgt_layer * 16
+ count += mask[src_start:src_start+16, tgt_start:tgt_start+16].sum().item()
+ assert count == 3840
+
+ def test_skip_connections_count(self):
+ """Skip connections: 105 × 16 × 16 = 26880."""
+ mask = create_block_upper_triangular_mask(256, 16)
+ count = 0
+ for src_layer in range(14):
+ for tgt_layer in range(src_layer + 2, 16):
+ src_start = src_layer * 16
+ tgt_start = tgt_layer * 16
+ count += mask[src_start:src_start+16, tgt_start:tgt_start+16].sum().item()
+ assert count == 26880
+
+ def test_not_torch_triu(self):
+ """Verify this is NOT element-upper-triangular.
+
+ torch.triu would set mask[0,15]=1 (both in layer 0), which is wrong.
+ """
+ mask = create_block_upper_triangular_mask(256, 16)
+ # Node 0 (layer 0, head 0) to node 15 (layer 0, head 15)
+ assert mask[0, 15] == 0, "Same-layer connection detected — did you use torch.triu()?"
+ # Node 0 (layer 0, head 0) to node 16 (layer 1, head 0)
+ assert mask[0, 16] == 1, "Adjacent-layer connection should be 1"
+
+
+class TestInputNormalizer:
+ """Test input normalization methods."""
+
+ def test_none(self):
+ norm = InputNormalizer("none")
+ x = torch.randn(2, 16, 32, 2048)
+ out = norm(x)
+ assert torch.allclose(out, x)
+
+ def test_gate_mean(self):
+ norm = InputNormalizer("gate_mean")
+ gated_sum = torch.randn(2, 16, 32, 2048)
+ A_slice = torch.rand(2, 48, 16) # 3 prior layers
+ out = norm(gated_sum, A_slice=A_slice)
+ assert out.shape == gated_sum.shape
+ assert torch.isfinite(out).all()
+
+ def test_rms_post(self):
+ norm = InputNormalizer("rms_post", model_dim=2048)
+ x = torch.randn(2, 16, 32, 2048)
+ out = norm(x)
+ assert out.shape == x.shape
+ assert torch.isfinite(out).all()
+
+ def test_ln_post(self):
+ norm = InputNormalizer("ln_post", model_dim=2048)
+ x = torch.randn(2, 16, 32, 2048)
+ out = norm(x)
+ assert out.shape == x.shape
+ assert torch.isfinite(out).all()
+
+ def test_rms_pre(self):
+ norm = InputNormalizer("rms_pre", model_dim=64, num_nodes=32) # small for test
+ prior = torch.randn(2, 32, 8, 64)
+ A_slice = torch.rand(2, 32, 4)
+ gated_sum = torch.einsum('bih,bisd->bhsd', A_slice, prior)
+ out = norm(gated_sum, A_slice=A_slice, prior_head_outs=prior)
+ assert out.shape == gated_sum.shape
+ assert torch.isfinite(out).all()
+
+ def test_unknown_method_raises(self):
+ with pytest.raises(ValueError, match="Unknown input_norm"):
+ InputNormalizer("unknown_method")
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/test_predictor.py b/tests/test_predictor.py
new file mode 100644
index 0000000..00a4124
--- /dev/null
+++ b/tests/test_predictor.py
@@ -0,0 +1,206 @@
+"""Tests for the structure predictor components (no GPU or model loading required)."""
+
+import pytest
+import torch
+import torch.nn as nn
+
+from src.model.predictor import (
+ PredictorMLP,
+ cascading_gate,
+ gumbel_sigmoid,
+)
+from src.model.olmo_graph import create_block_upper_triangular_mask
+
+
+class TestPredictorMLP:
+ """Test MLP decoder shapes and gradient flow."""
+
+ def setup_method(self):
+ self.batch = 2
+ self.input_dim = 1024 # Qwen embed_dim
+ self.hidden_dim = 256 # small for testing
+ self.rank = 8
+ self.mlp = PredictorMLP(self.input_dim, self.hidden_dim, self.rank)
+
+ def test_output_shape(self):
+ e = torch.randn(self.batch, self.input_dim)
+ Z = self.mlp(e)
+ assert Z.shape == (self.batch, 256, 256)
+
+ def test_low_rank_structure(self):
+ """Z = UV^T should have rank <= r."""
+ e = torch.randn(1, self.input_dim)
+ Z = self.mlp(e)
+ Z_2d = Z.squeeze(0)
+ # SVD to check effective rank
+ S = torch.linalg.svdvals(Z_2d)
+ # Values beyond rank r should be ~0 (up to numerical precision)
+ assert S[self.rank:].abs().max() < 1e-4, \
+ f"Z has effective rank > {self.rank}: max singular value beyond rank = {S[self.rank:].abs().max()}"
+
+ def test_gradient_flow(self):
+ e = torch.randn(self.batch, self.input_dim)
+ Z = self.mlp(e)
+ loss = Z.sum()
+ loss.backward()
+ for name, p in self.mlp.named_parameters():
+ assert p.grad is not None, f"No gradient for {name}"
+ assert p.grad.abs().sum() > 0, f"Zero gradient for {name}"
+
+ def test_batch_independence(self):
+ """Different inputs should produce different outputs."""
+ e1 = torch.randn(1, self.input_dim)
+ e2 = torch.randn(1, self.input_dim)
+ Z1 = self.mlp(e1)
+ Z2 = self.mlp(e2)
+ assert not torch.allclose(Z1, Z2), "Different inputs produced identical Z"
+
+
+class TestGumbelSigmoid:
+ """Test Gumbel-Sigmoid in all 3 modes."""
+
+ def setup_method(self):
+ self.batch = 2
+ mask = create_block_upper_triangular_mask()
+ # Create Z_masked with valid structure
+ Z = torch.randn(self.batch, 256, 256)
+ self.Z_masked = Z * mask.unsqueeze(0) + (-1e9) * (1 - mask.unsqueeze(0))
+ self.tau = 2.0
+
+ def test_train_mode_range(self):
+ A = gumbel_sigmoid(self.Z_masked, self.tau, mode="train")
+ assert A.shape == (self.batch, 256, 256)
+ assert (A >= 0).all() and (A <= 1).all(), "Train mode values out of [0, 1]"
+
+ def test_train_mode_stochastic(self):
+ """Two calls with same input should give different results (stochastic)."""
+ A1 = gumbel_sigmoid(self.Z_masked, self.tau, mode="train")
+ A2 = gumbel_sigmoid(self.Z_masked, self.tau, mode="train")
+ assert not torch.allclose(A1, A2), "Train mode is deterministic (should be stochastic)"
+
+ def test_eval_soft_range(self):
+ A = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_soft")
+ assert (A >= 0).all() and (A <= 1).all(), "Eval soft values out of [0, 1]"
+
+ def test_eval_soft_deterministic(self):
+ A1 = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_soft")
+ A2 = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_soft")
+ assert torch.allclose(A1, A2), "Eval soft is not deterministic"
+
+ def test_eval_hard_binary(self):
+ A = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_hard")
+ unique_values = A.unique()
+ assert all(v in [0.0, 1.0] for v in unique_values), \
+ f"Eval hard should produce binary 0/1, got {unique_values}"
+
+ def test_eval_hard_deterministic(self):
+ A1 = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_hard")
+ A2 = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_hard")
+ assert torch.allclose(A1, A2), "Eval hard is not deterministic"
+
+ def test_invalid_positions_zero(self):
+ """Invalid positions (same/backward layer) should be ~0 in all modes."""
+ mask = create_block_upper_triangular_mask()
+ invalid_mask = (1 - mask).bool()
+ for mode in ["train", "eval_soft", "eval_hard"]:
+ A = gumbel_sigmoid(self.Z_masked, self.tau, mode=mode)
+ invalid_vals = A[0][invalid_mask]
+ assert (invalid_vals < 1e-6).all(), \
+ f"Invalid positions not zero in {mode}: max={invalid_vals.max()}"
+
+ def test_unknown_mode_raises(self):
+ with pytest.raises(ValueError):
+ gumbel_sigmoid(self.Z_masked, self.tau, mode="unknown")
+
+ def test_temperature_effect(self):
+ """Lower temperature → sharper distribution (closer to binary)."""
+ A_high_tau = gumbel_sigmoid(self.Z_masked, tau=10.0, mode="eval_soft")
+ A_low_tau = gumbel_sigmoid(self.Z_masked, tau=0.1, mode="eval_soft")
+ mask = create_block_upper_triangular_mask().bool()
+ # Low tau should be more extreme (values closer to 0 or 1)
+ valid_high = A_high_tau[0][mask]
+ valid_low = A_low_tau[0][mask]
+ # Measure "sharpness": distance from 0.5
+ sharp_high = (valid_high - 0.5).abs().mean()
+ sharp_low = (valid_low - 0.5).abs().mean()
+ assert sharp_low > sharp_high, \
+ f"Lower tau should be sharper: sharp_low={sharp_low:.4f}, sharp_high={sharp_high:.4f}"
+
+ def test_gradient_through_train_mode(self):
+ """Gradients should flow through Gumbel-Sigmoid in train mode."""
+ Z = torch.randn(1, 256, 256, requires_grad=True)
+ mask = create_block_upper_triangular_mask()
+ Z_masked = Z * mask + (-1e9) * (1 - mask)
+ A = gumbel_sigmoid(Z_masked, tau=2.0, mode="train")
+ loss = A.sum()
+ loss.backward()
+ assert Z.grad is not None
+ # Gradients should be nonzero at valid positions
+ valid_grads = Z.grad[0][mask.bool()]
+ assert (valid_grads != 0).any(), "No nonzero gradients at valid positions"
+
+
+class TestCascadingGate:
+ """Test cascading activation gate."""
+
+ def setup_method(self):
+ self.batch = 2
+
+ def test_output_shape(self):
+ A = torch.rand(self.batch, 256, 256)
+ A_gated = cascading_gate(A, k=5.0, hard=False)
+ assert A_gated.shape == A.shape
+
+ def test_soft_mode_range(self):
+ A = torch.rand(self.batch, 256, 256)
+ A_gated = cascading_gate(A, k=5.0, hard=False)
+ assert (A_gated >= 0).all() and (A_gated <= 1).all()
+
+ def test_hard_mode_kills_disconnected(self):
+ """Nodes with no incoming edges should have all outgoing edges zeroed."""
+ A = torch.zeros(1, 256, 256)
+ # Only set edges from node 0 to node 16 (layer 0 → layer 1)
+ A[0, 0, 16] = 1.0
+ A_gated = cascading_gate(A, k=5.0, hard=True)
+ # Node 0 has no incoming edges → its outgoing should be zeroed
+ assert A_gated[0, 0, 16] == 0.0, "Node 0 has no incoming but wasn't gated to 0"
+ # Node 16 has incoming from node 0 (but node 0 was gated to 0)
+ # In one-pass mode, inc uses ORIGINAL A, so node 16 has inc > 0
+
+ def test_hard_mode_preserves_connected(self):
+ """Nodes with incoming edges keep their outgoing edges."""
+ A = torch.zeros(1, 256, 256)
+ # Set edges: node 0→16, node 16→32
+ A[0, 0, 16] = 1.0
+ A[0, 16, 32] = 1.0
+ A_gated = cascading_gate(A, k=5.0, hard=True)
+ # Node 16 has incoming (from 0) → g_16 = 1 → outgoing preserved
+ assert A_gated[0, 16, 32] == 1.0
+
+ def test_soft_mode_differentiable(self):
+ A = torch.rand(1, 256, 256, requires_grad=True)
+ A_gated = cascading_gate(A, k=5.0, hard=False)
+ loss = A_gated.sum()
+ loss.backward()
+ assert A.grad is not None
+ assert A.grad.abs().sum() > 0
+
+ def test_all_zeros_all_killed(self):
+ """If A is all zeros, cascading gate should keep it all zeros."""
+ A = torch.zeros(1, 256, 256)
+ A_gated = cascading_gate(A, k=5.0, hard=True)
+ assert (A_gated == 0).all()
+
+ def test_one_pass_uses_original(self):
+ """Verify cascading gate uses original A for incoming sums (one-pass)."""
+ # If it were iterative, node 0 being gated off would affect node 16's incoming
+ # But one-pass uses original A, so node 16's incoming is computed from original
+ A = torch.zeros(1, 256, 256)
+ A[0, 0, 16] = 1.0 # 0 → 16
+ A[0, 16, 32] = 1.0 # 16 → 32
+
+ A_gated = cascading_gate(A, k=5.0, hard=True)
+ # One-pass: inc[16] = A[:,16].sum() = A[0,16] = 1.0 (from original A)
+ # g[16] = (inc[16] > 0) = 1.0
+ # So A_gated[16, 32] = A[16, 32] * g[16] = 1.0 * 1.0 = 1.0
+ assert A_gated[0, 16, 32] == 1.0