summaryrefslogtreecommitdiff
path: root/scripts/sweep_d_and_multi.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/sweep_d_and_multi.py')
-rw-r--r--scripts/sweep_d_and_multi.py165
1 files changed, 165 insertions, 0 deletions
diff --git a/scripts/sweep_d_and_multi.py b/scripts/sweep_d_and_multi.py
new file mode 100644
index 0000000..004f6cf
--- /dev/null
+++ b/scripts/sweep_d_and_multi.py
@@ -0,0 +1,165 @@
+"""Sweep d values and test multi-basis CVH."""
+
+import sys
+import os
+import time
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from eval.metrics import evaluate_all
+
+
+class MultiBasisCVH(torch.nn.Module):
+ """Two-basis CVH: h'_t = h_t + a1*B1(theta⊙A1*h) + a2*B2(theta⊙A2*h)"""
+
+ def __init__(self, hidden_size, d=64, alpha=0.1, basis_seed=42):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.d = d
+ self.alpha = alpha
+
+ gen1 = torch.Generator()
+ gen1.manual_seed(basis_seed)
+ gen2 = torch.Generator()
+ gen2.manual_seed(basis_seed + 500)
+
+ scale_a = 1.0 / (hidden_size ** 0.5)
+ scale_b = 1.0 / (d ** 0.5)
+
+ self.register_buffer('A1', torch.randn(d, hidden_size, generator=gen1) * scale_a)
+ self.register_buffer('B1', torch.randn(hidden_size, d, generator=gen1) * scale_b)
+ self.register_buffer('A2', torch.randn(d, hidden_size, generator=gen2) * scale_a)
+ self.register_buffer('B2', torch.randn(hidden_size, d, generator=gen2) * scale_b)
+
+ def forward(self, h, theta):
+ proj1 = (self.A1.float() @ h.T).T
+ gated1 = theta.unsqueeze(0) * proj1
+ res1 = (self.B1.float() @ gated1.T).T
+
+ proj2 = (self.A2.float() @ h.T).T
+ gated2 = theta.unsqueeze(0) * proj2
+ res2 = (self.B2.float() @ gated2.T).T
+
+ return h + self.alpha * (res1 + res2)
+
+ def forward_fn(self, h, theta):
+ return self.forward(h, theta)
+
+
+def run_head(wrapper, examples, support_sets, head_module, d=64, alpha=0.1,
+ beta=0.05, steps=30, lr=0.05, max_new_tokens=512):
+ device = 'cuda:1'
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ predictions = []
+ theta_norms = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens)
+ predictions.append(pred)
+ continue
+
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head_module,
+ d=d, lr=lr, steps=steps, beta=beta, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+ theta_norms.append(theta.norm().item())
+
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_with_head(
+ prompt, theta, head_module.forward_fn,
+ max_new_tokens=max_new_tokens, temperature=0.0,
+ )
+ predictions.append(pred)
+
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 20 == 0:
+ print(f" {i+1}/{len(examples)}")
+
+ avg_norm = sum(theta_norms) / max(len(theta_norms), 1)
+ return predictions, avg_norm
+
+
+def main():
+ print("Loading data...")
+ examples = load_longlamp('product_review_user', split='val')[:50]
+ K = 4
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ print("Loading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+ H = wrapper.hidden_size
+ device = 'cuda:1'
+
+ # Base
+ print("\n=== Base ===")
+ base_preds = []
+ for ex in examples:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0)
+ base_preds.append(pred)
+ base_r = evaluate_all(base_preds, references, support_texts)
+ print(f" ROUGE-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}")
+
+ results = {'Base': base_r}
+ configs = [
+ ('CVH d=64', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ ('CVH d=128', CVHHead(H, d=128, alpha=0.1, basis_seed=42).to(device), 128),
+ ('CVH d=256', CVHHead(H, d=256, alpha=0.1, basis_seed=42).to(device), 256),
+ ('Uncond d=64', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ ('Uncond d=128', UnconditionalHead(H, d=128, alpha=0.1, basis_seed=42).to(device), 128),
+ ('MultiBasis d=64', MultiBasisCVH(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ # Higher beta to preserve content
+ ('CVH d=64 b=0.1', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ ('CVH d=64 b=0.2', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ ]
+
+ betas = {
+ 'CVH d=64 b=0.1': 0.1,
+ 'CVH d=64 b=0.2': 0.2,
+ }
+
+ for name, head, d in configs:
+ beta = betas.get(name, 0.05)
+ print(f"\n=== {name} (beta={beta}) ===")
+ t0 = time.time()
+ preds, avg_norm = run_head(
+ wrapper, examples, support_sets, head, d=d,
+ alpha=0.1, beta=beta, steps=30, lr=0.05, max_new_tokens=512,
+ )
+ elapsed = time.time() - t0
+ r = evaluate_all(preds, references, support_texts)
+ results[name] = r
+ print(f" ROUGE-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD: {r['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, time: {elapsed:.0f}s")
+
+ # Summary
+ print("\n" + "=" * 90)
+ print(f"{'Config':<25} {'ROUGE-1':<10} {'ROUGE-L':<10} {'METEOR':<10} {'SFD':<10}")
+ print("-" * 90)
+ for name, r in results.items():
+ print(f"{name:<25} {r['rouge1']:<10.4f} {r['rougeL']:<10.4f} {r['meteor']:<10.4f} {r['sfd']:<10.4f}")
+
+
+if __name__ == '__main__':
+ main()