From 8fe28101366dd32562b8c5534d7fe359b252bdf3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Fri, 3 Apr 2026 15:12:34 -0500 Subject: Initial commit: UPH project codebase and experiment results Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/run_fair_audit.py | 381 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 381 insertions(+) create mode 100644 scripts/run_fair_audit.py (limited to 'scripts/run_fair_audit.py') diff --git a/scripts/run_fair_audit.py b/scripts/run_fair_audit.py new file mode 100644 index 0000000..73cc744 --- /dev/null +++ b/scripts/run_fair_audit.py @@ -0,0 +1,381 @@ +"""Fair audit experiment: all methods use the SAME decode policy. + +Decode policy: greedy (temperature=0), min_new_tokens=128, max_new_tokens=512. +For vector methods: blended generation with gamma=0.5. +For base/prompt methods: standard generation with min_new_tokens=128. + +Reports: ROUGE-L, METEOR, SFD_all, SFD_-len, avg_len, feature-level deltas. +""" + +import sys +import os +import json +import time +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from data.style_features import ( + extract_style_features, compute_sfd, compute_feature_deltas, FEATURE_NAMES +) +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from adapt.fit_theta_weighted import fit_theta_weighted +from baselines.prompt_all_k import generate_prompt_all_k +from baselines.bm25_top1 import generate_bm25_top1 +from eval.metrics import compute_rouge, compute_meteor + + +def generate_base_with_min(wrapper, input_text, max_new_tokens=512, min_new_tokens=128): + """Base generation with min_new_tokens constraint (fair decode policy).""" + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + with torch.no_grad(): + outputs = wrapper.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + temperature=None, + top_p=None, + do_sample=False, + pad_token_id=wrapper.tokenizer.pad_token_id, + ) + + generated_ids = outputs[0, input_ids.shape[1]:] + return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) + + +def generate_prompt_with_min(wrapper, input_text, max_new_tokens=512, min_new_tokens=128): + """Prompt-based generation with min_new_tokens constraint.""" + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + with torch.no_grad(): + outputs = wrapper.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + temperature=None, + top_p=None, + do_sample=False, + pad_token_id=wrapper.tokenizer.pad_token_id, + ) + + generated_ids = outputs[0, input_ids.shape[1]:] + return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) + + +def compute_detailed_metrics(predictions, references, support_texts_per_example): + """Compute comprehensive metrics including SFD split and feature-level analysis.""" + rouge = compute_rouge(predictions, references) + meteor = compute_meteor(predictions, references) + + # SFD all features + sfd_all_list = [] + sfd_nolen_list = [] + feature_deltas_all = {name: [] for name in FEATURE_NAMES} + + for pred, support_texts in zip(predictions, support_texts_per_example): + if not pred.strip(): + pred = "empty" + sfd_all = compute_sfd(pred, support_texts, exclude_length=False) + sfd_nolen = compute_sfd(pred, support_texts, exclude_length=True) + sfd_all_list.append(sfd_all) + sfd_nolen_list.append(sfd_nolen) + + deltas = compute_feature_deltas(pred, support_texts) + for name in FEATURE_NAMES: + if name in deltas: + feature_deltas_all[name].append(deltas[name]['delta']) + + avg_sfd_all = sum(sfd_all_list) / max(len(sfd_all_list), 1) + avg_sfd_nolen = sum(sfd_nolen_list) / max(len(sfd_nolen_list), 1) + + avg_feature_deltas = {} + for name in FEATURE_NAMES: + vals = feature_deltas_all[name] + avg_feature_deltas[name] = sum(vals) / max(len(vals), 1) + + avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1) + + return { + 'rouge1': rouge['rouge1'], + 'rougeL': rouge['rougeL'], + 'meteor': meteor, + 'sfd_all': avg_sfd_all, + 'sfd_nolen': avg_sfd_nolen, + 'avg_len': avg_len, + 'feature_deltas': avg_feature_deltas, + } + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--num_eval', type=int, default=200) + parser.add_argument('--task', type=str, default='review', choices=['review', 'topic']) + parser.add_argument('--setting', type=str, default='user', choices=['user', 'temporal']) + parser.add_argument('--output_dir', type=str, default='outputs/fair_audit') + args = parser.parse_args() + + N = args.num_eval + task = args.task + setting = args.setting + + config_map = { + ('review', 'user'): 'product_review_user', + ('review', 'temporal'): 'product_review_temporal', + ('topic', 'user'): 'topic_writing_user', + ('topic', 'temporal'): 'topic_writing_temporal', + } + config_name = config_map[(task, setting)] + + print(f"=== Fair Audit: {task}_{setting}, N={N} ===") + print(f"Decode policy: greedy, min_new_tokens=128, max_new_tokens=512") + print(f"Vector methods: blended gamma=0.5") + + print("\nLoading data...") + examples = load_longlamp(config_name, split='val')[:N] + K = 4 + support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + avg_ref_len = sum(len(r.split()) for r in references) / len(references) + print(f"Examples: {len(examples)}, Avg reference len: {avg_ref_len:.0f}") + + print("\nLoading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + H = wrapper.hidden_size + device = 'cuda:1' + + all_results = {} + all_predictions = {} + + # === 1. Base (with min_new_tokens=128) === + print("\n--- Base ---") + preds = [] + for i, ex in enumerate(examples): + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_base_with_min(wrapper, prompt, min_new_tokens=128) + preds.append(pred) + if (i + 1) % 40 == 0: + print(f" {i+1}/{N}") + r = compute_detailed_metrics(preds, references, support_texts) + all_results['Base'] = r + all_predictions['Base'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # === 2. Prompt-All-K (with min_new_tokens=128) === + print(f"\n--- Prompt-All-K (K=4) ---") + preds = [] + for i, (ex, support) in enumerate(zip(examples, support_sets)): + from data.templates import build_prompt_with_examples + prompt = build_prompt_with_examples(ex['query_input'], support, ex['task']) + pred = generate_prompt_with_min(wrapper, prompt, min_new_tokens=128) + preds.append(pred) + if (i + 1) % 40 == 0: + print(f" {i+1}/{N}") + r = compute_detailed_metrics(preds, references, support_texts) + all_results['Prompt-All-K'] = r + all_predictions['Prompt-All-K'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # === 3. BM25-Top1 (with min_new_tokens=128) === + print(f"\n--- BM25-Top1 ---") + preds = [] + for i, (ex, support) in enumerate(zip(examples, support_sets)): + from baselines.bm25_top1 import bm25_select_top1 + from data.templates import build_prompt_with_examples + selected = bm25_select_top1(ex['query_input'], support) + prompt = build_prompt_with_examples(ex['query_input'], selected, ex['task']) + pred = generate_prompt_with_min(wrapper, prompt, min_new_tokens=128) + preds.append(pred) + if (i + 1) % 40 == 0: + print(f" {i+1}/{N}") + r = compute_detailed_metrics(preds, references, support_texts) + all_results['BM25-Top1'] = r + all_predictions['BM25-Top1'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # Helper for vector head methods + def run_vector_head(name, head_module, d=64, use_weighted=False): + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + preds = [] + adapt_times = [] + + for i, (ex, support) in enumerate(zip(examples, support_sets)): + t0 = time.time() + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_base_with_min(wrapper, prompt) + preds.append(pred) + adapt_times.append(0.0) + continue + + if use_weighted: + theta = fit_theta_weighted( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head_module, + tokenizer=wrapper.tokenizer, + d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, + max_tokens_per_item=128, + verbose=False, + ) + else: + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head_module, + d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, + verbose=False, + ) + + adapt_time = time.time() - t0 + adapt_times.append(adapt_time) + + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_with_head_blended( + prompt, theta, head_module.forward_fn, + blend_gamma=0.5, max_new_tokens=512, + min_new_tokens=128, temperature=0.0, + ) + preds.append(pred) + + del cached_h, theta + torch.cuda.empty_cache() + + if (i + 1) % 40 == 0: + avg_t = sum(adapt_times) / len(adapt_times) + print(f" {i+1}/{N} (avg adapt: {avg_t:.1f}s)") + + avg_adapt = sum(adapt_times) / max(len(adapt_times), 1) + r = compute_detailed_metrics(preds, references, support_texts) + r['adapt_time'] = avg_adapt + return preds, r + + # === 4. Uncond-Head === + print(f"\n--- Uncond-Head ---") + uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + preds, r = run_vector_head('Uncond', uncond, d=64) + all_results['Uncond-Head'] = r + all_predictions['Uncond-Head'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # === 5. CVH === + print(f"\n--- CVH ---") + cvh = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + preds, r = run_vector_head('CVH', cvh, d=64) + all_results['CVH'] = r + all_predictions['CVH'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # === 6. Uncond-Head with style-weighted loss === + print(f"\n--- Uncond-Head (style-weighted) ---") + uncond_sw = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + preds, r = run_vector_head('Uncond-SW', uncond_sw, d=64, use_weighted=True) + all_results['Uncond-SW'] = r + all_predictions['Uncond-SW'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # === Print comprehensive results === + print("\n" + "=" * 110) + print("COMPREHENSIVE RESULTS (FAIR DECODE POLICY)") + print("=" * 110) + header = f"{'Method':<20} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD_all':<8} {'SFD_-len':<8} {'Len':<6}" + print(header) + print("-" * 110) + for name, r in all_results.items(): + print(f"{name:<20} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} {r['meteor']:<8.4f} " + f"{r['sfd_all']:<8.4f} {r['sfd_nolen']:<8.4f} {r['avg_len']:<6.0f}") + + # Feature-level analysis + print("\n" + "=" * 110) + print("FEATURE-LEVEL DELTAS (gen - proto, closer to 0 = better)") + print("=" * 110) + header = f"{'Method':<20}" + "".join(f"{n:<14}" for n in FEATURE_NAMES) + print(header) + print("-" * 110) + for name, r in all_results.items(): + fd = r['feature_deltas'] + row = f"{name:<20}" + for feat_name in FEATURE_NAMES: + row += f"{fd[feat_name]:<14.3f}" + print(row) + + # Recovery analysis + if 'BM25-Top1' in all_results and 'Base' in all_results: + print("\n--- Recovery (vs BM25-Top1 baseline) ---") + base_rl = all_results['Base']['rougeL'] + bm25_rl = all_results['BM25-Top1']['rougeL'] + base_m = all_results['Base']['meteor'] + bm25_m = all_results['BM25-Top1']['meteor'] + + for name, r in all_results.items(): + if name in ('Base', 'Prompt-All-K', 'BM25-Top1'): + continue + denom_rl = bm25_rl - base_rl + denom_m = bm25_m - base_m + rec_rl = (r['rougeL'] - base_rl) / denom_rl if abs(denom_rl) > 1e-8 else 0 + rec_m = (r['meteor'] - base_m) / denom_m if abs(denom_m) > 1e-8 else 0 + print(f" {name}: R-L Recovery={rec_rl:.3f}, METEOR Recovery={rec_m:.3f}") + + # Efficiency + print("\n--- Efficiency ---") + theta_bytes = 64 * 2 # d=64, bf16 + print(f" Theta size: {theta_bytes} bytes") + print(f" Personalization prompt tokens at inference: 0 (vector methods)") + if support_texts: + from eval.metrics import compute_compression + comp = compute_compression(support_texts[0], theta_bytes) + print(f" Compression ratio (example): {comp:.0f}x") + + # Save results + os.makedirs(args.output_dir, exist_ok=True) + exp_name = f"{task}_{setting}_K4_d64_N{N}_fair" + output_path = os.path.join(args.output_dir, f"{exp_name}_results.json") + + save_data = { + 'results': {k: {kk: vv for kk, vv in v.items()} for k, v in all_results.items()}, + 'num_examples': len(examples), + 'decode_policy': 'greedy, min_new_tokens=128, max_new_tokens=512, blend_gamma=0.5', + } + with open(output_path, 'w') as f: + json.dump(save_data, f, indent=2, default=str) + print(f"\nResults saved to {output_path}") + + +if __name__ == '__main__': + main() -- cgit v1.2.3