diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/test_gumbel.py | 68 | ||||
| -rw-r--r-- | tests/test_olmo_graph.py | 153 | ||||
| -rw-r--r-- | tests/test_predictor.py | 206 |
3 files changed, 427 insertions, 0 deletions
diff --git a/tests/test_gumbel.py b/tests/test_gumbel.py new file mode 100644 index 0000000..b458948 --- /dev/null +++ b/tests/test_gumbel.py @@ -0,0 +1,68 @@ +"""Tests specifically for Gumbel-Sigmoid correctness and mathematical properties.""" + +import pytest +import torch + +from src.model.predictor import gumbel_sigmoid +from src.model.olmo_graph import create_block_upper_triangular_mask + + +class TestGumbelSigmoidMath: + """Test mathematical properties of Gumbel-Sigmoid.""" + + def test_logistic_noise_distribution(self): + """Gumbel noise G = log(U) - log(1-U) should follow Logistic(0,1).""" + torch.manual_seed(42) + n = 100_000 + U = torch.rand(n).clamp(1e-8, 1 - 1e-8) + G = torch.log(U) - torch.log(1 - U) + # Logistic(0,1) has mean=0, variance=pi^2/3 + assert abs(G.mean().item()) < 0.05, f"Logistic noise mean should be ~0, got {G.mean():.4f}" + expected_var = torch.pi ** 2 / 3 + assert abs(G.var().item() - expected_var) < 0.1, \ + f"Logistic noise var should be ~{expected_var:.4f}, got {G.var():.4f}" + + def test_sigmoid_saturation_at_large_logits(self): + """Large positive logits → A ≈ 1, large negative → A ≈ 0.""" + Z = torch.tensor([[[100.0, -100.0]]]) + A = gumbel_sigmoid(Z, tau=1.0, mode="eval_soft") + assert A[0, 0, 0] > 0.999 + assert A[0, 0, 1] < 0.001 + + def test_zero_logit_gives_half(self): + """σ(0/τ) = 0.5 for any τ.""" + Z = torch.tensor([[[0.0]]]) + A = gumbel_sigmoid(Z, tau=1.0, mode="eval_soft") + assert abs(A[0, 0, 0].item() - 0.5) < 1e-6 + + def test_hard_threshold_at_zero(self): + """Hard mode thresholds at logit=0 (prob=0.5).""" + Z = torch.tensor([[[0.1, -0.1, 0.0]]]) + A = gumbel_sigmoid(Z, tau=1.0, mode="eval_hard") + assert A[0, 0, 0] == 1.0 # > 0 + assert A[0, 0, 1] == 0.0 # < 0 + assert A[0, 0, 2] == 0.0 # = 0 → not > 0 + + def test_train_mean_converges_to_sigmoid(self): + """With many samples, training mode mean should converge to σ(Z/τ).""" + torch.manual_seed(0) + Z = torch.tensor([[[1.5]]]) + tau = 2.0 + n_samples = 10_000 + samples = torch.stack([gumbel_sigmoid(Z, tau, mode="train") for _ in range(n_samples)]) + empirical_mean = samples.mean().item() + expected = torch.sigmoid(Z / tau).item() + assert abs(empirical_mean - expected) < 0.05, \ + f"Empirical mean {empirical_mean:.4f} != σ(Z/τ) {expected:.4f}" + + def test_masked_positions_stay_zero(self): + """After masking, invalid positions should be ~0 regardless of Z values.""" + mask = create_block_upper_triangular_mask() + Z = torch.ones(1, 256, 256) * 10.0 # all high logits + Z_masked = Z * mask + (-1e9) * (1 - mask) + + for mode in ["train", "eval_soft", "eval_hard"]: + A = gumbel_sigmoid(Z_masked, tau=1.0, mode=mode) + invalid = A[0][~mask.bool()] + assert (invalid < 1e-6).all(), \ + f"Invalid positions not zero in {mode}: max={invalid.max():.6f}" diff --git a/tests/test_olmo_graph.py b/tests/test_olmo_graph.py new file mode 100644 index 0000000..efeb57b --- /dev/null +++ b/tests/test_olmo_graph.py @@ -0,0 +1,153 @@ +"""Unit tests for olmo_graph.py. + +Tests that don't require model download run with synthetic tensors. +Integration tests (baseline reproduction) require the model and are +skipped if model is not available. +""" + +import pytest +import torch +import torch.nn as nn + +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from src.model.olmo_graph import ( + create_block_upper_triangular_mask, + InputNormalizer, +) + + +class TestBlockUpperTriangularMask: + """Test the DAG constraint mask.""" + + def test_shape(self): + mask = create_block_upper_triangular_mask(256, 16) + assert mask.shape == (256, 256) + + def test_dtype(self): + mask = create_block_upper_triangular_mask(256, 16) + assert mask.dtype == torch.float32 + + def test_no_self_connections(self): + """Diagonal should be 0 — a node cannot connect to itself.""" + mask = create_block_upper_triangular_mask(256, 16) + assert mask.diag().sum() == 0 + + def test_no_same_layer_connections(self): + """Nodes in the same layer should NOT be connected.""" + mask = create_block_upper_triangular_mask(256, 16) + for layer in range(16): + start = layer * 16 + end = start + 16 + block = mask[start:end, start:end] + assert block.sum() == 0, f"Layer {layer} has same-layer connections" + + def test_no_backward_connections(self): + """No connections from higher layer to lower layer.""" + mask = create_block_upper_triangular_mask(256, 16) + for src_layer in range(16): + for tgt_layer in range(src_layer): # tgt < src = backward + src_start = src_layer * 16 + tgt_start = tgt_layer * 16 + block = mask[src_start:src_start+16, tgt_start:tgt_start+16] + assert block.sum() == 0, f"Backward connection from layer {src_layer} to {tgt_layer}" + + def test_forward_connections_exist(self): + """Forward connections (higher layer targets) should be 1.""" + mask = create_block_upper_triangular_mask(256, 16) + for src_layer in range(15): + for tgt_layer in range(src_layer + 1, 16): + src_start = src_layer * 16 + tgt_start = tgt_layer * 16 + block = mask[src_start:src_start+16, tgt_start:tgt_start+16] + assert block.sum() == 16 * 16, \ + f"Missing connections from layer {src_layer} to layer {tgt_layer}" + + def test_total_valid_entries(self): + """Should have exactly 30,720 valid entries.""" + mask = create_block_upper_triangular_mask(256, 16) + assert mask.sum().item() == 30720 + + def test_adjacent_connections_count(self): + """Adjacent layer connections: 15 × 16 × 16 = 3840.""" + mask = create_block_upper_triangular_mask(256, 16) + count = 0 + for src_layer in range(15): + tgt_layer = src_layer + 1 + src_start = src_layer * 16 + tgt_start = tgt_layer * 16 + count += mask[src_start:src_start+16, tgt_start:tgt_start+16].sum().item() + assert count == 3840 + + def test_skip_connections_count(self): + """Skip connections: 105 × 16 × 16 = 26880.""" + mask = create_block_upper_triangular_mask(256, 16) + count = 0 + for src_layer in range(14): + for tgt_layer in range(src_layer + 2, 16): + src_start = src_layer * 16 + tgt_start = tgt_layer * 16 + count += mask[src_start:src_start+16, tgt_start:tgt_start+16].sum().item() + assert count == 26880 + + def test_not_torch_triu(self): + """Verify this is NOT element-upper-triangular. + + torch.triu would set mask[0,15]=1 (both in layer 0), which is wrong. + """ + mask = create_block_upper_triangular_mask(256, 16) + # Node 0 (layer 0, head 0) to node 15 (layer 0, head 15) + assert mask[0, 15] == 0, "Same-layer connection detected — did you use torch.triu()?" + # Node 0 (layer 0, head 0) to node 16 (layer 1, head 0) + assert mask[0, 16] == 1, "Adjacent-layer connection should be 1" + + +class TestInputNormalizer: + """Test input normalization methods.""" + + def test_none(self): + norm = InputNormalizer("none") + x = torch.randn(2, 16, 32, 2048) + out = norm(x) + assert torch.allclose(out, x) + + def test_gate_mean(self): + norm = InputNormalizer("gate_mean") + gated_sum = torch.randn(2, 16, 32, 2048) + A_slice = torch.rand(2, 48, 16) # 3 prior layers + out = norm(gated_sum, A_slice=A_slice) + assert out.shape == gated_sum.shape + assert torch.isfinite(out).all() + + def test_rms_post(self): + norm = InputNormalizer("rms_post", model_dim=2048) + x = torch.randn(2, 16, 32, 2048) + out = norm(x) + assert out.shape == x.shape + assert torch.isfinite(out).all() + + def test_ln_post(self): + norm = InputNormalizer("ln_post", model_dim=2048) + x = torch.randn(2, 16, 32, 2048) + out = norm(x) + assert out.shape == x.shape + assert torch.isfinite(out).all() + + def test_rms_pre(self): + norm = InputNormalizer("rms_pre", model_dim=64, num_nodes=32) # small for test + prior = torch.randn(2, 32, 8, 64) + A_slice = torch.rand(2, 32, 4) + gated_sum = torch.einsum('bih,bisd->bhsd', A_slice, prior) + out = norm(gated_sum, A_slice=A_slice, prior_head_outs=prior) + assert out.shape == gated_sum.shape + assert torch.isfinite(out).all() + + def test_unknown_method_raises(self): + with pytest.raises(ValueError, match="Unknown input_norm"): + InputNormalizer("unknown_method") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_predictor.py b/tests/test_predictor.py new file mode 100644 index 0000000..00a4124 --- /dev/null +++ b/tests/test_predictor.py @@ -0,0 +1,206 @@ +"""Tests for the structure predictor components (no GPU or model loading required).""" + +import pytest +import torch +import torch.nn as nn + +from src.model.predictor import ( + PredictorMLP, + cascading_gate, + gumbel_sigmoid, +) +from src.model.olmo_graph import create_block_upper_triangular_mask + + +class TestPredictorMLP: + """Test MLP decoder shapes and gradient flow.""" + + def setup_method(self): + self.batch = 2 + self.input_dim = 1024 # Qwen embed_dim + self.hidden_dim = 256 # small for testing + self.rank = 8 + self.mlp = PredictorMLP(self.input_dim, self.hidden_dim, self.rank) + + def test_output_shape(self): + e = torch.randn(self.batch, self.input_dim) + Z = self.mlp(e) + assert Z.shape == (self.batch, 256, 256) + + def test_low_rank_structure(self): + """Z = UV^T should have rank <= r.""" + e = torch.randn(1, self.input_dim) + Z = self.mlp(e) + Z_2d = Z.squeeze(0) + # SVD to check effective rank + S = torch.linalg.svdvals(Z_2d) + # Values beyond rank r should be ~0 (up to numerical precision) + assert S[self.rank:].abs().max() < 1e-4, \ + f"Z has effective rank > {self.rank}: max singular value beyond rank = {S[self.rank:].abs().max()}" + + def test_gradient_flow(self): + e = torch.randn(self.batch, self.input_dim) + Z = self.mlp(e) + loss = Z.sum() + loss.backward() + for name, p in self.mlp.named_parameters(): + assert p.grad is not None, f"No gradient for {name}" + assert p.grad.abs().sum() > 0, f"Zero gradient for {name}" + + def test_batch_independence(self): + """Different inputs should produce different outputs.""" + e1 = torch.randn(1, self.input_dim) + e2 = torch.randn(1, self.input_dim) + Z1 = self.mlp(e1) + Z2 = self.mlp(e2) + assert not torch.allclose(Z1, Z2), "Different inputs produced identical Z" + + +class TestGumbelSigmoid: + """Test Gumbel-Sigmoid in all 3 modes.""" + + def setup_method(self): + self.batch = 2 + mask = create_block_upper_triangular_mask() + # Create Z_masked with valid structure + Z = torch.randn(self.batch, 256, 256) + self.Z_masked = Z * mask.unsqueeze(0) + (-1e9) * (1 - mask.unsqueeze(0)) + self.tau = 2.0 + + def test_train_mode_range(self): + A = gumbel_sigmoid(self.Z_masked, self.tau, mode="train") + assert A.shape == (self.batch, 256, 256) + assert (A >= 0).all() and (A <= 1).all(), "Train mode values out of [0, 1]" + + def test_train_mode_stochastic(self): + """Two calls with same input should give different results (stochastic).""" + A1 = gumbel_sigmoid(self.Z_masked, self.tau, mode="train") + A2 = gumbel_sigmoid(self.Z_masked, self.tau, mode="train") + assert not torch.allclose(A1, A2), "Train mode is deterministic (should be stochastic)" + + def test_eval_soft_range(self): + A = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_soft") + assert (A >= 0).all() and (A <= 1).all(), "Eval soft values out of [0, 1]" + + def test_eval_soft_deterministic(self): + A1 = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_soft") + A2 = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_soft") + assert torch.allclose(A1, A2), "Eval soft is not deterministic" + + def test_eval_hard_binary(self): + A = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_hard") + unique_values = A.unique() + assert all(v in [0.0, 1.0] for v in unique_values), \ + f"Eval hard should produce binary 0/1, got {unique_values}" + + def test_eval_hard_deterministic(self): + A1 = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_hard") + A2 = gumbel_sigmoid(self.Z_masked, self.tau, mode="eval_hard") + assert torch.allclose(A1, A2), "Eval hard is not deterministic" + + def test_invalid_positions_zero(self): + """Invalid positions (same/backward layer) should be ~0 in all modes.""" + mask = create_block_upper_triangular_mask() + invalid_mask = (1 - mask).bool() + for mode in ["train", "eval_soft", "eval_hard"]: + A = gumbel_sigmoid(self.Z_masked, self.tau, mode=mode) + invalid_vals = A[0][invalid_mask] + assert (invalid_vals < 1e-6).all(), \ + f"Invalid positions not zero in {mode}: max={invalid_vals.max()}" + + def test_unknown_mode_raises(self): + with pytest.raises(ValueError): + gumbel_sigmoid(self.Z_masked, self.tau, mode="unknown") + + def test_temperature_effect(self): + """Lower temperature → sharper distribution (closer to binary).""" + A_high_tau = gumbel_sigmoid(self.Z_masked, tau=10.0, mode="eval_soft") + A_low_tau = gumbel_sigmoid(self.Z_masked, tau=0.1, mode="eval_soft") + mask = create_block_upper_triangular_mask().bool() + # Low tau should be more extreme (values closer to 0 or 1) + valid_high = A_high_tau[0][mask] + valid_low = A_low_tau[0][mask] + # Measure "sharpness": distance from 0.5 + sharp_high = (valid_high - 0.5).abs().mean() + sharp_low = (valid_low - 0.5).abs().mean() + assert sharp_low > sharp_high, \ + f"Lower tau should be sharper: sharp_low={sharp_low:.4f}, sharp_high={sharp_high:.4f}" + + def test_gradient_through_train_mode(self): + """Gradients should flow through Gumbel-Sigmoid in train mode.""" + Z = torch.randn(1, 256, 256, requires_grad=True) + mask = create_block_upper_triangular_mask() + Z_masked = Z * mask + (-1e9) * (1 - mask) + A = gumbel_sigmoid(Z_masked, tau=2.0, mode="train") + loss = A.sum() + loss.backward() + assert Z.grad is not None + # Gradients should be nonzero at valid positions + valid_grads = Z.grad[0][mask.bool()] + assert (valid_grads != 0).any(), "No nonzero gradients at valid positions" + + +class TestCascadingGate: + """Test cascading activation gate.""" + + def setup_method(self): + self.batch = 2 + + def test_output_shape(self): + A = torch.rand(self.batch, 256, 256) + A_gated = cascading_gate(A, k=5.0, hard=False) + assert A_gated.shape == A.shape + + def test_soft_mode_range(self): + A = torch.rand(self.batch, 256, 256) + A_gated = cascading_gate(A, k=5.0, hard=False) + assert (A_gated >= 0).all() and (A_gated <= 1).all() + + def test_hard_mode_kills_disconnected(self): + """Nodes with no incoming edges should have all outgoing edges zeroed.""" + A = torch.zeros(1, 256, 256) + # Only set edges from node 0 to node 16 (layer 0 → layer 1) + A[0, 0, 16] = 1.0 + A_gated = cascading_gate(A, k=5.0, hard=True) + # Node 0 has no incoming edges → its outgoing should be zeroed + assert A_gated[0, 0, 16] == 0.0, "Node 0 has no incoming but wasn't gated to 0" + # Node 16 has incoming from node 0 (but node 0 was gated to 0) + # In one-pass mode, inc uses ORIGINAL A, so node 16 has inc > 0 + + def test_hard_mode_preserves_connected(self): + """Nodes with incoming edges keep their outgoing edges.""" + A = torch.zeros(1, 256, 256) + # Set edges: node 0→16, node 16→32 + A[0, 0, 16] = 1.0 + A[0, 16, 32] = 1.0 + A_gated = cascading_gate(A, k=5.0, hard=True) + # Node 16 has incoming (from 0) → g_16 = 1 → outgoing preserved + assert A_gated[0, 16, 32] == 1.0 + + def test_soft_mode_differentiable(self): + A = torch.rand(1, 256, 256, requires_grad=True) + A_gated = cascading_gate(A, k=5.0, hard=False) + loss = A_gated.sum() + loss.backward() + assert A.grad is not None + assert A.grad.abs().sum() > 0 + + def test_all_zeros_all_killed(self): + """If A is all zeros, cascading gate should keep it all zeros.""" + A = torch.zeros(1, 256, 256) + A_gated = cascading_gate(A, k=5.0, hard=True) + assert (A_gated == 0).all() + + def test_one_pass_uses_original(self): + """Verify cascading gate uses original A for incoming sums (one-pass).""" + # If it were iterative, node 0 being gated off would affect node 16's incoming + # But one-pass uses original A, so node 16's incoming is computed from original + A = torch.zeros(1, 256, 256) + A[0, 0, 16] = 1.0 # 0 → 16 + A[0, 16, 32] = 1.0 # 16 → 32 + + A_gated = cascading_gate(A, k=5.0, hard=True) + # One-pass: inc[16] = A[:,16].sum() = A[0,16] = 1.0 (from original A) + # g[16] = (inc[16] > 0) = 1.0 + # So A_gated[16, 32] = A[16, 32] * g[16] = 1.0 * 1.0 = 1.0 + assert A_gated[0, 16, 32] == 1.0 |
