summaryrefslogtreecommitdiff
path: root/metrics/credit_metrics.py
blob: 516dca2f6436399753242dacb6114bcefd8f63a3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Credit assignment diagnostic metrics:
1. Exact costate cosine (for toy LQ)
2. Local perturbation correlation rho_l
3. Nudging test Delta_l^nudge
4. Offline BP cosine Gamma_l
5. Bridge residual R_l
6. Feature drift M_l
"""
import torch
import torch.nn.functional as F
import numpy as np
from scipy.stats import pearsonr


def cosine_similarity_batch(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """Compute cosine similarity between a and b along last dim, averaged over batch."""
    a_flat = a.reshape(a.shape[0], -1)
    b_flat = b.reshape(b.shape[0], -1)
    cos = F.cosine_similarity(a_flat, b_flat, dim=-1)
    return cos.mean().item()


def perturbation_correlation(h_l, a_l, forward_fn, epsilon=1e-3, M=32):
    """
    Compute local perturbation correlation rho_l.

    Args:
        h_l: (batch, d) hidden state at layer l
        a_l: (batch, d) credit signal at layer l
        forward_fn: callable that takes h_l -> scalar loss (averaged over batch dims handled inside)
        epsilon: perturbation magnitude
        M: number of random directions

    Returns:
        rho: Pearson correlation between predicted and true loss changes
    """
    batch_size, d = h_l.shape
    device = h_l.device

    pred_list = []
    true_list = []

    base_loss = forward_fn(h_l)  # (batch,) or scalar

    for _ in range(M):
        v = torch.randn(batch_size, d, device=device)
        v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)

        # Predicted change: <a_l, epsilon * v>
        delta_pred = (a_l * (epsilon * v)).sum(dim=-1)  # (batch,)

        # True change: forward from perturbed h
        perturbed_loss = forward_fn(h_l + epsilon * v)  # (batch,)
        delta_true = perturbed_loss - base_loss  # (batch,)

        pred_list.append(delta_pred.detach().cpu().numpy())
        true_list.append(delta_true.detach().cpu().numpy())

    pred_arr = np.concatenate(pred_list)
    true_arr = np.concatenate(true_list)

    if np.std(pred_arr) < 1e-12 or np.std(true_arr) < 1e-12:
        return 0.0

    rho, _ = pearsonr(pred_arr, true_arr)
    return float(rho)


def nudging_test(h_l, a_l, forward_fn, eta=0.01):
    """
    Nudging test: check if moving h_l in -a_l direction decreases loss.

    Args:
        h_l: (batch, d) hidden state
        a_l: (batch, d) credit signal
        forward_fn: callable h -> loss per sample (batch,)
        eta: step size

    Returns:
        mean delta_nudge (negative is good)
    """
    rms_a = (a_l ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
    a_normed = a_l / rms_a
    h_nudged = h_l - eta * a_normed

    base_loss = forward_fn(h_l)
    nudged_loss = forward_fn(h_nudged)
    delta = (nudged_loss - base_loss).mean().item()
    return delta


def offline_bp_cosine(a_l, bp_grad_l):
    """
    Compute offline BP cosine similarity.
    a_l: (batch, d) credit signal
    bp_grad_l: (batch, d) true BP gradient at layer l
    """
    return cosine_similarity_batch(a_l, bp_grad_l)


def bridge_residual(V_phi, V_bar_phi, h_l, t_l, s, h_l_next_noisy_list, t_l_next, lam=0.1):
    """
    Compute bridge residual R_l.

    Args:
        V_phi: value network
        V_bar_phi: EMA target value network
        h_l: (batch, d)
        t_l: (batch,)
        s: (batch, s_dim)
        h_l_next_noisy_list: list of K tensors (batch, d), noisy next states
        t_l_next: (batch,)
        lam: temperature

    Returns:
        mean absolute bridge residual
    """
    with torch.no_grad():
        V_current = V_phi(h_l, t_l, s)  # (batch,)

        # Compute soft-min target
        K = len(h_l_next_noisy_list)
        log_terms = []
        for h_next in h_l_next_noisy_list:
            V_next = V_bar_phi(h_next, t_l_next, s)  # (batch,)
            log_terms.append(-V_next / lam)

        log_terms = torch.stack(log_terms, dim=-1)  # (batch, K)
        V_target = -lam * torch.logsumexp(log_terms, dim=-1) + lam * np.log(K)

        residual = (V_current - V_target).abs().mean().item()
    return residual


def feature_drift(model_init_params, model_final_params):
    """
    Compute per-layer feature drift M_l = ||W_final - W_init||_F / ||W_init||_F.

    Args:
        model_init_params: dict of {name: tensor} initial parameters
        model_final_params: dict of {name: tensor} final parameters

    Returns:
        dict of {name: drift_ratio}
    """
    drifts = {}
    for name in model_init_params:
        if name in model_final_params:
            w_init = model_init_params[name]
            w_final = model_final_params[name]
            init_norm = w_init.norm().item()
            if init_norm > 1e-8:
                drift = (w_final - w_init).norm().item() / init_norm
                drifts[name] = drift
    return drifts