summaryrefslogtreecommitdiff
path: root/experiments/ep_synthetic.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/ep_synthetic.py')
-rw-r--r--experiments/ep_synthetic.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/experiments/ep_synthetic.py b/experiments/ep_synthetic.py
index 7daecde..a2f24df 100644
--- a/experiments/ep_synthetic.py
+++ b/experiments/ep_synthetic.py
@@ -115,7 +115,9 @@ def compute_diagnostics(model, teacher, dev, d, C, L, beta=0.5, T_nudge=20, alph
gammas,rhos=[],[]
with torch.no_grad():_,hi=model(x,return_hidden=True)
for l in range(L):
- a_ep=(h_nudge[l+1].detach()-h_free[l+1])/beta
+ # EP nudge moves h toward lower loss, so (h_nudge - h_free) points opposite to BP grad.
+ # Negate to align with BP gradient convention (pointing toward loss increase).
+ a_ep=-(h_nudge[l+1].detach()-h_free[l+1])/beta
gammas.append(cosine_similarity_batch(a_ep,bp[l+1]))
def mk(sl):
def f(h):