summaryrefslogtreecommitdiff
path: root/scripts/test_length_fix.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
commit8fe28101366dd32562b8c5534d7fe359b252bdf3 (patch)
treec92a92184fb2f46f265ab84c1f754c3d5d6597bc /scripts/test_length_fix.py
Initial commit: UPH project codebase and experiment results
Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'scripts/test_length_fix.py')
-rw-r--r--scripts/test_length_fix.py203
1 files changed, 203 insertions, 0 deletions
diff --git a/scripts/test_length_fix.py b/scripts/test_length_fix.py
new file mode 100644
index 0000000..ad3f358
--- /dev/null
+++ b/scripts/test_length_fix.py
@@ -0,0 +1,203 @@
+"""Test different strategies to fix the output length issue."""
+
+import sys
+import os
+import time
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from eval.metrics import evaluate_all
+
+
+def run_with_blend(wrapper, examples, support_sets, head_module, d=64,
+ beta=0.05, steps=30, lr=0.05, blend_gamma=0.5,
+ min_new_tokens=64, max_new_tokens=512):
+ """Run CVH with logit blending: logits = (1-gamma)*base + gamma*cvh"""
+ device = 'cuda:1'
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ predictions = []
+ theta_norms = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens)
+ predictions.append(pred)
+ continue
+
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head_module,
+ d=d, lr=lr, steps=steps, beta=beta, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+ theta_norms.append(theta.norm().item())
+
+ # Generate with blending
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_with_blend(
+ wrapper, prompt, theta, head_module,
+ gamma=blend_gamma, max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ )
+ predictions.append(pred)
+
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 20 == 0:
+ print(f" {i+1}/{len(examples)}")
+
+ avg_norm = sum(theta_norms) / max(len(theta_norms), 1)
+ avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1)
+ return predictions, avg_norm, avg_len
+
+
+def generate_with_blend(wrapper, input_text, theta, head_module,
+ gamma=0.5, max_new_tokens=512, min_new_tokens=64):
+ """Generate with blended base + CVH logits."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ generated_ids = []
+ past_key_values = None
+
+ for step in range(max_new_tokens):
+ if step == 0:
+ cur_input = input_ids
+ else:
+ cur_input = torch.tensor([[generated_ids[-1]]], device=wrapper.device)
+
+ with torch.no_grad():
+ outputs = wrapper.model(
+ input_ids=cur_input,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ past_key_values = outputs.past_key_values
+ last_hidden = outputs.hidden_states[-1][:, -1, :] # (1, H)
+
+ # Base logits
+ base_logits = torch.nn.functional.linear(
+ last_hidden.to(wrapper.lm_head_weight.dtype),
+ wrapper.lm_head_weight,
+ wrapper.model.lm_head.bias if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None else None,
+ ).float()
+
+ # CVH logits
+ h_prime = head_module.forward_fn(last_hidden.float(), theta)
+ cvh_logits = torch.nn.functional.linear(
+ h_prime.to(wrapper.lm_head_weight.dtype),
+ wrapper.lm_head_weight,
+ wrapper.model.lm_head.bias if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None else None,
+ ).float()
+
+ # Blend logits
+ logits = (1 - gamma) * base_logits + gamma * cvh_logits
+
+ # Suppress EOS before min_new_tokens
+ if step < min_new_tokens and wrapper.tokenizer.eos_token_id is not None:
+ logits[0, wrapper.tokenizer.eos_token_id] = float('-inf')
+
+ next_token = logits.argmax(dim=-1).item()
+
+ if next_token == wrapper.tokenizer.eos_token_id:
+ break
+
+ generated_ids.append(next_token)
+
+ return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+
+def main():
+ N = 50
+ print(f"Loading data ({N} examples)...")
+ examples = load_longlamp('product_review_user', split='val')[:N]
+ K = 4
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ avg_ref_len = sum(len(r.split()) for r in references) / len(references)
+ print(f"Avg reference length: {avg_ref_len:.0f} words")
+
+ print("Loading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+ H = wrapper.hidden_size
+ device = 'cuda:1'
+
+ # Base with min_new_tokens=200
+ print("\n=== Base ===")
+ base_preds = []
+ for ex in examples:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0)
+ base_preds.append(pred)
+ base_r = evaluate_all(base_preds, references, support_texts)
+ base_len = sum(len(p.split()) for p in base_preds) / len(base_preds)
+ print(f" R-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}, len: {base_len:.0f}")
+
+ results = {'Base': {**base_r, 'avg_len': base_len}}
+
+ head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+ uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+
+ configs = [
+ # Standard CVH with higher min_new_tokens
+ ('CVH min=200', head, 1.0, 200),
+ # Blended CVH with different gammas
+ ('CVH blend=0.3 min=64', head, 0.3, 64),
+ ('CVH blend=0.5 min=64', head, 0.5, 64),
+ ('CVH blend=0.7 min=64', head, 0.7, 64),
+ ('CVH blend=0.5 min=128', head, 0.5, 128),
+ # Uncond blend
+ ('Uncond blend=0.5 min=64', uncond, 0.5, 64),
+ ]
+
+ for name, head_mod, gamma, min_tok in configs:
+ print(f"\n=== {name} ===")
+ t0 = time.time()
+ preds, avg_norm, avg_len = run_with_blend(
+ wrapper, examples, support_sets, head_mod, d=64,
+ beta=0.05, steps=30, lr=0.05, blend_gamma=gamma,
+ min_new_tokens=min_tok, max_new_tokens=512,
+ )
+ elapsed = time.time() - t0
+ r = evaluate_all(preds, references, support_texts)
+ results[name] = {**r, 'avg_len': avg_len}
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, SFD: {r['sfd']:.4f}, "
+ f"|theta|: {avg_norm:.3f}, len: {avg_len:.0f}, time: {elapsed:.0f}s")
+
+ # Summary
+ print("\n" + "=" * 100)
+ print(f"{'Config':<30} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8} {'Len':<6}")
+ print("-" * 100)
+ for name, r in results.items():
+ print(f"{name:<30} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} "
+ f"{r['meteor']:<8.4f} {r['sfd']:<8.4f} {r.get('avg_len', 0):<6.0f}")
+
+
+if __name__ == '__main__':
+ main()