summaryrefslogtreecommitdiff
path: root/scripts/run_fair_audit.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
commit8fe28101366dd32562b8c5534d7fe359b252bdf3 (patch)
treec92a92184fb2f46f265ab84c1f754c3d5d6597bc /scripts/run_fair_audit.py
Initial commit: UPH project codebase and experiment results
Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'scripts/run_fair_audit.py')
-rw-r--r--scripts/run_fair_audit.py381
1 files changed, 381 insertions, 0 deletions
diff --git a/scripts/run_fair_audit.py b/scripts/run_fair_audit.py
new file mode 100644
index 0000000..73cc744
--- /dev/null
+++ b/scripts/run_fair_audit.py
@@ -0,0 +1,381 @@
+"""Fair audit experiment: all methods use the SAME decode policy.
+
+Decode policy: greedy (temperature=0), min_new_tokens=128, max_new_tokens=512.
+For vector methods: blended generation with gamma=0.5.
+For base/prompt methods: standard generation with min_new_tokens=128.
+
+Reports: ROUGE-L, METEOR, SFD_all, SFD_-len, avg_len, feature-level deltas.
+"""
+
+import sys
+import os
+import json
+import time
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from data.style_features import (
+ extract_style_features, compute_sfd, compute_feature_deltas, FEATURE_NAMES
+)
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from adapt.fit_theta_weighted import fit_theta_weighted
+from baselines.prompt_all_k import generate_prompt_all_k
+from baselines.bm25_top1 import generate_bm25_top1
+from eval.metrics import compute_rouge, compute_meteor
+
+
+def generate_base_with_min(wrapper, input_text, max_new_tokens=512, min_new_tokens=128):
+ """Base generation with min_new_tokens constraint (fair decode policy)."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ with torch.no_grad():
+ outputs = wrapper.model.generate(
+ input_ids,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ temperature=None,
+ top_p=None,
+ do_sample=False,
+ pad_token_id=wrapper.tokenizer.pad_token_id,
+ )
+
+ generated_ids = outputs[0, input_ids.shape[1]:]
+ return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+
+def generate_prompt_with_min(wrapper, input_text, max_new_tokens=512, min_new_tokens=128):
+ """Prompt-based generation with min_new_tokens constraint."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ with torch.no_grad():
+ outputs = wrapper.model.generate(
+ input_ids,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ temperature=None,
+ top_p=None,
+ do_sample=False,
+ pad_token_id=wrapper.tokenizer.pad_token_id,
+ )
+
+ generated_ids = outputs[0, input_ids.shape[1]:]
+ return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+
+def compute_detailed_metrics(predictions, references, support_texts_per_example):
+ """Compute comprehensive metrics including SFD split and feature-level analysis."""
+ rouge = compute_rouge(predictions, references)
+ meteor = compute_meteor(predictions, references)
+
+ # SFD all features
+ sfd_all_list = []
+ sfd_nolen_list = []
+ feature_deltas_all = {name: [] for name in FEATURE_NAMES}
+
+ for pred, support_texts in zip(predictions, support_texts_per_example):
+ if not pred.strip():
+ pred = "empty"
+ sfd_all = compute_sfd(pred, support_texts, exclude_length=False)
+ sfd_nolen = compute_sfd(pred, support_texts, exclude_length=True)
+ sfd_all_list.append(sfd_all)
+ sfd_nolen_list.append(sfd_nolen)
+
+ deltas = compute_feature_deltas(pred, support_texts)
+ for name in FEATURE_NAMES:
+ if name in deltas:
+ feature_deltas_all[name].append(deltas[name]['delta'])
+
+ avg_sfd_all = sum(sfd_all_list) / max(len(sfd_all_list), 1)
+ avg_sfd_nolen = sum(sfd_nolen_list) / max(len(sfd_nolen_list), 1)
+
+ avg_feature_deltas = {}
+ for name in FEATURE_NAMES:
+ vals = feature_deltas_all[name]
+ avg_feature_deltas[name] = sum(vals) / max(len(vals), 1)
+
+ avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1)
+
+ return {
+ 'rouge1': rouge['rouge1'],
+ 'rougeL': rouge['rougeL'],
+ 'meteor': meteor,
+ 'sfd_all': avg_sfd_all,
+ 'sfd_nolen': avg_sfd_nolen,
+ 'avg_len': avg_len,
+ 'feature_deltas': avg_feature_deltas,
+ }
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_eval', type=int, default=200)
+ parser.add_argument('--task', type=str, default='review', choices=['review', 'topic'])
+ parser.add_argument('--setting', type=str, default='user', choices=['user', 'temporal'])
+ parser.add_argument('--output_dir', type=str, default='outputs/fair_audit')
+ args = parser.parse_args()
+
+ N = args.num_eval
+ task = args.task
+ setting = args.setting
+
+ config_map = {
+ ('review', 'user'): 'product_review_user',
+ ('review', 'temporal'): 'product_review_temporal',
+ ('topic', 'user'): 'topic_writing_user',
+ ('topic', 'temporal'): 'topic_writing_temporal',
+ }
+ config_name = config_map[(task, setting)]
+
+ print(f"=== Fair Audit: {task}_{setting}, N={N} ===")
+ print(f"Decode policy: greedy, min_new_tokens=128, max_new_tokens=512")
+ print(f"Vector methods: blended gamma=0.5")
+
+ print("\nLoading data...")
+ examples = load_longlamp(config_name, split='val')[:N]
+ K = 4
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ avg_ref_len = sum(len(r.split()) for r in references) / len(references)
+ print(f"Examples: {len(examples)}, Avg reference len: {avg_ref_len:.0f}")
+
+ print("\nLoading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+ H = wrapper.hidden_size
+ device = 'cuda:1'
+
+ all_results = {}
+ all_predictions = {}
+
+ # === 1. Base (with min_new_tokens=128) ===
+ print("\n--- Base ---")
+ preds = []
+ for i, ex in enumerate(examples):
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_base_with_min(wrapper, prompt, min_new_tokens=128)
+ preds.append(pred)
+ if (i + 1) % 40 == 0:
+ print(f" {i+1}/{N}")
+ r = compute_detailed_metrics(preds, references, support_texts)
+ all_results['Base'] = r
+ all_predictions['Base'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # === 2. Prompt-All-K (with min_new_tokens=128) ===
+ print(f"\n--- Prompt-All-K (K=4) ---")
+ preds = []
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ from data.templates import build_prompt_with_examples
+ prompt = build_prompt_with_examples(ex['query_input'], support, ex['task'])
+ pred = generate_prompt_with_min(wrapper, prompt, min_new_tokens=128)
+ preds.append(pred)
+ if (i + 1) % 40 == 0:
+ print(f" {i+1}/{N}")
+ r = compute_detailed_metrics(preds, references, support_texts)
+ all_results['Prompt-All-K'] = r
+ all_predictions['Prompt-All-K'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # === 3. BM25-Top1 (with min_new_tokens=128) ===
+ print(f"\n--- BM25-Top1 ---")
+ preds = []
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ from baselines.bm25_top1 import bm25_select_top1
+ from data.templates import build_prompt_with_examples
+ selected = bm25_select_top1(ex['query_input'], support)
+ prompt = build_prompt_with_examples(ex['query_input'], selected, ex['task'])
+ pred = generate_prompt_with_min(wrapper, prompt, min_new_tokens=128)
+ preds.append(pred)
+ if (i + 1) % 40 == 0:
+ print(f" {i+1}/{N}")
+ r = compute_detailed_metrics(preds, references, support_texts)
+ all_results['BM25-Top1'] = r
+ all_predictions['BM25-Top1'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # Helper for vector head methods
+ def run_vector_head(name, head_module, d=64, use_weighted=False):
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ preds = []
+ adapt_times = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ t0 = time.time()
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_base_with_min(wrapper, prompt)
+ preds.append(pred)
+ adapt_times.append(0.0)
+ continue
+
+ if use_weighted:
+ theta = fit_theta_weighted(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head_module,
+ tokenizer=wrapper.tokenizer,
+ d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device,
+ max_tokens_per_item=128,
+ verbose=False,
+ )
+ else:
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head_module,
+ d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device,
+ verbose=False,
+ )
+
+ adapt_time = time.time() - t0
+ adapt_times.append(adapt_time)
+
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_with_head_blended(
+ prompt, theta, head_module.forward_fn,
+ blend_gamma=0.5, max_new_tokens=512,
+ min_new_tokens=128, temperature=0.0,
+ )
+ preds.append(pred)
+
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 40 == 0:
+ avg_t = sum(adapt_times) / len(adapt_times)
+ print(f" {i+1}/{N} (avg adapt: {avg_t:.1f}s)")
+
+ avg_adapt = sum(adapt_times) / max(len(adapt_times), 1)
+ r = compute_detailed_metrics(preds, references, support_texts)
+ r['adapt_time'] = avg_adapt
+ return preds, r
+
+ # === 4. Uncond-Head ===
+ print(f"\n--- Uncond-Head ---")
+ uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+ preds, r = run_vector_head('Uncond', uncond, d=64)
+ all_results['Uncond-Head'] = r
+ all_predictions['Uncond-Head'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # === 5. CVH ===
+ print(f"\n--- CVH ---")
+ cvh = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+ preds, r = run_vector_head('CVH', cvh, d=64)
+ all_results['CVH'] = r
+ all_predictions['CVH'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # === 6. Uncond-Head with style-weighted loss ===
+ print(f"\n--- Uncond-Head (style-weighted) ---")
+ uncond_sw = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+ preds, r = run_vector_head('Uncond-SW', uncond_sw, d=64, use_weighted=True)
+ all_results['Uncond-SW'] = r
+ all_predictions['Uncond-SW'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # === Print comprehensive results ===
+ print("\n" + "=" * 110)
+ print("COMPREHENSIVE RESULTS (FAIR DECODE POLICY)")
+ print("=" * 110)
+ header = f"{'Method':<20} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD_all':<8} {'SFD_-len':<8} {'Len':<6}"
+ print(header)
+ print("-" * 110)
+ for name, r in all_results.items():
+ print(f"{name:<20} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} {r['meteor']:<8.4f} "
+ f"{r['sfd_all']:<8.4f} {r['sfd_nolen']:<8.4f} {r['avg_len']:<6.0f}")
+
+ # Feature-level analysis
+ print("\n" + "=" * 110)
+ print("FEATURE-LEVEL DELTAS (gen - proto, closer to 0 = better)")
+ print("=" * 110)
+ header = f"{'Method':<20}" + "".join(f"{n:<14}" for n in FEATURE_NAMES)
+ print(header)
+ print("-" * 110)
+ for name, r in all_results.items():
+ fd = r['feature_deltas']
+ row = f"{name:<20}"
+ for feat_name in FEATURE_NAMES:
+ row += f"{fd[feat_name]:<14.3f}"
+ print(row)
+
+ # Recovery analysis
+ if 'BM25-Top1' in all_results and 'Base' in all_results:
+ print("\n--- Recovery (vs BM25-Top1 baseline) ---")
+ base_rl = all_results['Base']['rougeL']
+ bm25_rl = all_results['BM25-Top1']['rougeL']
+ base_m = all_results['Base']['meteor']
+ bm25_m = all_results['BM25-Top1']['meteor']
+
+ for name, r in all_results.items():
+ if name in ('Base', 'Prompt-All-K', 'BM25-Top1'):
+ continue
+ denom_rl = bm25_rl - base_rl
+ denom_m = bm25_m - base_m
+ rec_rl = (r['rougeL'] - base_rl) / denom_rl if abs(denom_rl) > 1e-8 else 0
+ rec_m = (r['meteor'] - base_m) / denom_m if abs(denom_m) > 1e-8 else 0
+ print(f" {name}: R-L Recovery={rec_rl:.3f}, METEOR Recovery={rec_m:.3f}")
+
+ # Efficiency
+ print("\n--- Efficiency ---")
+ theta_bytes = 64 * 2 # d=64, bf16
+ print(f" Theta size: {theta_bytes} bytes")
+ print(f" Personalization prompt tokens at inference: 0 (vector methods)")
+ if support_texts:
+ from eval.metrics import compute_compression
+ comp = compute_compression(support_texts[0], theta_bytes)
+ print(f" Compression ratio (example): {comp:.0f}x")
+
+ # Save results
+ os.makedirs(args.output_dir, exist_ok=True)
+ exp_name = f"{task}_{setting}_K4_d64_N{N}_fair"
+ output_path = os.path.join(args.output_dir, f"{exp_name}_results.json")
+
+ save_data = {
+ 'results': {k: {kk: vv for kk, vv in v.items()} for k, v in all_results.items()},
+ 'num_examples': len(examples),
+ 'decode_policy': 'greedy, min_new_tokens=128, max_new_tokens=512, blend_gamma=0.5',
+ }
+ with open(output_path, 'w') as f:
+ json.dump(save_data, f, indent=2, default=str)
+ print(f"\nResults saved to {output_path}")
+
+
+if __name__ == '__main__':
+ main()