"""Test different strategies to fix the output length issue.""" import sys import os import time import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.longlamp import load_longlamp, select_k_profile_items from data.templates import build_query_prompt from models.qwen_wrapper import QwenWrapper from models.cvh import CVHHead, UnconditionalHead from adapt.cache_hidden import cache_support_hidden_states from adapt.fit_theta import fit_theta from eval.metrics import evaluate_all def run_with_blend(wrapper, examples, support_sets, head_module, d=64, beta=0.05, steps=30, lr=0.05, blend_gamma=0.5, min_new_tokens=64, max_new_tokens=512): """Run CVH with logit blending: logits = (1-gamma)*base + gamma*cvh""" device = 'cuda:1' lm_head_bias = None if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: lm_head_bias = wrapper.model.lm_head.bias.data predictions = [] theta_norms = [] for i, (ex, support) in enumerate(zip(examples, support_sets)): cached_h = cache_support_hidden_states(wrapper, support, ex['task']) if not cached_h: prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens) predictions.append(pred) continue theta = fit_theta( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=head_module, d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, max_grad_norm=5.0, device=device, verbose=False, ) theta_norms.append(theta.norm().item()) # Generate with blending prompt = build_query_prompt(ex['query_input'], ex['task']) pred = generate_with_blend( wrapper, prompt, theta, head_module, gamma=blend_gamma, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, ) predictions.append(pred) del cached_h, theta torch.cuda.empty_cache() if (i + 1) % 20 == 0: print(f" {i+1}/{len(examples)}") avg_norm = sum(theta_norms) / max(len(theta_norms), 1) avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1) return predictions, avg_norm, avg_len def generate_with_blend(wrapper, input_text, theta, head_module, gamma=0.5, max_new_tokens=512, min_new_tokens=64): """Generate with blended base + CVH logits.""" chat_messages = [ {"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": input_text}, ] prompt_text = wrapper.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) generated_ids = [] past_key_values = None for step in range(max_new_tokens): if step == 0: cur_input = input_ids else: cur_input = torch.tensor([[generated_ids[-1]]], device=wrapper.device) with torch.no_grad(): outputs = wrapper.model( input_ids=cur_input, past_key_values=past_key_values, output_hidden_states=True, use_cache=True, return_dict=True, ) past_key_values = outputs.past_key_values last_hidden = outputs.hidden_states[-1][:, -1, :] # (1, H) # Base logits base_logits = torch.nn.functional.linear( last_hidden.to(wrapper.lm_head_weight.dtype), wrapper.lm_head_weight, wrapper.model.lm_head.bias if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None else None, ).float() # CVH logits h_prime = head_module.forward_fn(last_hidden.float(), theta) cvh_logits = torch.nn.functional.linear( h_prime.to(wrapper.lm_head_weight.dtype), wrapper.lm_head_weight, wrapper.model.lm_head.bias if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None else None, ).float() # Blend logits logits = (1 - gamma) * base_logits + gamma * cvh_logits # Suppress EOS before min_new_tokens if step < min_new_tokens and wrapper.tokenizer.eos_token_id is not None: logits[0, wrapper.tokenizer.eos_token_id] = float('-inf') next_token = logits.argmax(dim=-1).item() if next_token == wrapper.tokenizer.eos_token_id: break generated_ids.append(next_token) return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) def main(): N = 50 print(f"Loading data ({N} examples)...") examples = load_longlamp('product_review_user', split='val')[:N] K = 4 support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] references = [ex['target_output'] for ex in examples] support_texts = [[s['support_output'] for s in ss] for ss in support_sets] avg_ref_len = sum(len(r.split()) for r in references) / len(references) print(f"Avg reference length: {avg_ref_len:.0f} words") print("Loading model...") wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') H = wrapper.hidden_size device = 'cuda:1' # Base with min_new_tokens=200 print("\n=== Base ===") base_preds = [] for ex in examples: prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0) base_preds.append(pred) base_r = evaluate_all(base_preds, references, support_texts) base_len = sum(len(p.split()) for p in base_preds) / len(base_preds) print(f" R-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}, len: {base_len:.0f}") results = {'Base': {**base_r, 'avg_len': base_len}} head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device) uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) configs = [ # Standard CVH with higher min_new_tokens ('CVH min=200', head, 1.0, 200), # Blended CVH with different gammas ('CVH blend=0.3 min=64', head, 0.3, 64), ('CVH blend=0.5 min=64', head, 0.5, 64), ('CVH blend=0.7 min=64', head, 0.7, 64), ('CVH blend=0.5 min=128', head, 0.5, 128), # Uncond blend ('Uncond blend=0.5 min=64', uncond, 0.5, 64), ] for name, head_mod, gamma, min_tok in configs: print(f"\n=== {name} ===") t0 = time.time() preds, avg_norm, avg_len = run_with_blend( wrapper, examples, support_sets, head_mod, d=64, beta=0.05, steps=30, lr=0.05, blend_gamma=gamma, min_new_tokens=min_tok, max_new_tokens=512, ) elapsed = time.time() - t0 r = evaluate_all(preds, references, support_texts) results[name] = {**r, 'avg_len': avg_len} print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, SFD: {r['sfd']:.4f}, " f"|theta|: {avg_norm:.3f}, len: {avg_len:.0f}, time: {elapsed:.0f}s") # Summary print("\n" + "=" * 100) print(f"{'Config':<30} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8} {'Len':<6}") print("-" * 100) for name, r in results.items(): print(f"{name:<30} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} " f"{r['meteor']:<8.4f} {r['sfd']:<8.4f} {r.get('avg_len', 0):<6.0f}") if __name__ == '__main__': main()