summaryrefslogtreecommitdiff
path: root/protocol/fa_protocol.py
blob: 8d1293931c938a852c47ca5c92b4cd6694514bc0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
Reference implementation of the three-diagnostic FA evaluation protocol.

Usage:
    from protocol.fa_protocol import FAProtocol

    protocol = FAProtocol(model, x_eval, y_eval)
    report = protocol.run(frozen_baseline_acc=0.349)
    print(report['verdict'])
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional


class FAProtocol:
    """Three-diagnostic evaluation protocol for feedback alignment methods.

    Diagnostics:
        D1 (Scale stability): max per-block residual growth rho = max_l ||h_{l+1}|| / ||h_l||.
            Flags if rho > threshold (default 50).
        D2 (Reference validity): BP gradient norm at the deepest hidden state.
            Flags if ||g_L|| < 10 * eps, where eps is the cosine clamp floor.
        D3 (Depth utility): test accuracy vs frozen-blocks baseline.
            Flags if trained acc < frozen_acc + margin (default 2 pp).

    The protocol requires:
        - A trained model with .blocks (nn.ModuleList) and forward(x, return_hidden=True)
        - A test batch (x_eval, y_eval)
        - A frozen-blocks baseline accuracy (must be computed separately)
    """

    def __init__(
        self,
        model: nn.Module,
        x_eval: torch.Tensor,
        y_eval: torch.Tensor,
        d1_threshold: float = 50.0,
        d2_eps: float = 1e-8,
        d2_factor: float = 10.0,
        d3_margin_pp: float = 2.0,
    ):
        self.model = model
        self.x_eval = x_eval
        self.y_eval = y_eval
        self.d1_threshold = d1_threshold
        self.d2_floor = d2_factor * d2_eps
        self.d3_margin = d3_margin_pp / 100.0

    def _compute_hidden_norms(self, hiddens):
        """Compute median per-sample L2 norm at each hidden layer."""
        norms = []
        for h in hiddens:
            if h.dim() == 4:  # conv: (B, C, H, W) -> pool to (B, C)
                h_flat = F.adaptive_avg_pool2d(h, 1).flatten(1)
            elif h.dim() == 3:  # transformer: (B, T, D) -> cls token
                h_flat = h[:, 0]
            else:
                h_flat = h
            norms.append(float(h_flat.norm(dim=-1).median().item()))
        return norms

    def _compute_bp_grad_norms(self, hiddens):
        """Compute BP gradient norms at each hidden layer via manual forward."""
        model = self.model
        L = len(hiddens) - 1  # number of blocks

        # Rebuild forward from hidden states with grad tracking
        hs = [hiddens[0].detach().clone().requires_grad_(True)]
        for i, block in enumerate(model.blocks):
            if hasattr(block, 'forward'):
                h_next = block(hs[-1])
                # Check if block includes residual (output same shape, skip connection)
                if h_next.shape == hs[-1].shape and not self._block_has_internal_skip(block):
                    h_next = hs[-1] + h_next
            hs.append(h_next)

        # Forward through head
        h_final = hs[-1]
        if h_final.dim() == 4:  # conv
            h_final = F.adaptive_avg_pool2d(h_final, 1).flatten(1)
        elif h_final.dim() == 3:  # transformer cls token
            h_final = h_final[:, 0]
        if hasattr(model, 'out_ln'):
            h_final = model.out_ln(h_final)
        logits = model.out_head(h_final)
        loss = F.cross_entropy(logits, self.y_eval)
        grads = torch.autograd.grad(loss, hs, allow_unused=True)

        norms = []
        for g in grads:
            if g is None:
                norms.append(0.0)
                continue
            if g.dim() == 4:
                g_flat = F.adaptive_avg_pool2d(g, 1).flatten(1)
            elif g.dim() == 3:
                g_flat = g[:, 0]
            else:
                g_flat = g
            norms.append(float(g_flat.norm(dim=-1).median().item()))
        return norms

    @staticmethod
    def _block_has_internal_skip(block):
        """Heuristic: check if the block's forward already includes a residual skip."""
        src = type(block).forward.__qualname__
        # Blocks that compute x + f(x) internally (e.g., transformer blocks)
        return False  # conservative default; override if needed

    def run(self, frozen_baseline_acc: Optional[float] = None, test_acc: Optional[float] = None):
        """Run all three diagnostics.

        Args:
            frozen_baseline_acc: accuracy of the frozen-blocks baseline (required for D3).
            test_acc: test accuracy of the trained model. If None, computed from x_eval/y_eval.

        Returns:
            dict with 'diagnostics', 'verdict', and raw values.
        """
        self.model.eval()

        # Forward pass to get hidden states
        with torch.no_grad():
            logits, hiddens = self.model(self.x_eval, return_hidden=True)

        if test_acc is None:
            test_acc = float((logits.argmax(-1) == self.y_eval).float().mean().item())

        # D1: Scale stability
        h_norms = self._compute_hidden_norms(hiddens)
        growth_ratios = [h_norms[i+1] / max(h_norms[i], 1e-12)
                         for i in range(len(h_norms) - 1)]
        max_growth = max(growth_ratios) if growth_ratios else 1.0
        d1_fires = max_growth > self.d1_threshold

        # D2: Reference validity
        bp_grad_norms = self._compute_bp_grad_norms(hiddens)
        g_L = bp_grad_norms[-1] if bp_grad_norms else 0.0
        d2_fires = g_L < self.d2_floor

        # D3: Depth utility
        if frozen_baseline_acc is not None:
            margin = test_acc - frozen_baseline_acc
            d3_fires = margin < self.d3_margin
        else:
            margin = None
            d3_fires = None

        # Verdict
        mode1 = d1_fires and d2_fires
        flags = []
        if d1_fires:
            flags.append('D1')
        if d2_fires:
            flags.append('D2')
        if d3_fires:
            flags.append('D3')

        if not flags:
            verdict = 'PASS'
        else:
            verdict = 'FAIL(' + '+'.join(flags) + ')'

        return {
            'verdict': verdict,
            'test_acc': test_acc,
            'diagnostics': {
                'D1_scale_growth': {
                    'max_growth': max_growth,
                    'per_block_growth': growth_ratios,
                    'hidden_norms': h_norms,
                    'threshold': self.d1_threshold,
                    'fires': d1_fires,
                },
                'D2_ref_validity': {
                    'g_L': g_L,
                    'bp_grad_norms': bp_grad_norms,
                    'floor': self.d2_floor,
                    'fires': d2_fires,
                },
                'D3_depth_utility': {
                    'test_acc': test_acc,
                    'frozen_baseline_acc': frozen_baseline_acc,
                    'margin': margin,
                    'margin_threshold': self.d3_margin,
                    'fires': d3_fires,
                },
            },
        }

    def summary(self, report: dict) -> str:
        """Human-readable summary of a protocol report."""
        d = report['diagnostics']
        lines = [
            f"Verdict: {report['verdict']}",
            f"Test accuracy: {report['test_acc']:.4f}",
            f"D1 Scale stability: max growth = {d['D1_scale_growth']['max_growth']:.1f}x "
            f"(threshold {d['D1_scale_growth']['threshold']}x) -> "
            f"{'FIRE' if d['D1_scale_growth']['fires'] else 'pass'}",
            f"D2 Reference validity: ||g_L|| = {d['D2_ref_validity']['g_L']:.2e} "
            f"(floor {d['D2_ref_validity']['floor']:.0e}) -> "
            f"{'FIRE' if d['D2_ref_validity']['fires'] else 'pass'}",
        ]
        if d['D3_depth_utility']['fires'] is not None:
            lines.append(
                f"D3 Depth utility: margin = {d['D3_depth_utility']['margin']*100:+.1f} pp "
                f"(threshold {d['D3_depth_utility']['margin_threshold']*100:.0f} pp) -> "
                f"{'FIRE' if d['D3_depth_utility']['fires'] else 'pass'}"
            )
        else:
            lines.append("D3 Depth utility: not evaluated (no frozen baseline provided)")
        return '\n'.join(lines)