diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-03 15:12:34 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-03 15:12:34 -0500 |
| commit | 8fe28101366dd32562b8c5534d7fe359b252bdf3 (patch) | |
| tree | c92a92184fb2f46f265ab84c1f754c3d5d6597bc /scripts/test_length_fix.py | |
Initial commit: UPH project codebase and experiment results
Includes model code, evaluation scripts, configs, analysis outputs,
and experiment results for the User Prior Head personalization method.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'scripts/test_length_fix.py')
| -rw-r--r-- | scripts/test_length_fix.py | 203 |
1 files changed, 203 insertions, 0 deletions
diff --git a/scripts/test_length_fix.py b/scripts/test_length_fix.py new file mode 100644 index 0000000..ad3f358 --- /dev/null +++ b/scripts/test_length_fix.py @@ -0,0 +1,203 @@ +"""Test different strategies to fix the output length issue.""" + +import sys +import os +import time +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from eval.metrics import evaluate_all + + +def run_with_blend(wrapper, examples, support_sets, head_module, d=64, + beta=0.05, steps=30, lr=0.05, blend_gamma=0.5, + min_new_tokens=64, max_new_tokens=512): + """Run CVH with logit blending: logits = (1-gamma)*base + gamma*cvh""" + device = 'cuda:1' + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + predictions = [] + theta_norms = [] + + for i, (ex, support) in enumerate(zip(examples, support_sets)): + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens) + predictions.append(pred) + continue + + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head_module, + d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + theta_norms.append(theta.norm().item()) + + # Generate with blending + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_with_blend( + wrapper, prompt, theta, head_module, + gamma=blend_gamma, max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + ) + predictions.append(pred) + + del cached_h, theta + torch.cuda.empty_cache() + + if (i + 1) % 20 == 0: + print(f" {i+1}/{len(examples)}") + + avg_norm = sum(theta_norms) / max(len(theta_norms), 1) + avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1) + return predictions, avg_norm, avg_len + + +def generate_with_blend(wrapper, input_text, theta, head_module, + gamma=0.5, max_new_tokens=512, min_new_tokens=64): + """Generate with blended base + CVH logits.""" + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + generated_ids = [] + past_key_values = None + + for step in range(max_new_tokens): + if step == 0: + cur_input = input_ids + else: + cur_input = torch.tensor([[generated_ids[-1]]], device=wrapper.device) + + with torch.no_grad(): + outputs = wrapper.model( + input_ids=cur_input, + past_key_values=past_key_values, + output_hidden_states=True, + use_cache=True, + return_dict=True, + ) + + past_key_values = outputs.past_key_values + last_hidden = outputs.hidden_states[-1][:, -1, :] # (1, H) + + # Base logits + base_logits = torch.nn.functional.linear( + last_hidden.to(wrapper.lm_head_weight.dtype), + wrapper.lm_head_weight, + wrapper.model.lm_head.bias if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None else None, + ).float() + + # CVH logits + h_prime = head_module.forward_fn(last_hidden.float(), theta) + cvh_logits = torch.nn.functional.linear( + h_prime.to(wrapper.lm_head_weight.dtype), + wrapper.lm_head_weight, + wrapper.model.lm_head.bias if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None else None, + ).float() + + # Blend logits + logits = (1 - gamma) * base_logits + gamma * cvh_logits + + # Suppress EOS before min_new_tokens + if step < min_new_tokens and wrapper.tokenizer.eos_token_id is not None: + logits[0, wrapper.tokenizer.eos_token_id] = float('-inf') + + next_token = logits.argmax(dim=-1).item() + + if next_token == wrapper.tokenizer.eos_token_id: + break + + generated_ids.append(next_token) + + return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) + + +def main(): + N = 50 + print(f"Loading data ({N} examples)...") + examples = load_longlamp('product_review_user', split='val')[:N] + K = 4 + support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + avg_ref_len = sum(len(r.split()) for r in references) / len(references) + print(f"Avg reference length: {avg_ref_len:.0f} words") + + print("Loading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + H = wrapper.hidden_size + device = 'cuda:1' + + # Base with min_new_tokens=200 + print("\n=== Base ===") + base_preds = [] + for ex in examples: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0) + base_preds.append(pred) + base_r = evaluate_all(base_preds, references, support_texts) + base_len = sum(len(p.split()) for p in base_preds) / len(base_preds) + print(f" R-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}, len: {base_len:.0f}") + + results = {'Base': {**base_r, 'avg_len': base_len}} + + head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + + configs = [ + # Standard CVH with higher min_new_tokens + ('CVH min=200', head, 1.0, 200), + # Blended CVH with different gammas + ('CVH blend=0.3 min=64', head, 0.3, 64), + ('CVH blend=0.5 min=64', head, 0.5, 64), + ('CVH blend=0.7 min=64', head, 0.7, 64), + ('CVH blend=0.5 min=128', head, 0.5, 128), + # Uncond blend + ('Uncond blend=0.5 min=64', uncond, 0.5, 64), + ] + + for name, head_mod, gamma, min_tok in configs: + print(f"\n=== {name} ===") + t0 = time.time() + preds, avg_norm, avg_len = run_with_blend( + wrapper, examples, support_sets, head_mod, d=64, + beta=0.05, steps=30, lr=0.05, blend_gamma=gamma, + min_new_tokens=min_tok, max_new_tokens=512, + ) + elapsed = time.time() - t0 + r = evaluate_all(preds, references, support_texts) + results[name] = {**r, 'avg_len': avg_len} + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, SFD: {r['sfd']:.4f}, " + f"|theta|: {avg_norm:.3f}, len: {avg_len:.0f}, time: {elapsed:.0f}s") + + # Summary + print("\n" + "=" * 100) + print(f"{'Config':<30} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8} {'Len':<6}") + print("-" * 100) + for name, r in results.items(): + print(f"{name:<30} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} " + f"{r['meteor']:<8.4f} {r['sfd']:<8.4f} {r.get('avg_len', 0):<6.0f}") + + +if __name__ == '__main__': + main() |
