1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
"""Tests specifically for Gumbel-Sigmoid correctness and mathematical properties."""
import pytest
import torch
from src.model.predictor import gumbel_sigmoid
from src.model.olmo_graph import create_block_upper_triangular_mask
class TestGumbelSigmoidMath:
"""Test mathematical properties of Gumbel-Sigmoid."""
def test_logistic_noise_distribution(self):
"""Gumbel noise G = log(U) - log(1-U) should follow Logistic(0,1)."""
torch.manual_seed(42)
n = 100_000
U = torch.rand(n).clamp(1e-8, 1 - 1e-8)
G = torch.log(U) - torch.log(1 - U)
# Logistic(0,1) has mean=0, variance=pi^2/3
assert abs(G.mean().item()) < 0.05, f"Logistic noise mean should be ~0, got {G.mean():.4f}"
expected_var = torch.pi ** 2 / 3
assert abs(G.var().item() - expected_var) < 0.1, \
f"Logistic noise var should be ~{expected_var:.4f}, got {G.var():.4f}"
def test_sigmoid_saturation_at_large_logits(self):
"""Large positive logits → A ≈ 1, large negative → A ≈ 0."""
Z = torch.tensor([[[100.0, -100.0]]])
A = gumbel_sigmoid(Z, tau=1.0, mode="eval_soft")
assert A[0, 0, 0] > 0.999
assert A[0, 0, 1] < 0.001
def test_zero_logit_gives_half(self):
"""σ(0/τ) = 0.5 for any τ."""
Z = torch.tensor([[[0.0]]])
A = gumbel_sigmoid(Z, tau=1.0, mode="eval_soft")
assert abs(A[0, 0, 0].item() - 0.5) < 1e-6
def test_hard_threshold_at_zero(self):
"""Hard mode thresholds at logit=0 (prob=0.5)."""
Z = torch.tensor([[[0.1, -0.1, 0.0]]])
A = gumbel_sigmoid(Z, tau=1.0, mode="eval_hard")
assert A[0, 0, 0] == 1.0 # > 0
assert A[0, 0, 1] == 0.0 # < 0
assert A[0, 0, 2] == 0.0 # = 0 → not > 0
def test_train_mean_converges_to_sigmoid(self):
"""With many samples, training mode mean should converge to σ(Z/τ)."""
torch.manual_seed(0)
Z = torch.tensor([[[1.5]]])
tau = 2.0
n_samples = 10_000
samples = torch.stack([gumbel_sigmoid(Z, tau, mode="train") for _ in range(n_samples)])
empirical_mean = samples.mean().item()
expected = torch.sigmoid(Z / tau).item()
assert abs(empirical_mean - expected) < 0.05, \
f"Empirical mean {empirical_mean:.4f} != σ(Z/τ) {expected:.4f}"
def test_masked_positions_stay_zero(self):
"""After masking, invalid positions should be ~0 regardless of Z values."""
mask = create_block_upper_triangular_mask()
Z = torch.ones(1, 256, 256) * 10.0 # all high logits
Z_masked = Z * mask + (-1e9) * (1 - mask)
for mode in ["train", "eval_soft", "eval_hard"]:
A = gumbel_sigmoid(Z_masked, tau=1.0, mode=mode)
invalid = A[0][~mask.bool()]
assert (invalid < 1e-6).all(), \
f"Invalid positions not zero in {mode}: max={invalid.max():.6f}"
|