diff options
Diffstat (limited to 'scripts/sweep_d_and_multi.py')
| -rw-r--r-- | scripts/sweep_d_and_multi.py | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/scripts/sweep_d_and_multi.py b/scripts/sweep_d_and_multi.py new file mode 100644 index 0000000..004f6cf --- /dev/null +++ b/scripts/sweep_d_and_multi.py @@ -0,0 +1,165 @@ +"""Sweep d values and test multi-basis CVH.""" + +import sys +import os +import time +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from eval.metrics import evaluate_all + + +class MultiBasisCVH(torch.nn.Module): + """Two-basis CVH: h'_t = h_t + a1*B1(theta⊙A1*h) + a2*B2(theta⊙A2*h)""" + + def __init__(self, hidden_size, d=64, alpha=0.1, basis_seed=42): + super().__init__() + self.hidden_size = hidden_size + self.d = d + self.alpha = alpha + + gen1 = torch.Generator() + gen1.manual_seed(basis_seed) + gen2 = torch.Generator() + gen2.manual_seed(basis_seed + 500) + + scale_a = 1.0 / (hidden_size ** 0.5) + scale_b = 1.0 / (d ** 0.5) + + self.register_buffer('A1', torch.randn(d, hidden_size, generator=gen1) * scale_a) + self.register_buffer('B1', torch.randn(hidden_size, d, generator=gen1) * scale_b) + self.register_buffer('A2', torch.randn(d, hidden_size, generator=gen2) * scale_a) + self.register_buffer('B2', torch.randn(hidden_size, d, generator=gen2) * scale_b) + + def forward(self, h, theta): + proj1 = (self.A1.float() @ h.T).T + gated1 = theta.unsqueeze(0) * proj1 + res1 = (self.B1.float() @ gated1.T).T + + proj2 = (self.A2.float() @ h.T).T + gated2 = theta.unsqueeze(0) * proj2 + res2 = (self.B2.float() @ gated2.T).T + + return h + self.alpha * (res1 + res2) + + def forward_fn(self, h, theta): + return self.forward(h, theta) + + +def run_head(wrapper, examples, support_sets, head_module, d=64, alpha=0.1, + beta=0.05, steps=30, lr=0.05, max_new_tokens=512): + device = 'cuda:1' + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + predictions = [] + theta_norms = [] + + for i, (ex, support) in enumerate(zip(examples, support_sets)): + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens) + predictions.append(pred) + continue + + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head_module, + d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + theta_norms.append(theta.norm().item()) + + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_with_head( + prompt, theta, head_module.forward_fn, + max_new_tokens=max_new_tokens, temperature=0.0, + ) + predictions.append(pred) + + del cached_h, theta + torch.cuda.empty_cache() + + if (i + 1) % 20 == 0: + print(f" {i+1}/{len(examples)}") + + avg_norm = sum(theta_norms) / max(len(theta_norms), 1) + return predictions, avg_norm + + +def main(): + print("Loading data...") + examples = load_longlamp('product_review_user', split='val')[:50] + K = 4 + support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + print("Loading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + H = wrapper.hidden_size + device = 'cuda:1' + + # Base + print("\n=== Base ===") + base_preds = [] + for ex in examples: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0) + base_preds.append(pred) + base_r = evaluate_all(base_preds, references, support_texts) + print(f" ROUGE-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}") + + results = {'Base': base_r} + configs = [ + ('CVH d=64', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + ('CVH d=128', CVHHead(H, d=128, alpha=0.1, basis_seed=42).to(device), 128), + ('CVH d=256', CVHHead(H, d=256, alpha=0.1, basis_seed=42).to(device), 256), + ('Uncond d=64', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + ('Uncond d=128', UnconditionalHead(H, d=128, alpha=0.1, basis_seed=42).to(device), 128), + ('MultiBasis d=64', MultiBasisCVH(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + # Higher beta to preserve content + ('CVH d=64 b=0.1', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + ('CVH d=64 b=0.2', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + ] + + betas = { + 'CVH d=64 b=0.1': 0.1, + 'CVH d=64 b=0.2': 0.2, + } + + for name, head, d in configs: + beta = betas.get(name, 0.05) + print(f"\n=== {name} (beta={beta}) ===") + t0 = time.time() + preds, avg_norm = run_head( + wrapper, examples, support_sets, head, d=d, + alpha=0.1, beta=beta, steps=30, lr=0.05, max_new_tokens=512, + ) + elapsed = time.time() - t0 + r = evaluate_all(preds, references, support_texts) + results[name] = r + print(f" ROUGE-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD: {r['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, time: {elapsed:.0f}s") + + # Summary + print("\n" + "=" * 90) + print(f"{'Config':<25} {'ROUGE-1':<10} {'ROUGE-L':<10} {'METEOR':<10} {'SFD':<10}") + print("-" * 90) + for name, r in results.items(): + print(f"{name:<25} {r['rouge1']:<10.4f} {r['rougeL']:<10.4f} {r['meteor']:<10.4f} {r['sfd']:<10.4f}") + + +if __name__ == '__main__': + main() |
