summaryrefslogtreecommitdiff
path: root/experiments/frozen_init_identity_check.py
blob: 3f58d7d7e6d7cfe407dd394e23fd740d6409895a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Frozen-init identity check (supporting measurement for the depth-utility ladder).

Quantifies how close a randomly-initialized, frozen ResidualMLP block stack is to
the identity map. This grounds the footnote explaining why the k=0 rung of the
ladder (all blocks frozen at init) already sits well above chance: the trained
embedding + readout are composed with a fixed, near-norm-preserving random feature
map, i.e. effectively a trained (near-)linear classifier on pixels.

Reports, at random init, on a CIFAR-10 test batch (mean over seeds):
  - per-block residual ratio   ||f_l(h_l)|| / ||h_l||         (median over batch)
  - whole-stack deviation      ||h_L - h_0|| / ||h_0||        (median over batch)
  - whole-stack direction      cos(h_L, h_0)                  (median over batch)

Usage:
    CUDA_VISIBLE_DEVICES=2 python experiments/frozen_init_identity_check.py
"""
import os, sys, json
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.residual_mlp import ResidualMLP


def main():
    d_hidden, L, C, n = 256, 4, 10, 256
    seeds = [42, 123, 456]
    tf = transforms.Compose([transforms.ToTensor(),
                             transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                  (0.2470, 0.2435, 0.2616))])
    ds = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=tf)
    x = torch.stack([ds[i][0] for i in range(n)]).view(n, -1)

    per_block, rel_dev, cos_dev = [], [], []
    seed_rows = {}
    for seed in seeds:
        torch.manual_seed(seed); np.random.seed(seed)
        m = ResidualMLP(32 * 32 * 3, d_hidden, C, L).eval()
        with torch.no_grad():
            h0 = m.embed(x); h = h0; ratios = []
            for blk in m.blocks:
                f = blk(h)
                ratios.append(float((f.norm(dim=-1) / h.norm(dim=-1)).median()))
                h = h + f
            rel = float(((h - h0).norm(dim=-1) / h0.norm(dim=-1)).median())
            cos = float(F.cosine_similarity(h, h0, dim=-1).median())
        per_block.append(ratios); rel_dev.append(rel); cos_dev.append(cos)
        seed_rows[str(seed)] = {'per_block_ratio': ratios, 'rel_dev': rel, 'cos': cos}
        print(f"seed {seed}: per-block ||f||/||h|| = "
              f"{['%.4f' % r for r in ratios]}  "
              f"||h_L-h_0||/||h_0|| = {rel:.3f}  cos(h_L,h_0) = {cos:.4f}", flush=True)

    pb = np.array(per_block)
    summary = {
        'config': {'d_hidden': d_hidden, 'L': L, 'num_classes': C, 'batch': n,
                   'dataset': 'cifar10-test', 'seeds': seeds},
        'per_seed': seed_rows,
        'per_block_ratio_mean': pb.mean(0).tolist(),
        'per_block_ratio_grand_mean': float(pb.mean()),
        'rel_dev_mean': float(np.mean(rel_dev)),
        'rel_dev_std': float(np.std(rel_dev, ddof=1)),
        'cos_mean': float(np.mean(cos_dev)),
        'cos_std': float(np.std(cos_dev, ddof=1)),
    }
    print(f"\nMEAN over {len(seeds)} seeds: "
          f"per-block ratio ≈ {summary['per_block_ratio_grand_mean']:.3f}, "
          f"||h_L-h_0||/||h_0|| = {summary['rel_dev_mean']:.3f} ± {summary['rel_dev_std']:.3f}, "
          f"cos = {summary['cos_mean']:.4f} ± {summary['cos_std']:.4f}", flush=True)

    out = 'results/depth_ladder/frozen_init_identity.json'
    os.makedirs(os.path.dirname(out), exist_ok=True)
    with open(out, 'w') as f:
        json.dump(summary, f, indent=2)
    print(f"Saved -> {out}", flush=True)


if __name__ == '__main__':
    main()