diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-03 15:12:34 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-03 15:12:34 -0500 |
| commit | 8fe28101366dd32562b8c5534d7fe359b252bdf3 (patch) | |
| tree | c92a92184fb2f46f265ab84c1f754c3d5d6597bc /scripts | |
Initial commit: UPH project codebase and experiment results
Includes model code, evaluation scripts, configs, analysis outputs,
and experiment results for the User Prior Head personalization method.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'scripts')
| -rw-r--r-- | scripts/__init__.py | 0 | ||||
| -rw-r--r-- | scripts/run_dev.py | 310 | ||||
| -rw-r--r-- | scripts/run_fair_audit.py | 381 | ||||
| -rw-r--r-- | scripts/shift_analysis.py | 178 | ||||
| -rw-r--r-- | scripts/sweep_alpha.py | 122 | ||||
| -rw-r--r-- | scripts/sweep_d_and_multi.py | 165 | ||||
| -rw-r--r-- | scripts/test_length_fix.py | 203 | ||||
| -rw-r--r-- | scripts/test_normalized_cvh.py | 126 | ||||
| -rw-r--r-- | scripts/test_svd_cvh.py | 141 | ||||
| -rw-r--r-- | scripts/theta_analysis.py | 281 |
10 files changed, 1907 insertions, 0 deletions
diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/scripts/__init__.py diff --git a/scripts/run_dev.py b/scripts/run_dev.py new file mode 100644 index 0000000..d8ccee5 --- /dev/null +++ b/scripts/run_dev.py @@ -0,0 +1,310 @@ +"""Main experiment runner for CVH development. + +Runs all methods (Base, Prompt-All-K, BM25-Top1, Unconditional, CVH) on +LongLaMP validation set and reports metrics. +""" + +import sys +import os +import json +import time +import argparse +import yaml +import torch + +# Add project root to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from baselines.prompt_all_k import generate_prompt_all_k +from baselines.bm25_top1 import generate_bm25_top1 +from eval.metrics import evaluate_all, print_results_table, compute_recovery + + +def run_method_base(wrapper, examples, cfg): + """Run Base (no personalization) method.""" + predictions = [] + for i, ex in enumerate(examples): + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base( + prompt, + max_new_tokens=cfg['max_new_tokens'], + temperature=cfg.get('temperature', 0.7), + top_p=cfg.get('top_p', 0.9), + ) + predictions.append(pred) + if (i + 1) % 20 == 0: + print(f" Base: {i+1}/{len(examples)}") + return predictions + + +def run_method_prompt_all_k(wrapper, examples, support_sets, cfg): + """Run Prompt-All-K baseline.""" + predictions = [] + for i, (ex, support) in enumerate(zip(examples, support_sets)): + pred = generate_prompt_all_k( + wrapper, ex['query_input'], support, ex['task'], + max_new_tokens=cfg['max_new_tokens'], + temperature=cfg.get('temperature', 0.7), + top_p=cfg.get('top_p', 0.9), + ) + predictions.append(pred) + if (i + 1) % 20 == 0: + print(f" Prompt-All-K: {i+1}/{len(examples)}") + return predictions + + +def run_method_bm25_top1(wrapper, examples, support_sets, cfg): + """Run BM25-Top1 baseline.""" + predictions = [] + for i, (ex, support) in enumerate(zip(examples, support_sets)): + pred = generate_bm25_top1( + wrapper, ex['query_input'], support, ex['task'], + max_new_tokens=cfg['max_new_tokens'], + temperature=cfg.get('temperature', 0.7), + top_p=cfg.get('top_p', 0.9), + ) + predictions.append(pred) + if (i + 1) % 20 == 0: + print(f" BM25-Top1: {i+1}/{len(examples)}") + return predictions + + +def run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='cvh'): + """Run CVH or Unconditional head method.""" + device = cfg['device'] + d = cfg['d'] + alpha = cfg['alpha'] + basis_seed = cfg['basis_seed'] + H = wrapper.hidden_size + + # Create head + if head_type == 'cvh': + head = CVHHead(H, d=d, alpha=alpha, basis_seed=basis_seed).to(device) + else: + head = UnconditionalHead(H, d=d, alpha=alpha, basis_seed=basis_seed).to(device) + + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + predictions = [] + adapt_times = [] + + for i, (ex, support) in enumerate(zip(examples, support_sets)): + # Step 1: Cache hidden states from support set + t0 = time.time() + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + + if not cached_h: + # Fallback to base generation if caching fails + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=cfg['max_new_tokens']) + predictions.append(pred) + adapt_times.append(0.0) + continue + + # Step 2: Fit theta_u + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head, + d=d, + lr=cfg['lr'], + steps=cfg['adapt_steps'], + beta=cfg['beta'], + lam=cfg['lam'], + max_grad_norm=cfg['max_grad_norm'], + device=device, + verbose=False, + ) + adapt_time = time.time() - t0 + adapt_times.append(adapt_time) + + # Step 3: Generate with personalized head (blended) + prompt = build_query_prompt(ex['query_input'], ex['task']) + blend_gamma = cfg.get('blend_gamma', 0.5) + pred = wrapper.generate_with_head_blended( + prompt, theta, head.forward_fn, + blend_gamma=blend_gamma, + max_new_tokens=cfg['max_new_tokens'], + min_new_tokens=cfg.get('min_new_tokens', 128), + temperature=cfg.get('temperature', 0.0), + ) + predictions.append(pred) + + # Cleanup GPU memory + del cached_h, theta + torch.cuda.empty_cache() + + if (i + 1) % 10 == 0: + avg_adapt = sum(adapt_times) / len(adapt_times) + print(f" {head_type.upper()}: {i+1}/{len(examples)} " + f"(avg adapt: {avg_adapt:.2f}s)") + + avg_adapt_time = sum(adapt_times) / max(len(adapt_times), 1) + print(f" Average adaptation time: {avg_adapt_time:.2f}s") + + return predictions, avg_adapt_time + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', type=str, required=True) + parser.add_argument('--num_eval', type=int, default=None, + help='Override number of examples to evaluate') + parser.add_argument('--methods', type=str, default='all', + help='Comma-separated methods: base,prompt_all_k,bm25_top1,uncond,cvh') + parser.add_argument('--output_dir', type=str, default='outputs') + parser.add_argument('--seed', type=int, default=0) + args = parser.parse_args() + + # Load config + with open(args.config) as f: + cfg = yaml.safe_load(f) + + if args.num_eval is not None: + cfg['num_eval'] = args.num_eval + + torch.manual_seed(args.seed) + + print(f"=== CVH Dev Experiment ===") + print(f"Config: {args.config}") + print(f"Task: {cfg['task']}, Setting: {cfg['setting']}") + print(f"Model: {cfg['model_name']}, Device: {cfg['device']}") + print(f"d={cfg['d']}, K={cfg['K']}, alpha={cfg['alpha']}, steps={cfg['adapt_steps']}") + + # Load data + print("\nLoading LongLaMP data...") + examples = load_longlamp(cfg['dataset_config'], split='val') + print(f"Loaded {len(examples)} validation examples") + + # Limit examples if specified + num_eval = cfg.get('num_eval', -1) + if num_eval > 0: + examples = examples[:num_eval] + print(f"Evaluating on {len(examples)} examples") + + # Prepare support sets + K = cfg['K'] + support_sets = [] + for ex in examples: + support = select_k_profile_items(ex['profile_items'], K, seed=args.seed) + support_sets.append(support) + + # Gather references and support texts for metrics + references = [ex['target_output'] for ex in examples] + support_texts_per_example = [ + [s['support_output'] for s in support] + for support in support_sets + ] + + # Parse methods to run + if args.methods == 'all': + methods_to_run = ['base', 'prompt_all_k', 'bm25_top1', 'uncond', 'cvh'] + else: + methods_to_run = args.methods.split(',') + + # Load model + print(f"\nLoading model {cfg['model_name']}...") + wrapper = QwenWrapper(cfg['model_name'], device=cfg['device']) + print(f"Model loaded. Hidden size: {wrapper.hidden_size}") + + results = {} + all_predictions = {} + + # Run each method + if 'base' in methods_to_run: + print("\n--- Running Base ---") + preds = run_method_base(wrapper, examples, cfg) + all_predictions['Base'] = preds + results['Base'] = evaluate_all(preds, references, support_texts_per_example) + print(f" ROUGE-L: {results['Base']['rougeL']:.4f}, METEOR: {results['Base']['meteor']:.4f}, SFD: {results['Base']['sfd']:.4f}") + + if 'prompt_all_k' in methods_to_run: + print(f"\n--- Running Prompt-All-K (K={K}) ---") + preds = run_method_prompt_all_k(wrapper, examples, support_sets, cfg) + all_predictions['Prompt-All-K'] = preds + results['Prompt-All-K'] = evaluate_all(preds, references, support_texts_per_example) + print(f" ROUGE-L: {results['Prompt-All-K']['rougeL']:.4f}, METEOR: {results['Prompt-All-K']['meteor']:.4f}, SFD: {results['Prompt-All-K']['sfd']:.4f}") + + if 'bm25_top1' in methods_to_run: + print(f"\n--- Running BM25-Top1 ---") + preds = run_method_bm25_top1(wrapper, examples, support_sets, cfg) + all_predictions['BM25-Top1'] = preds + results['BM25-Top1'] = evaluate_all(preds, references, support_texts_per_example) + print(f" ROUGE-L: {results['BM25-Top1']['rougeL']:.4f}, METEOR: {results['BM25-Top1']['meteor']:.4f}, SFD: {results['BM25-Top1']['sfd']:.4f}") + + if 'uncond' in methods_to_run: + print(f"\n--- Running Unconditional Head ---") + preds, adapt_time = run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='uncond') + all_predictions['Uncond-Head'] = preds + results['Uncond-Head'] = evaluate_all(preds, references, support_texts_per_example) + results['Uncond-Head']['adapt_time'] = adapt_time + print(f" ROUGE-L: {results['Uncond-Head']['rougeL']:.4f}, METEOR: {results['Uncond-Head']['meteor']:.4f}, SFD: {results['Uncond-Head']['sfd']:.4f}") + + if 'cvh' in methods_to_run: + print(f"\n--- Running CVH ---") + preds, adapt_time = run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='cvh') + all_predictions['CVH'] = preds + results['CVH'] = evaluate_all(preds, references, support_texts_per_example) + results['CVH']['adapt_time'] = adapt_time + print(f" ROUGE-L: {results['CVH']['rougeL']:.4f}, METEOR: {results['CVH']['meteor']:.4f}, SFD: {results['CVH']['sfd']:.4f}") + + # Print comparison table + print("\n" + "=" * 70) + print("RESULTS SUMMARY") + print("=" * 70) + print_results_table(results) + + # Compute compression for vector methods + print("\n--- Efficiency ---") + theta_bytes = cfg['d'] * 2 # bf16 = 2 bytes per dim + for ex_support in support_texts_per_example[:5]: + from eval.metrics import compute_compression + comp = compute_compression(ex_support, theta_bytes) + print(f" Compression ratio (example): {comp:.0f}x") + + print(f" Theta size: {theta_bytes} bytes") + print(f" Personalization prompt tokens at inference: 0") + + # Save results + os.makedirs(args.output_dir, exist_ok=True) + exp_name = f"{cfg['task']}_{cfg['setting']}_K{K}_d{cfg['d']}" + output_path = os.path.join(args.output_dir, f"{exp_name}_results.json") + + # Convert results to serializable format + save_data = { + 'config': cfg, + 'results': results, + 'num_examples': len(examples), + } + with open(output_path, 'w') as f: + json.dump(save_data, f, indent=2) + print(f"\nResults saved to {output_path}") + + # Save predictions + pred_path = os.path.join(args.output_dir, f"{exp_name}_predictions.json") + save_preds = {} + for method, preds in all_predictions.items(): + save_preds[method] = [ + { + 'example_id': examples[i]['example_id'], + 'prediction': preds[i], + 'reference': references[i], + } + for i in range(len(preds)) + ] + with open(pred_path, 'w') as f: + json.dump(save_preds, f, indent=2) + print(f"Predictions saved to {pred_path}") + + +if __name__ == '__main__': + main() diff --git a/scripts/run_fair_audit.py b/scripts/run_fair_audit.py new file mode 100644 index 0000000..73cc744 --- /dev/null +++ b/scripts/run_fair_audit.py @@ -0,0 +1,381 @@ +"""Fair audit experiment: all methods use the SAME decode policy. + +Decode policy: greedy (temperature=0), min_new_tokens=128, max_new_tokens=512. +For vector methods: blended generation with gamma=0.5. +For base/prompt methods: standard generation with min_new_tokens=128. + +Reports: ROUGE-L, METEOR, SFD_all, SFD_-len, avg_len, feature-level deltas. +""" + +import sys +import os +import json +import time +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from data.style_features import ( + extract_style_features, compute_sfd, compute_feature_deltas, FEATURE_NAMES +) +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from adapt.fit_theta_weighted import fit_theta_weighted +from baselines.prompt_all_k import generate_prompt_all_k +from baselines.bm25_top1 import generate_bm25_top1 +from eval.metrics import compute_rouge, compute_meteor + + +def generate_base_with_min(wrapper, input_text, max_new_tokens=512, min_new_tokens=128): + """Base generation with min_new_tokens constraint (fair decode policy).""" + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + with torch.no_grad(): + outputs = wrapper.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + temperature=None, + top_p=None, + do_sample=False, + pad_token_id=wrapper.tokenizer.pad_token_id, + ) + + generated_ids = outputs[0, input_ids.shape[1]:] + return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) + + +def generate_prompt_with_min(wrapper, input_text, max_new_tokens=512, min_new_tokens=128): + """Prompt-based generation with min_new_tokens constraint.""" + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + with torch.no_grad(): + outputs = wrapper.model.generate( + input_ids, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + temperature=None, + top_p=None, + do_sample=False, + pad_token_id=wrapper.tokenizer.pad_token_id, + ) + + generated_ids = outputs[0, input_ids.shape[1]:] + return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) + + +def compute_detailed_metrics(predictions, references, support_texts_per_example): + """Compute comprehensive metrics including SFD split and feature-level analysis.""" + rouge = compute_rouge(predictions, references) + meteor = compute_meteor(predictions, references) + + # SFD all features + sfd_all_list = [] + sfd_nolen_list = [] + feature_deltas_all = {name: [] for name in FEATURE_NAMES} + + for pred, support_texts in zip(predictions, support_texts_per_example): + if not pred.strip(): + pred = "empty" + sfd_all = compute_sfd(pred, support_texts, exclude_length=False) + sfd_nolen = compute_sfd(pred, support_texts, exclude_length=True) + sfd_all_list.append(sfd_all) + sfd_nolen_list.append(sfd_nolen) + + deltas = compute_feature_deltas(pred, support_texts) + for name in FEATURE_NAMES: + if name in deltas: + feature_deltas_all[name].append(deltas[name]['delta']) + + avg_sfd_all = sum(sfd_all_list) / max(len(sfd_all_list), 1) + avg_sfd_nolen = sum(sfd_nolen_list) / max(len(sfd_nolen_list), 1) + + avg_feature_deltas = {} + for name in FEATURE_NAMES: + vals = feature_deltas_all[name] + avg_feature_deltas[name] = sum(vals) / max(len(vals), 1) + + avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1) + + return { + 'rouge1': rouge['rouge1'], + 'rougeL': rouge['rougeL'], + 'meteor': meteor, + 'sfd_all': avg_sfd_all, + 'sfd_nolen': avg_sfd_nolen, + 'avg_len': avg_len, + 'feature_deltas': avg_feature_deltas, + } + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--num_eval', type=int, default=200) + parser.add_argument('--task', type=str, default='review', choices=['review', 'topic']) + parser.add_argument('--setting', type=str, default='user', choices=['user', 'temporal']) + parser.add_argument('--output_dir', type=str, default='outputs/fair_audit') + args = parser.parse_args() + + N = args.num_eval + task = args.task + setting = args.setting + + config_map = { + ('review', 'user'): 'product_review_user', + ('review', 'temporal'): 'product_review_temporal', + ('topic', 'user'): 'topic_writing_user', + ('topic', 'temporal'): 'topic_writing_temporal', + } + config_name = config_map[(task, setting)] + + print(f"=== Fair Audit: {task}_{setting}, N={N} ===") + print(f"Decode policy: greedy, min_new_tokens=128, max_new_tokens=512") + print(f"Vector methods: blended gamma=0.5") + + print("\nLoading data...") + examples = load_longlamp(config_name, split='val')[:N] + K = 4 + support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + avg_ref_len = sum(len(r.split()) for r in references) / len(references) + print(f"Examples: {len(examples)}, Avg reference len: {avg_ref_len:.0f}") + + print("\nLoading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + H = wrapper.hidden_size + device = 'cuda:1' + + all_results = {} + all_predictions = {} + + # === 1. Base (with min_new_tokens=128) === + print("\n--- Base ---") + preds = [] + for i, ex in enumerate(examples): + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_base_with_min(wrapper, prompt, min_new_tokens=128) + preds.append(pred) + if (i + 1) % 40 == 0: + print(f" {i+1}/{N}") + r = compute_detailed_metrics(preds, references, support_texts) + all_results['Base'] = r + all_predictions['Base'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # === 2. Prompt-All-K (with min_new_tokens=128) === + print(f"\n--- Prompt-All-K (K=4) ---") + preds = [] + for i, (ex, support) in enumerate(zip(examples, support_sets)): + from data.templates import build_prompt_with_examples + prompt = build_prompt_with_examples(ex['query_input'], support, ex['task']) + pred = generate_prompt_with_min(wrapper, prompt, min_new_tokens=128) + preds.append(pred) + if (i + 1) % 40 == 0: + print(f" {i+1}/{N}") + r = compute_detailed_metrics(preds, references, support_texts) + all_results['Prompt-All-K'] = r + all_predictions['Prompt-All-K'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # === 3. BM25-Top1 (with min_new_tokens=128) === + print(f"\n--- BM25-Top1 ---") + preds = [] + for i, (ex, support) in enumerate(zip(examples, support_sets)): + from baselines.bm25_top1 import bm25_select_top1 + from data.templates import build_prompt_with_examples + selected = bm25_select_top1(ex['query_input'], support) + prompt = build_prompt_with_examples(ex['query_input'], selected, ex['task']) + pred = generate_prompt_with_min(wrapper, prompt, min_new_tokens=128) + preds.append(pred) + if (i + 1) % 40 == 0: + print(f" {i+1}/{N}") + r = compute_detailed_metrics(preds, references, support_texts) + all_results['BM25-Top1'] = r + all_predictions['BM25-Top1'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # Helper for vector head methods + def run_vector_head(name, head_module, d=64, use_weighted=False): + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + preds = [] + adapt_times = [] + + for i, (ex, support) in enumerate(zip(examples, support_sets)): + t0 = time.time() + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_base_with_min(wrapper, prompt) + preds.append(pred) + adapt_times.append(0.0) + continue + + if use_weighted: + theta = fit_theta_weighted( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head_module, + tokenizer=wrapper.tokenizer, + d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, + max_tokens_per_item=128, + verbose=False, + ) + else: + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head_module, + d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, + verbose=False, + ) + + adapt_time = time.time() - t0 + adapt_times.append(adapt_time) + + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_with_head_blended( + prompt, theta, head_module.forward_fn, + blend_gamma=0.5, max_new_tokens=512, + min_new_tokens=128, temperature=0.0, + ) + preds.append(pred) + + del cached_h, theta + torch.cuda.empty_cache() + + if (i + 1) % 40 == 0: + avg_t = sum(adapt_times) / len(adapt_times) + print(f" {i+1}/{N} (avg adapt: {avg_t:.1f}s)") + + avg_adapt = sum(adapt_times) / max(len(adapt_times), 1) + r = compute_detailed_metrics(preds, references, support_texts) + r['adapt_time'] = avg_adapt + return preds, r + + # === 4. Uncond-Head === + print(f"\n--- Uncond-Head ---") + uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + preds, r = run_vector_head('Uncond', uncond, d=64) + all_results['Uncond-Head'] = r + all_predictions['Uncond-Head'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # === 5. CVH === + print(f"\n--- CVH ---") + cvh = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + preds, r = run_vector_head('CVH', cvh, d=64) + all_results['CVH'] = r + all_predictions['CVH'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # === 6. Uncond-Head with style-weighted loss === + print(f"\n--- Uncond-Head (style-weighted) ---") + uncond_sw = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + preds, r = run_vector_head('Uncond-SW', uncond_sw, d=64, use_weighted=True) + all_results['Uncond-SW'] = r + all_predictions['Uncond-SW'] = preds + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}") + + # === Print comprehensive results === + print("\n" + "=" * 110) + print("COMPREHENSIVE RESULTS (FAIR DECODE POLICY)") + print("=" * 110) + header = f"{'Method':<20} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD_all':<8} {'SFD_-len':<8} {'Len':<6}" + print(header) + print("-" * 110) + for name, r in all_results.items(): + print(f"{name:<20} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} {r['meteor']:<8.4f} " + f"{r['sfd_all']:<8.4f} {r['sfd_nolen']:<8.4f} {r['avg_len']:<6.0f}") + + # Feature-level analysis + print("\n" + "=" * 110) + print("FEATURE-LEVEL DELTAS (gen - proto, closer to 0 = better)") + print("=" * 110) + header = f"{'Method':<20}" + "".join(f"{n:<14}" for n in FEATURE_NAMES) + print(header) + print("-" * 110) + for name, r in all_results.items(): + fd = r['feature_deltas'] + row = f"{name:<20}" + for feat_name in FEATURE_NAMES: + row += f"{fd[feat_name]:<14.3f}" + print(row) + + # Recovery analysis + if 'BM25-Top1' in all_results and 'Base' in all_results: + print("\n--- Recovery (vs BM25-Top1 baseline) ---") + base_rl = all_results['Base']['rougeL'] + bm25_rl = all_results['BM25-Top1']['rougeL'] + base_m = all_results['Base']['meteor'] + bm25_m = all_results['BM25-Top1']['meteor'] + + for name, r in all_results.items(): + if name in ('Base', 'Prompt-All-K', 'BM25-Top1'): + continue + denom_rl = bm25_rl - base_rl + denom_m = bm25_m - base_m + rec_rl = (r['rougeL'] - base_rl) / denom_rl if abs(denom_rl) > 1e-8 else 0 + rec_m = (r['meteor'] - base_m) / denom_m if abs(denom_m) > 1e-8 else 0 + print(f" {name}: R-L Recovery={rec_rl:.3f}, METEOR Recovery={rec_m:.3f}") + + # Efficiency + print("\n--- Efficiency ---") + theta_bytes = 64 * 2 # d=64, bf16 + print(f" Theta size: {theta_bytes} bytes") + print(f" Personalization prompt tokens at inference: 0 (vector methods)") + if support_texts: + from eval.metrics import compute_compression + comp = compute_compression(support_texts[0], theta_bytes) + print(f" Compression ratio (example): {comp:.0f}x") + + # Save results + os.makedirs(args.output_dir, exist_ok=True) + exp_name = f"{task}_{setting}_K4_d64_N{N}_fair" + output_path = os.path.join(args.output_dir, f"{exp_name}_results.json") + + save_data = { + 'results': {k: {kk: vv for kk, vv in v.items()} for k, v in all_results.items()}, + 'num_examples': len(examples), + 'decode_policy': 'greedy, min_new_tokens=128, max_new_tokens=512, blend_gamma=0.5', + } + with open(output_path, 'w') as f: + json.dump(save_data, f, indent=2, default=str) + print(f"\nResults saved to {output_path}") + + +if __name__ == '__main__': + main() diff --git a/scripts/shift_analysis.py b/scripts/shift_analysis.py new file mode 100644 index 0000000..99c7fd2 --- /dev/null +++ b/scripts/shift_analysis.py @@ -0,0 +1,178 @@ +"""Support-query distribution shift analysis. + +For each user, compute: + s_u = cos(mean_support_hidden, mean_query_hidden) +Then correlate with CVH-UPH performance gap: + delta_u = ROUGE-L(CVH, u) - ROUGE-L(UPH, u) + +If correlation is positive: CVH benefits when support-query are aligned. +""" + +import sys +import os +import json +import numpy as np +from scipy import stats +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt, build_support_prompt +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from eval.metrics import compute_rouge + + +def get_query_hidden_mean(wrapper, query_text, task): + """Get mean hidden state from the query prompt.""" + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": build_query_prompt(query_text, task)}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + with torch.no_grad(): + outputs = wrapper.model( + input_ids=input_ids, + output_hidden_states=True, + return_dict=True, + ) + last_hidden = outputs.hidden_states[-1][0] # (seq_len, H) + return last_hidden.mean(dim=0).cpu().float().numpy() + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--num_eval', type=int, default=100) + parser.add_argument('--config', type=str, default='product_review_user') + args = parser.parse_args() + + N = args.num_eval + print(f"=== Shift Analysis: {args.config}, N={N} ===") + + print("Loading data...") + examples = load_longlamp(args.config, split='val')[:N] + + print("Loading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + H = wrapper.hidden_size + device = 'cuda:1' + + uph_head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + cvh_head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + K = 4 + shift_cosines = [] + uph_rouges = [] + cvh_rouges = [] + + for i, ex in enumerate(examples): + support = select_k_profile_items(ex['profile_items'], K, seed=0) + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + continue + + # Mean support hidden + all_h = torch.cat([h for h, _ in cached_h], dim=0) + support_mean = all_h.mean(dim=0).numpy() + + # Mean query hidden + query_mean = get_query_hidden_mean(wrapper, ex['query_input'], ex['task']) + + # Cosine similarity + cos = np.dot(support_mean, query_mean) / ( + np.linalg.norm(support_mean) * np.linalg.norm(query_mean) + 1e-8) + shift_cosines.append(float(cos)) + + # Fit UPH theta + theta_uph = fit_theta( + cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, head_module=uph_head, + d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + + # Fit CVH theta + theta_cvh = fit_theta( + cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, head_module=cvh_head, + d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + + # Generate with both + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred_uph = wrapper.generate_with_head_blended( + prompt, theta_uph, uph_head.forward_fn, + blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0, + ) + pred_cvh = wrapper.generate_with_head_blended( + prompt, theta_cvh, cvh_head.forward_fn, + blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0, + ) + + # ROUGE-L for each + rouge_uph = compute_rouge([pred_uph], [ex['target_output']])['rougeL'] + rouge_cvh = compute_rouge([pred_cvh], [ex['target_output']])['rougeL'] + uph_rouges.append(rouge_uph) + cvh_rouges.append(rouge_cvh) + + del cached_h, theta_uph, theta_cvh + torch.cuda.empty_cache() + + if (i + 1) % 20 == 0: + print(f" {i+1}/{N}") + + # Compute correlation + shift_cosines = np.array(shift_cosines) + deltas = np.array(cvh_rouges) - np.array(uph_rouges) # positive = CVH better + + rho, pval = stats.spearmanr(shift_cosines, deltas) + + print(f"\n=== Results (N={len(shift_cosines)}) ===") + print(f" Mean shift cosine: {shift_cosines.mean():.4f} +/- {shift_cosines.std():.4f}") + print(f" Mean delta (CVH - UPH): {deltas.mean():.4f} +/- {deltas.std():.4f}") + print(f" Spearman(shift_cos, delta): rho={rho:.4f}, p={pval:.4f}") + print(f" Mean UPH ROUGE-L: {np.mean(uph_rouges):.4f}") + print(f" Mean CVH ROUGE-L: {np.mean(cvh_rouges):.4f}") + + # Bin analysis: high vs low shift + median_cos = np.median(shift_cosines) + high_mask = shift_cosines >= median_cos + low_mask = shift_cosines < median_cos + + print(f"\n High-alignment (cos >= {median_cos:.3f}, n={high_mask.sum()}):") + print(f" UPH R-L: {np.mean(np.array(uph_rouges)[high_mask]):.4f}") + print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[high_mask]):.4f}") + print(f" Low-alignment (cos < {median_cos:.3f}, n={low_mask.sum()}):") + print(f" UPH R-L: {np.mean(np.array(uph_rouges)[low_mask]):.4f}") + print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[low_mask]):.4f}") + + # Save + os.makedirs('outputs/analysis', exist_ok=True) + save_data = { + 'shift_cosines': [float(x) for x in shift_cosines], + 'uph_rouges': [float(x) for x in uph_rouges], + 'cvh_rouges': [float(x) for x in cvh_rouges], + 'deltas': [float(x) for x in deltas], + 'spearman_rho': float(rho), + 'spearman_pval': float(pval), + } + with open('outputs/analysis/shift_analysis.json', 'w') as f: + json.dump(save_data, f, indent=2) + print("\nSaved to outputs/analysis/shift_analysis.json") + + +if __name__ == '__main__': + main() diff --git a/scripts/sweep_alpha.py b/scripts/sweep_alpha.py new file mode 100644 index 0000000..bc35b0c --- /dev/null +++ b/scripts/sweep_alpha.py @@ -0,0 +1,122 @@ +"""Quick sweep over alpha values to find the right perturbation scale.""" + +import sys +import os +import json +import time +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from eval.metrics import evaluate_all + + +def run_cvh_with_params(wrapper, examples, support_sets, alpha, beta, steps, d=64, lr=0.05): + """Run CVH with specific hyperparameters.""" + device = 'cuda:1' + H = wrapper.hidden_size + head = CVHHead(H, d=d, alpha=alpha, basis_seed=42).to(device) + + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + predictions = [] + theta_norms = [] + + for i, (ex, support) in enumerate(zip(examples, support_sets)): + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=256) + predictions.append(pred) + continue + + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head, + d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + theta_norms.append(theta.norm().item()) + + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_with_head( + prompt, theta, head.forward_fn, + max_new_tokens=256, temperature=0.0, + ) + predictions.append(pred) + + del cached_h, theta + torch.cuda.empty_cache() + + avg_norm = sum(theta_norms) / max(len(theta_norms), 1) + return predictions, avg_norm + + +def main(): + print("Loading data...") + examples = load_longlamp('product_review_user', split='val')[:50] + K = 4 + support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + print("Loading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + + # Run base + print("\n=== Base ===") + base_preds = [] + for ex in examples: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=256, temperature=0.0) + base_preds.append(pred) + base_results = evaluate_all(base_preds, references, support_texts) + print(f" ROUGE-L: {base_results['rougeL']:.4f}, METEOR: {base_results['meteor']:.4f}, SFD: {base_results['sfd']:.4f}") + + # Sweep + configs = [ + {'alpha': 0.1, 'beta': 0.05, 'steps': 30, 'lr': 0.05}, + {'alpha': 0.3, 'beta': 0.05, 'steps': 30, 'lr': 0.05}, + {'alpha': 0.5, 'beta': 0.05, 'steps': 30, 'lr': 0.05}, + {'alpha': 0.3, 'beta': 0.01, 'steps': 50, 'lr': 0.05}, + {'alpha': 0.5, 'beta': 0.01, 'steps': 50, 'lr': 0.05}, + {'alpha': 0.3, 'beta': 0.01, 'steps': 50, 'lr': 0.1}, + ] + + all_results = {'Base': base_results} + + for cfg in configs: + name = f"a{cfg['alpha']}_b{cfg['beta']}_s{cfg['steps']}_lr{cfg['lr']}" + print(f"\n=== CVH {name} ===") + t0 = time.time() + preds, avg_norm = run_cvh_with_params( + wrapper, examples, support_sets, + alpha=cfg['alpha'], beta=cfg['beta'], + steps=cfg['steps'], lr=cfg['lr'], + ) + elapsed = time.time() - t0 + results = evaluate_all(preds, references, support_texts) + all_results[name] = results + print(f" ROUGE-L: {results['rougeL']:.4f}, METEOR: {results['meteor']:.4f}, " + f"SFD: {results['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, time: {elapsed:.0f}s") + + # Summary + print("\n" + "=" * 80) + print(f"{'Config':<40} {'ROUGE-L':<10} {'METEOR':<10} {'SFD':<10}") + print("-" * 80) + for name, r in all_results.items(): + print(f"{name:<40} {r['rougeL']:<10.4f} {r['meteor']:<10.4f} {r['sfd']:<10.4f}") + + +if __name__ == '__main__': + main() diff --git a/scripts/sweep_d_and_multi.py b/scripts/sweep_d_and_multi.py new file mode 100644 index 0000000..004f6cf --- /dev/null +++ b/scripts/sweep_d_and_multi.py @@ -0,0 +1,165 @@ +"""Sweep d values and test multi-basis CVH.""" + +import sys +import os +import time +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from eval.metrics import evaluate_all + + +class MultiBasisCVH(torch.nn.Module): + """Two-basis CVH: h'_t = h_t + a1*B1(theta⊙A1*h) + a2*B2(theta⊙A2*h)""" + + def __init__(self, hidden_size, d=64, alpha=0.1, basis_seed=42): + super().__init__() + self.hidden_size = hidden_size + self.d = d + self.alpha = alpha + + gen1 = torch.Generator() + gen1.manual_seed(basis_seed) + gen2 = torch.Generator() + gen2.manual_seed(basis_seed + 500) + + scale_a = 1.0 / (hidden_size ** 0.5) + scale_b = 1.0 / (d ** 0.5) + + self.register_buffer('A1', torch.randn(d, hidden_size, generator=gen1) * scale_a) + self.register_buffer('B1', torch.randn(hidden_size, d, generator=gen1) * scale_b) + self.register_buffer('A2', torch.randn(d, hidden_size, generator=gen2) * scale_a) + self.register_buffer('B2', torch.randn(hidden_size, d, generator=gen2) * scale_b) + + def forward(self, h, theta): + proj1 = (self.A1.float() @ h.T).T + gated1 = theta.unsqueeze(0) * proj1 + res1 = (self.B1.float() @ gated1.T).T + + proj2 = (self.A2.float() @ h.T).T + gated2 = theta.unsqueeze(0) * proj2 + res2 = (self.B2.float() @ gated2.T).T + + return h + self.alpha * (res1 + res2) + + def forward_fn(self, h, theta): + return self.forward(h, theta) + + +def run_head(wrapper, examples, support_sets, head_module, d=64, alpha=0.1, + beta=0.05, steps=30, lr=0.05, max_new_tokens=512): + device = 'cuda:1' + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + predictions = [] + theta_norms = [] + + for i, (ex, support) in enumerate(zip(examples, support_sets)): + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens) + predictions.append(pred) + continue + + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head_module, + d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + theta_norms.append(theta.norm().item()) + + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_with_head( + prompt, theta, head_module.forward_fn, + max_new_tokens=max_new_tokens, temperature=0.0, + ) + predictions.append(pred) + + del cached_h, theta + torch.cuda.empty_cache() + + if (i + 1) % 20 == 0: + print(f" {i+1}/{len(examples)}") + + avg_norm = sum(theta_norms) / max(len(theta_norms), 1) + return predictions, avg_norm + + +def main(): + print("Loading data...") + examples = load_longlamp('product_review_user', split='val')[:50] + K = 4 + support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + print("Loading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + H = wrapper.hidden_size + device = 'cuda:1' + + # Base + print("\n=== Base ===") + base_preds = [] + for ex in examples: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0) + base_preds.append(pred) + base_r = evaluate_all(base_preds, references, support_texts) + print(f" ROUGE-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}") + + results = {'Base': base_r} + configs = [ + ('CVH d=64', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + ('CVH d=128', CVHHead(H, d=128, alpha=0.1, basis_seed=42).to(device), 128), + ('CVH d=256', CVHHead(H, d=256, alpha=0.1, basis_seed=42).to(device), 256), + ('Uncond d=64', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + ('Uncond d=128', UnconditionalHead(H, d=128, alpha=0.1, basis_seed=42).to(device), 128), + ('MultiBasis d=64', MultiBasisCVH(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + # Higher beta to preserve content + ('CVH d=64 b=0.1', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + ('CVH d=64 b=0.2', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + ] + + betas = { + 'CVH d=64 b=0.1': 0.1, + 'CVH d=64 b=0.2': 0.2, + } + + for name, head, d in configs: + beta = betas.get(name, 0.05) + print(f"\n=== {name} (beta={beta}) ===") + t0 = time.time() + preds, avg_norm = run_head( + wrapper, examples, support_sets, head, d=d, + alpha=0.1, beta=beta, steps=30, lr=0.05, max_new_tokens=512, + ) + elapsed = time.time() - t0 + r = evaluate_all(preds, references, support_texts) + results[name] = r + print(f" ROUGE-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD: {r['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, time: {elapsed:.0f}s") + + # Summary + print("\n" + "=" * 90) + print(f"{'Config':<25} {'ROUGE-1':<10} {'ROUGE-L':<10} {'METEOR':<10} {'SFD':<10}") + print("-" * 90) + for name, r in results.items(): + print(f"{name:<25} {r['rouge1']:<10.4f} {r['rougeL']:<10.4f} {r['meteor']:<10.4f} {r['sfd']:<10.4f}") + + +if __name__ == '__main__': + main() diff --git a/scripts/test_length_fix.py b/scripts/test_length_fix.py new file mode 100644 index 0000000..ad3f358 --- /dev/null +++ b/scripts/test_length_fix.py @@ -0,0 +1,203 @@ +"""Test different strategies to fix the output length issue.""" + +import sys +import os +import time +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from eval.metrics import evaluate_all + + +def run_with_blend(wrapper, examples, support_sets, head_module, d=64, + beta=0.05, steps=30, lr=0.05, blend_gamma=0.5, + min_new_tokens=64, max_new_tokens=512): + """Run CVH with logit blending: logits = (1-gamma)*base + gamma*cvh""" + device = 'cuda:1' + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + predictions = [] + theta_norms = [] + + for i, (ex, support) in enumerate(zip(examples, support_sets)): + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens) + predictions.append(pred) + continue + + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head_module, + d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + theta_norms.append(theta.norm().item()) + + # Generate with blending + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = generate_with_blend( + wrapper, prompt, theta, head_module, + gamma=blend_gamma, max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + ) + predictions.append(pred) + + del cached_h, theta + torch.cuda.empty_cache() + + if (i + 1) % 20 == 0: + print(f" {i+1}/{len(examples)}") + + avg_norm = sum(theta_norms) / max(len(theta_norms), 1) + avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1) + return predictions, avg_norm, avg_len + + +def generate_with_blend(wrapper, input_text, theta, head_module, + gamma=0.5, max_new_tokens=512, min_new_tokens=64): + """Generate with blended base + CVH logits.""" + chat_messages = [ + {"role": "system", "content": "You are a helpful writing assistant."}, + {"role": "user", "content": input_text}, + ] + prompt_text = wrapper.tokenizer.apply_chat_template( + chat_messages, tokenize=False, add_generation_prompt=True + ) + input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device) + + generated_ids = [] + past_key_values = None + + for step in range(max_new_tokens): + if step == 0: + cur_input = input_ids + else: + cur_input = torch.tensor([[generated_ids[-1]]], device=wrapper.device) + + with torch.no_grad(): + outputs = wrapper.model( + input_ids=cur_input, + past_key_values=past_key_values, + output_hidden_states=True, + use_cache=True, + return_dict=True, + ) + + past_key_values = outputs.past_key_values + last_hidden = outputs.hidden_states[-1][:, -1, :] # (1, H) + + # Base logits + base_logits = torch.nn.functional.linear( + last_hidden.to(wrapper.lm_head_weight.dtype), + wrapper.lm_head_weight, + wrapper.model.lm_head.bias if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None else None, + ).float() + + # CVH logits + h_prime = head_module.forward_fn(last_hidden.float(), theta) + cvh_logits = torch.nn.functional.linear( + h_prime.to(wrapper.lm_head_weight.dtype), + wrapper.lm_head_weight, + wrapper.model.lm_head.bias if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None else None, + ).float() + + # Blend logits + logits = (1 - gamma) * base_logits + gamma * cvh_logits + + # Suppress EOS before min_new_tokens + if step < min_new_tokens and wrapper.tokenizer.eos_token_id is not None: + logits[0, wrapper.tokenizer.eos_token_id] = float('-inf') + + next_token = logits.argmax(dim=-1).item() + + if next_token == wrapper.tokenizer.eos_token_id: + break + + generated_ids.append(next_token) + + return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True) + + +def main(): + N = 50 + print(f"Loading data ({N} examples)...") + examples = load_longlamp('product_review_user', split='val')[:N] + K = 4 + support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + avg_ref_len = sum(len(r.split()) for r in references) / len(references) + print(f"Avg reference length: {avg_ref_len:.0f} words") + + print("Loading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + H = wrapper.hidden_size + device = 'cuda:1' + + # Base with min_new_tokens=200 + print("\n=== Base ===") + base_preds = [] + for ex in examples: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0) + base_preds.append(pred) + base_r = evaluate_all(base_preds, references, support_texts) + base_len = sum(len(p.split()) for p in base_preds) / len(base_preds) + print(f" R-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}, len: {base_len:.0f}") + + results = {'Base': {**base_r, 'avg_len': base_len}} + + head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + + configs = [ + # Standard CVH with higher min_new_tokens + ('CVH min=200', head, 1.0, 200), + # Blended CVH with different gammas + ('CVH blend=0.3 min=64', head, 0.3, 64), + ('CVH blend=0.5 min=64', head, 0.5, 64), + ('CVH blend=0.7 min=64', head, 0.7, 64), + ('CVH blend=0.5 min=128', head, 0.5, 128), + # Uncond blend + ('Uncond blend=0.5 min=64', uncond, 0.5, 64), + ] + + for name, head_mod, gamma, min_tok in configs: + print(f"\n=== {name} ===") + t0 = time.time() + preds, avg_norm, avg_len = run_with_blend( + wrapper, examples, support_sets, head_mod, d=64, + beta=0.05, steps=30, lr=0.05, blend_gamma=gamma, + min_new_tokens=min_tok, max_new_tokens=512, + ) + elapsed = time.time() - t0 + r = evaluate_all(preds, references, support_texts) + results[name] = {**r, 'avg_len': avg_len} + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, SFD: {r['sfd']:.4f}, " + f"|theta|: {avg_norm:.3f}, len: {avg_len:.0f}, time: {elapsed:.0f}s") + + # Summary + print("\n" + "=" * 100) + print(f"{'Config':<30} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8} {'Len':<6}") + print("-" * 100) + for name, r in results.items(): + print(f"{name:<30} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} " + f"{r['meteor']:<8.4f} {r['sfd']:<8.4f} {r.get('avg_len', 0):<6.0f}") + + +if __name__ == '__main__': + main() diff --git a/scripts/test_normalized_cvh.py b/scripts/test_normalized_cvh.py new file mode 100644 index 0000000..0057341 --- /dev/null +++ b/scripts/test_normalized_cvh.py @@ -0,0 +1,126 @@ +"""Quick test: normalized CVH vs Uncond with blending.""" + +import sys +import os +import time +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from eval.metrics import evaluate_all + + +def run_blended(wrapper, examples, support_sets, head_module, d=64, + beta=0.05, steps=30, lr=0.05, blend_gamma=0.5): + device = 'cuda:1' + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + predictions = [] + theta_norms = [] + + for i, (ex, support) in enumerate(zip(examples, support_sets)): + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=512) + predictions.append(pred) + continue + + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head_module, + d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=(i == 0), + ) + theta_norms.append(theta.norm().item()) + + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_with_head_blended( + prompt, theta, head_module.forward_fn, + blend_gamma=blend_gamma, + max_new_tokens=512, min_new_tokens=128, + temperature=0.0, + ) + predictions.append(pred) + + del cached_h, theta + torch.cuda.empty_cache() + + if (i + 1) % 20 == 0: + print(f" {i+1}/{len(examples)}") + + avg_norm = sum(theta_norms) / max(len(theta_norms), 1) + avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1) + return predictions, avg_norm, avg_len + + +def main(): + N = 100 + print(f"Loading data ({N} examples)...") + examples = load_longlamp('product_review_user', split='val')[:N] + K = 4 + support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + print("Loading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + H = wrapper.hidden_size + device = 'cuda:1' + + # Base + print("\n=== Base ===") + base_preds = [] + for i, ex in enumerate(examples): + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0) + base_preds.append(pred) + if (i + 1) % 20 == 0: + print(f" {i+1}/{N}") + base_r = evaluate_all(base_preds, references, support_texts) + base_len = sum(len(p.split()) for p in base_preds) / len(base_preds) + print(f" R-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}, len: {base_len:.0f}") + + results = {'Base': base_r} + + configs = [ + ('Uncond d=64 g=0.5', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.5), + ('CVH-norm d=64 g=0.5', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.5), + ('CVH-norm d=64 g=0.3', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.3), + ('CVH-norm d=64 g=0.7', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.7), + ] + + for name, head, d, gamma in configs: + print(f"\n=== {name} ===") + t0 = time.time() + preds, avg_norm, avg_len = run_blended( + wrapper, examples, support_sets, head, d=d, + beta=0.05, steps=30, lr=0.05, blend_gamma=gamma, + ) + elapsed = time.time() - t0 + r = evaluate_all(preds, references, support_texts) + results[name] = r + print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, SFD: {r['sfd']:.4f}, " + f"|θ|: {avg_norm:.3f}, len: {avg_len:.0f}, time: {elapsed:.0f}s") + + # Summary + print("\n" + "=" * 90) + print(f"{'Config':<30} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8}") + print("-" * 90) + for name, r in results.items(): + print(f"{name:<30} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} " + f"{r['meteor']:<8.4f} {r['sfd']:<8.4f}") + + +if __name__ == '__main__': + main() diff --git a/scripts/test_svd_cvh.py b/scripts/test_svd_cvh.py new file mode 100644 index 0000000..1f93bd1 --- /dev/null +++ b/scripts/test_svd_cvh.py @@ -0,0 +1,141 @@ +"""Test SVD-based CVH vs random basis CVH vs Unconditional.""" + +import sys +import os +import time +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.templates import build_query_prompt +from models.qwen_wrapper import QwenWrapper +from models.cvh import CVHHead, UnconditionalHead +from models.svd_cvh import SVDCVHHead, SVDUncondHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta +from eval.metrics import evaluate_all + + +def run_head(wrapper, examples, support_sets, head_module, d=64, + beta=0.05, steps=30, lr=0.05, max_new_tokens=512, + min_new_tokens=64): + device = 'cuda:1' + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + predictions = [] + theta_norms = [] + + for i, (ex, support) in enumerate(zip(examples, support_sets)): + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens) + predictions.append(pred) + continue + + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head_module, + d=d, lr=lr, steps=steps, beta=beta, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=(i == 0), + ) + theta_norms.append(theta.norm().item()) + + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_with_head( + prompt, theta, head_module.forward_fn, + max_new_tokens=max_new_tokens, temperature=0.0, + min_new_tokens=min_new_tokens, + ) + predictions.append(pred) + + del cached_h, theta + torch.cuda.empty_cache() + + if (i + 1) % 20 == 0: + print(f" {i+1}/{len(examples)}") + + avg_norm = sum(theta_norms) / max(len(theta_norms), 1) + # Check avg output length + avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1) + return predictions, avg_norm, avg_len + + +def main(): + N = 50 + print(f"Loading data ({N} examples)...") + examples = load_longlamp('product_review_user', split='val')[:N] + K = 4 + support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples] + references = [ex['target_output'] for ex in examples] + support_texts = [[s['support_output'] for s in ss] for ss in support_sets] + + avg_ref_len = sum(len(r.split()) for r in references) / len(references) + print(f"Avg reference length: {avg_ref_len:.0f} words") + + print("Loading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + H = wrapper.hidden_size + device = 'cuda:1' + + # Base + print("\n=== Base ===") + base_preds = [] + for ex in examples: + prompt = build_query_prompt(ex['query_input'], ex['task']) + pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0) + base_preds.append(pred) + base_r = evaluate_all(base_preds, references, support_texts) + base_len = sum(len(p.split()) for p in base_preds) / len(base_preds) + print(f" ROUGE-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, " + f"SFD: {base_r['sfd']:.4f}, avg_len: {base_len:.0f}") + + results = {} + results['Base'] = {**base_r, 'avg_len': base_len} + + # SVD-based heads + print("\nComputing SVD of lm_head...") + svd_cvh = SVDCVHHead(wrapper.lm_head_weight, d=64, alpha=0.1).to(device) + svd_uncond = SVDUncondHead(wrapper.lm_head_weight, d=64, alpha=0.1).to(device) + + configs = [ + ('Random CVH d=64', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + ('Random Uncond d=64', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64), + ('SVD CVH d=64', svd_cvh, 64), + ('SVD Uncond d=64', svd_uncond, 64), + # Try different alpha with SVD + ('SVD CVH d=64 a=0.05', SVDCVHHead(wrapper.lm_head_weight, d=64, alpha=0.05).to(device), 64), + ('SVD CVH d=64 a=0.2', SVDCVHHead(wrapper.lm_head_weight, d=64, alpha=0.2).to(device), 64), + ] + + for name, head, d in configs: + print(f"\n=== {name} ===") + t0 = time.time() + preds, avg_norm, avg_len = run_head( + wrapper, examples, support_sets, head, d=d, + beta=0.05, steps=30, lr=0.05, max_new_tokens=512, + min_new_tokens=64, + ) + elapsed = time.time() - t0 + r = evaluate_all(preds, references, support_texts) + results[name] = {**r, 'avg_len': avg_len} + print(f" ROUGE-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, " + f"SFD: {r['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, " + f"avg_len: {avg_len:.0f}, time: {elapsed:.0f}s") + + # Summary + print("\n" + "=" * 100) + print(f"{'Config':<25} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8} {'Len':<6}") + print("-" * 100) + for name, r in results.items(): + print(f"{name:<25} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} " + f"{r['meteor']:<8.4f} {r['sfd']:<8.4f} {r.get('avg_len', 0):<6.0f}") + + +if __name__ == '__main__': + main() diff --git a/scripts/theta_analysis.py b/scripts/theta_analysis.py new file mode 100644 index 0000000..94d4010 --- /dev/null +++ b/scripts/theta_analysis.py @@ -0,0 +1,281 @@ +"""User-state geometry / representational alignment analysis. + +Computes: +1. RSA: Spearman(cos(theta_u, theta_v), cos(phi_u, phi_v)) for all-style and -len/newline +2. Self-consistency: Delta_self = E_u[cos(theta_a, theta_b)] - E_{u!=v}[cos(theta_a, theta_v)] +3. Ridge probe: R^2 for predicting style features from theta +4. PCA visualization +""" + +import sys +import os +import json +import numpy as np +from scipy import stats +from sklearn.linear_model import Ridge +from sklearn.model_selection import cross_val_score +import torch + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from data.longlamp import load_longlamp, select_k_profile_items +from data.style_features import extract_style_features, FEATURE_NAMES +from models.qwen_wrapper import QwenWrapper +from models.cvh import UnconditionalHead +from adapt.cache_hidden import cache_support_hidden_states +from adapt.fit_theta import fit_theta + + +def collect_thetas_and_styles(wrapper, examples, K=4, seed=0): + """Collect theta_u and style prototypes for all users.""" + device = 'cuda:1' + H = wrapper.hidden_size + head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + thetas = [] + style_protos = [] + user_ids = [] + + for i, ex in enumerate(examples): + support = select_k_profile_items(ex['profile_items'], K, seed=seed) + cached_h = cache_support_hidden_states(wrapper, support, ex['task']) + if not cached_h: + continue + + theta = fit_theta( + cached_h=cached_h, + lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, + head_module=head, + d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + + thetas.append(theta.cpu().numpy()) + + # Compute style prototype + support_texts = [s['support_output'] for s in support] + features_list = [extract_style_features(t) for t in support_texts] + proto = np.mean(features_list, axis=0) + style_protos.append(proto) + user_ids.append(ex['user_id']) + + del cached_h, theta + torch.cuda.empty_cache() + + if (i + 1) % 40 == 0: + print(f" Collected {i+1}/{len(examples)}") + + return np.array(thetas), np.array(style_protos), user_ids + + +def compute_rsa(thetas, style_protos, exclude_indices=None): + """Compute RSA: Spearman correlation between theta similarity and style similarity.""" + N = len(thetas) + + # Theta cosine similarity matrix + theta_norms = np.linalg.norm(thetas, axis=1, keepdims=True) + theta_norms = np.maximum(theta_norms, 1e-8) + theta_normed = thetas / theta_norms + theta_sim = theta_normed @ theta_normed.T + + # Style cosine similarity matrix + if exclude_indices is not None: + style = np.delete(style_protos, exclude_indices, axis=1) + else: + style = style_protos.copy() + + style_norms = np.linalg.norm(style, axis=1, keepdims=True) + style_norms = np.maximum(style_norms, 1e-8) + style_normed = style / style_norms + style_sim = style_normed @ style_normed.T + + # Extract upper triangle + idx = np.triu_indices(N, k=1) + theta_upper = theta_sim[idx] + style_upper = style_sim[idx] + + rho, pval = stats.spearmanr(theta_upper, style_upper) + return rho, pval + + +def compute_self_consistency(wrapper, examples, K=4): + """Compute self-consistency by fitting theta with different support subsets.""" + device = 'cuda:1' + H = wrapper.hidden_size + head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device) + + lm_head_bias = None + if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None: + lm_head_bias = wrapper.model.lm_head.bias.data + + thetas_a = [] + thetas_b = [] + valid_indices = [] + + for i, ex in enumerate(examples): + profile = ex['profile_items'] + if len(profile) < 2 * K: + continue + + # Two different subsets + support_a = select_k_profile_items(profile, K, seed=100) + support_b = select_k_profile_items(profile, K, seed=200) + + cached_a = cache_support_hidden_states(wrapper, support_a, ex['task']) + cached_b = cache_support_hidden_states(wrapper, support_b, ex['task']) + + if not cached_a or not cached_b: + continue + + theta_a = fit_theta( + cached_h=cached_a, lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, head_module=head, + d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + theta_b = fit_theta( + cached_h=cached_b, lm_head_weight=wrapper.lm_head_weight, + lm_head_bias=lm_head_bias, head_module=head, + d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4, + max_grad_norm=5.0, device=device, verbose=False, + ) + + thetas_a.append(theta_a.cpu().numpy()) + thetas_b.append(theta_b.cpu().numpy()) + valid_indices.append(i) + + del cached_a, cached_b, theta_a, theta_b + torch.cuda.empty_cache() + + if (i + 1) % 40 == 0: + print(f" Self-consistency: {i+1}/{len(examples)} ({len(valid_indices)} valid)") + + thetas_a = np.array(thetas_a) + thetas_b = np.array(thetas_b) + N = len(thetas_a) + + if N < 5: + return 0.0, 0.0, 0.0 + + # Self similarity: cos(theta_a_u, theta_b_u) + self_cos = [] + for u in range(N): + cos = np.dot(thetas_a[u], thetas_b[u]) / ( + np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[u]) + 1e-8) + self_cos.append(cos) + avg_self = np.mean(self_cos) + + # Cross similarity: cos(theta_a_u, theta_b_v) for u != v + cross_cos = [] + for u in range(N): + for v in range(N): + if u == v: + continue + cos = np.dot(thetas_a[u], thetas_b[v]) / ( + np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[v]) + 1e-8) + cross_cos.append(cos) + avg_cross = np.mean(cross_cos) + + delta_self = avg_self - avg_cross + return avg_self, avg_cross, delta_self + + +def compute_ridge_probe(thetas, style_protos): + """Probe: predict each style feature from theta using Ridge regression.""" + results = {} + N = len(thetas) + + for i, feat_name in enumerate(FEATURE_NAMES): + y = style_protos[:, i] + + # Check if target has variance + if np.std(y) < 1e-8: + results[feat_name] = 0.0 + continue + + ridge = Ridge(alpha=1.0) + scores = cross_val_score(ridge, thetas, y, cv=min(5, N), scoring='r2') + results[feat_name] = max(np.mean(scores), 0.0) # Clip negative R2 to 0 + + return results + + +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--num_eval', type=int, default=200) + parser.add_argument('--config', type=str, default='product_review_user') + args = parser.parse_args() + + N = args.num_eval + print(f"=== Theta Analysis: {args.config}, N={N} ===") + + print("\nLoading data...") + examples = load_longlamp(args.config, split='val')[:N] + print(f"Loaded {len(examples)} examples") + + print("\nLoading model...") + wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1') + + # 1. Collect thetas and style prototypes + print("\n--- Collecting thetas and style prototypes ---") + thetas, style_protos, user_ids = collect_thetas_and_styles(wrapper, examples, K=4, seed=0) + print(f"Collected {len(thetas)} vectors") + + # 2. RSA + print("\n--- RSA (Representational Similarity Analysis) ---") + rho_all, pval_all = compute_rsa(thetas, style_protos) + # Exclude length (index 0) and newline_rate (index 3) + rho_nolen, pval_nolen = compute_rsa(thetas, style_protos, exclude_indices=[0, 3]) + print(f" rho_all: {rho_all:.4f} (p={pval_all:.2e})") + print(f" rho_-len/newline: {rho_nolen:.4f} (p={pval_nolen:.2e})") + + # 3. Self-consistency + print("\n--- Self-Consistency ---") + avg_self, avg_cross, delta_self = compute_self_consistency(wrapper, examples, K=4) + print(f" avg_self_cos: {avg_self:.4f}") + print(f" avg_cross_cos: {avg_cross:.4f}") + print(f" Delta_self: {delta_self:.4f}") + + # 4. Ridge probe + print("\n--- Ridge Probe (R^2) ---") + probe_results = compute_ridge_probe(thetas, style_protos) + for feat_name in FEATURE_NAMES: + r2 = probe_results[feat_name] + print(f" {feat_name:<20}: R^2 = {r2:.4f}") + + # Summary: the 6 key numbers + print("\n" + "=" * 60) + print("KEY NUMBERS FOR PAPER DECISION") + print("=" * 60) + print(f" rho_all: {rho_all:.4f}") + print(f" rho_-len/newline: {rho_nolen:.4f}") + print(f" Delta_self: {delta_self:.4f}") + print(f" R^2_TTR: {probe_results.get('TTR', 0.0):.4f}") + print(f" R^2_first_person: {probe_results.get('first_person_rate', 0.0):.4f}") + print(f" R^2_newline: {probe_results.get('newline_rate', 0.0):.4f}") + + # Save results + os.makedirs('outputs/analysis', exist_ok=True) + save_data = { + 'rsa_all': {'rho': float(rho_all), 'pval': float(pval_all)}, + 'rsa_nolen': {'rho': float(rho_nolen), 'pval': float(pval_nolen)}, + 'self_consistency': {'avg_self': float(avg_self), 'avg_cross': float(avg_cross), 'delta_self': float(delta_self)}, + 'probe_r2': {k: float(v) for k, v in probe_results.items()}, + 'num_users': len(thetas), + 'thetas': [[float(x) for x in row] for row in thetas], + 'style_protos': [[float(x) for x in row] for row in style_protos], + 'user_ids': user_ids, + } + with open('outputs/analysis/theta_analysis.json', 'w') as f: + json.dump(save_data, f, indent=2) + print("\nSaved to outputs/analysis/theta_analysis.json") + + +if __name__ == '__main__': + main() |
