"""Fair audit experiment: all methods use the SAME decode policy. Decode policy: greedy (temperature=0), min_new_tokens=128, max_new_tokens=512. For vector methods: blended generation with gamma=0.5. For base/prompt methods: standard generation with min_new_tokens=128. Reports: ROUGE-L, METEOR, SFD_all, SFD_-len, avg_len, feature-level deltas. """ import sys import os import json import time import torch sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.longlamp import load_longlamp, select_k_profile_items from data.templates import build_query_prompt from data.style_features import ( extract_style_features, compute_sfd, compute_feature_deltas, FEATURE_NAMES ) from models.qwen_wrapper import QwenWrapper from models.cvh import CVHHead, UnconditionalHead from adapt.cache_hidden import cache_support_hidden_states from adapt.fit_theta import fit_theta from adapt.fit_theta_weighted import fit_theta_weighted from baselines.prompt_all_k import generate_prompt_all_k from baselines.bm25_top1 import generate_bm25_top1 from eval.metrics import compute_rouge, compute_meteor def generate_base_with_min(wrapper, input_text, max_new_tokens=512, min_new_tokens=128): """Base generation with min_new_tokens constraint (fair decode policy).""" chat_messages = [ {"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": input_text}, ] prompt_text = wrapper.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) with torch.no_grad(): outputs = wrapper.model.generate( input_ids, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, temperature=None, top_p=None, do_sample=False, pad_token_id=wrapper.tokenizer.pad_token_id, ) generated_ids = outputs[0, input_ids.shape[1]:] return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) def generate_prompt_with_min(wrapper, input_text, max_new_tokens=512, min_new_tokens=128): """Prompt-based generation with min_new_tokens constraint.""" chat_messages = [ {"role": "system", "content": "You are a helpful writing assistant."}, {"role": "user", "content": input_text}, ] prompt_text = wrapper.tokenizer.apply_chat_template( chat_messages, tokenize=False, add_generation_prompt=True ) input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) with torch.no_grad(): outputs = wrapper.model.generate( input_ids, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, temperature=None, top_p=None, do_sample=False, pad_token_id=wrapper.tokenizer.pad_token_id, ) generated_ids = outputs[0, input_ids.shape[1]:] return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) def compute_detailed_metrics(predictions, references, support_texts_per_example): """Compute comprehensive metrics including SFD split and feature-level analysis.""" rouge = compute_rouge(predictions, references) meteor = compute_meteor(predictions, references) # SFD all features sfd_all_list = [] sfd_nolen_list = [] feature_deltas_all = {name: [] for name in FEATURE_NAMES} for pred, support_texts in zip(predictions, support_texts_per_example): if not pred.strip(): pred = "empty" sfd_all = compute_sfd(pred, support_texts, exclude_length=False) sfd_nolen = compute_sfd(pred, support_texts, exclude_length=True) sfd_all_list.append(sfd_all) sfd_nolen_list.append(sfd_nolen) deltas = compute_feature_deltas(pred, support_texts) for name in FEATURE_NAMES: if name in deltas: feature_deltas_all[name].append(deltas[name]['delta']) avg_sfd_all = sum(sfd_all_list) / max(len(sfd_all_list), 1) avg_sfd_nolen = sum(sfd_nolen_list) / max(len(sfd_nolen_list), 1) avg_feature_deltas = {} for name in FEATURE_NAMES: vals = feature_deltas_all[name] avg_feature_deltas[name] = sum(vals) / max(len(vals), 1) avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1) return { 'rouge1': rouge['rouge1'], 'rougeL': rouge['rougeL'], 'meteor': meteor, 'sfd_all': avg_sfd_all, 'sfd_nolen': avg_sfd_nolen, 'avg_len': avg_len, 'feature_deltas': avg_feature_deltas, } def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('--num_eval', type=int, default=200) parser.add_argument('--task', type=str, default='review', choices=['review', 'topic']) parser.add_argument('--setting', type=str, default='user', choices=['user', 'temporal']) parser.add_argument('--output_dir', type=str, default='outputs/fair_audit') args = parser.parse_args() N = args.num_eval task = args.task setting = args.setting config_map = { ('review', 'user'): 'product_review_user', ('review', 'temporal'): 'product_review_temporal', ('topic', 'user'): 'topic_writing_user', ('topic', 'temporal'): 'topic_writing_temporal', } config_name = config_map[(task, setting)] print(f"=== Fair Audit: {task}_{setting}, N={N} ===") print(f"Decode policy: greedy, min_new_tokens=128, max_new_tokens=512") print(f"Vector methods: blended gamma=0.5") print("\nLoading data...") examples = load_longlamp(config_name, split='val')[:N] K = 4 support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] references = [ex['target_output'] for ex in examples] support_texts = [[s['support_output'] for s in ss] for ss in support_sets] avg_ref_len = sum(len(r.split()) for r in references) / len(references) print(f"Examples: {len(examples)}, Avg reference len: {avg_ref_len:.0f}") print("\nLoading model...") wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') H = wrapper.hidden_size device = 'cuda:1' all_results = {} all_predictions = {} # === 1. Base (with min_new_tokens=128) === print("\n--- Base ---") preds = [] for i, ex in enumerate(examples): prompt = build_query_prompt(ex['query_input'], ex['task']) pred = generate_base_with_min(wrapper, prompt, min_new_tokens=128) preds.append(pred) if (i + 1) % 40 == 0: print(f" {i+1}/{N}") r = compute_detailed_metrics(preds, references, support_texts) all_results['Base'] = r all_predictions['Base'] = preds print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") # === 2. Prompt-All-K (with min_new_tokens=128) === print(f"\n--- Prompt-All-K (K=4) ---") preds = [] for i, (ex, support) in enumerate(zip(examples, support_sets)): from data.templates import build_prompt_with_examples prompt = build_prompt_with_examples(ex['query_input'], support, ex['task']) pred = generate_prompt_with_min(wrapper, prompt, min_new_tokens=128) preds.append(pred) if (i + 1) % 40 == 0: print(f" {i+1}/{N}") r = compute_detailed_metrics(preds, references, support_texts) all_results['Prompt-All-K'] = r all_predictions['Prompt-All-K'] = preds print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") # === 3. BM25-Top1 (with min_new_tokens=128) === print(f"\n--- BM25-Top1 ---") preds = [] for i, (ex, support) in enumerate(zip(examples, support_sets)): from baselines.bm25_top1 import bm25_select_top1 from data.templates import build_prompt_with_examples selected = bm25_select_top1(ex['query_input'], support) prompt = build_prompt_with_examples(ex['query_input'], selected, ex['task']) pred = generate_prompt_with_min(wrapper, prompt, min_new_tokens=128) preds.append(pred) if (i + 1) % 40 == 0: print(f" {i+1}/{N}") r = compute_detailed_metrics(preds, references, support_texts) all_results['BM25-Top1'] = r all_predictions['BM25-Top1'] = preds print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") # Helper for vector head methods def run_vector_head(name, head_module, d=64, use_weighted=False): lm_head_bias = None if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: lm_head_bias = wrapper.model.lm_head.bias.data preds = [] adapt_times = [] for i, (ex, support) in enumerate(zip(examples, support_sets)): t0 = time.time() cached_h = cache_support_hidden_states(wrapper, support, ex['task']) if not cached_h: prompt = build_query_prompt(ex['query_input'], ex['task']) pred = generate_base_with_min(wrapper, prompt) preds.append(pred) adapt_times.append(0.0) continue if use_weighted: theta = fit_theta_weighted( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=head_module, tokenizer=wrapper.tokenizer, d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4, max_grad_norm=5.0, device=device, max_tokens_per_item=128, verbose=False, ) else: theta = fit_theta( cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, lm_head_bias=lm_head_bias, head_module=head_module, d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4, max_grad_norm=5.0, device=device, verbose=False, ) adapt_time = time.time() - t0 adapt_times.append(adapt_time) prompt = build_query_prompt(ex['query_input'], ex['task']) pred = wrapper.generate_with_head_blended( prompt, theta, head_module.forward_fn, blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0, ) preds.append(pred) del cached_h, theta torch.cuda.empty_cache() if (i + 1) % 40 == 0: avg_t = sum(adapt_times) / len(adapt_times) print(f" {i+1}/{N} (avg adapt: {avg_t:.1f}s)") avg_adapt = sum(adapt_times) / max(len(adapt_times), 1) r = compute_detailed_metrics(preds, references, support_texts) r['adapt_time'] = avg_adapt return preds, r # === 4. Uncond-Head === print(f"\n--- Uncond-Head ---") uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) preds, r = run_vector_head('Uncond', uncond, d=64) all_results['Uncond-Head'] = r all_predictions['Uncond-Head'] = preds print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") # === 5. CVH === print(f"\n--- CVH ---") cvh = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device) preds, r = run_vector_head('CVH', cvh, d=64) all_results['CVH'] = r all_predictions['CVH'] = preds print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") # === 6. Uncond-Head with style-weighted loss === print(f"\n--- Uncond-Head (style-weighted) ---") uncond_sw = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) preds, r = run_vector_head('Uncond-SW', uncond_sw, d=64, use_weighted=True) all_results['Uncond-SW'] = r all_predictions['Uncond-SW'] = preds print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") # === Print comprehensive results === print("\n" + "=" * 110) print("COMPREHENSIVE RESULTS (FAIR DECODE POLICY)") print("=" * 110) header = f"{'Method':<20} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD_all':<8} {'SFD_-len':<8} {'Len':<6}" print(header) print("-" * 110) for name, r in all_results.items(): print(f"{name:<20} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} {r['meteor']:<8.4f} " f"{r['sfd_all']:<8.4f} {r['sfd_nolen']:<8.4f} {r['avg_len']:<6.0f}") # Feature-level analysis print("\n" + "=" * 110) print("FEATURE-LEVEL DELTAS (gen - proto, closer to 0 = better)") print("=" * 110) header = f"{'Method':<20}" + "".join(f"{n:<14}" for n in FEATURE_NAMES) print(header) print("-" * 110) for name, r in all_results.items(): fd = r['feature_deltas'] row = f"{name:<20}" for feat_name in FEATURE_NAMES: row += f"{fd[feat_name]:<14.3f}" print(row) # Recovery analysis if 'BM25-Top1' in all_results and 'Base' in all_results: print("\n--- Recovery (vs BM25-Top1 baseline) ---") base_rl = all_results['Base']['rougeL'] bm25_rl = all_results['BM25-Top1']['rougeL'] base_m = all_results['Base']['meteor'] bm25_m = all_results['BM25-Top1']['meteor'] for name, r in all_results.items(): if name in ('Base', 'Prompt-All-K', 'BM25-Top1'): continue denom_rl = bm25_rl - base_rl denom_m = bm25_m - base_m rec_rl = (r['rougeL'] - base_rl) / denom_rl if abs(denom_rl) > 1e-8 else 0 rec_m = (r['meteor'] - base_m) / denom_m if abs(denom_m) > 1e-8 else 0 print(f" {name}: R-L Recovery={rec_rl:.3f}, METEOR Recovery={rec_m:.3f}") # Efficiency print("\n--- Efficiency ---") theta_bytes = 64 * 2 # d=64, bf16 print(f" Theta size: {theta_bytes} bytes") print(f" Personalization prompt tokens at inference: 0 (vector methods)") if support_texts: from eval.metrics import compute_compression comp = compute_compression(support_texts[0], theta_bytes) print(f" Compression ratio (example): {comp:.0f}x") # Save results os.makedirs(args.output_dir, exist_ok=True) exp_name = f"{task}_{setting}_K4_d64_N{N}_fair" output_path = os.path.join(args.output_dir, f"{exp_name}_results.json") save_data = { 'results': {k: {kk: vv for kk, vv in v.items()} for k, v in all_results.items()}, 'num_examples': len(examples), 'decode_policy': 'greedy, min_new_tokens=128, max_new_tokens=512, blend_gamma=0.5', } with open(output_path, 'w') as f: json.dump(save_data, f, indent=2, default=str) print(f"\nResults saved to {output_path}") if __name__ == '__main__': main()