From 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 11:00:39 -0600 Subject: Initial implementation: DAGFormer Phase 1 - olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 --- tests/test_gumbel.py | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 tests/test_gumbel.py (limited to 'tests/test_gumbel.py') diff --git a/tests/test_gumbel.py b/tests/test_gumbel.py new file mode 100644 index 0000000..b458948 --- /dev/null +++ b/tests/test_gumbel.py @@ -0,0 +1,68 @@ +"""Tests specifically for Gumbel-Sigmoid correctness and mathematical properties.""" + +import pytest +import torch + +from src.model.predictor import gumbel_sigmoid +from src.model.olmo_graph import create_block_upper_triangular_mask + + +class TestGumbelSigmoidMath: + """Test mathematical properties of Gumbel-Sigmoid.""" + + def test_logistic_noise_distribution(self): + """Gumbel noise G = log(U) - log(1-U) should follow Logistic(0,1).""" + torch.manual_seed(42) + n = 100_000 + U = torch.rand(n).clamp(1e-8, 1 - 1e-8) + G = torch.log(U) - torch.log(1 - U) + # Logistic(0,1) has mean=0, variance=pi^2/3 + assert abs(G.mean().item()) < 0.05, f"Logistic noise mean should be ~0, got {G.mean():.4f}" + expected_var = torch.pi ** 2 / 3 + assert abs(G.var().item() - expected_var) < 0.1, \ + f"Logistic noise var should be ~{expected_var:.4f}, got {G.var():.4f}" + + def test_sigmoid_saturation_at_large_logits(self): + """Large positive logits → A ≈ 1, large negative → A ≈ 0.""" + Z = torch.tensor([[[100.0, -100.0]]]) + A = gumbel_sigmoid(Z, tau=1.0, mode="eval_soft") + assert A[0, 0, 0] > 0.999 + assert A[0, 0, 1] < 0.001 + + def test_zero_logit_gives_half(self): + """σ(0/τ) = 0.5 for any τ.""" + Z = torch.tensor([[[0.0]]]) + A = gumbel_sigmoid(Z, tau=1.0, mode="eval_soft") + assert abs(A[0, 0, 0].item() - 0.5) < 1e-6 + + def test_hard_threshold_at_zero(self): + """Hard mode thresholds at logit=0 (prob=0.5).""" + Z = torch.tensor([[[0.1, -0.1, 0.0]]]) + A = gumbel_sigmoid(Z, tau=1.0, mode="eval_hard") + assert A[0, 0, 0] == 1.0 # > 0 + assert A[0, 0, 1] == 0.0 # < 0 + assert A[0, 0, 2] == 0.0 # = 0 → not > 0 + + def test_train_mean_converges_to_sigmoid(self): + """With many samples, training mode mean should converge to σ(Z/τ).""" + torch.manual_seed(0) + Z = torch.tensor([[[1.5]]]) + tau = 2.0 + n_samples = 10_000 + samples = torch.stack([gumbel_sigmoid(Z, tau, mode="train") for _ in range(n_samples)]) + empirical_mean = samples.mean().item() + expected = torch.sigmoid(Z / tau).item() + assert abs(empirical_mean - expected) < 0.05, \ + f"Empirical mean {empirical_mean:.4f} != σ(Z/τ) {expected:.4f}" + + def test_masked_positions_stay_zero(self): + """After masking, invalid positions should be ~0 regardless of Z values.""" + mask = create_block_upper_triangular_mask() + Z = torch.ones(1, 256, 256) * 10.0 # all high logits + Z_masked = Z * mask + (-1e9) * (1 - mask) + + for mode in ["train", "eval_soft", "eval_hard"]: + A = gumbel_sigmoid(Z_masked, tau=1.0, mode=mode) + invalid = A[0][~mask.bool()] + assert (invalid < 1e-6).all(), \ + f"Invalid positions not zero in {mode}: max={invalid.max():.6f}" -- cgit v1.2.3