summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
commit8fe28101366dd32562b8c5534d7fe359b252bdf3 (patch)
treec92a92184fb2f46f265ab84c1f754c3d5d6597bc /scripts
Initial commit: UPH project codebase and experiment results
Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'scripts')
-rw-r--r--scripts/__init__.py0
-rw-r--r--scripts/run_dev.py310
-rw-r--r--scripts/run_fair_audit.py381
-rw-r--r--scripts/shift_analysis.py178
-rw-r--r--scripts/sweep_alpha.py122
-rw-r--r--scripts/sweep_d_and_multi.py165
-rw-r--r--scripts/test_length_fix.py203
-rw-r--r--scripts/test_normalized_cvh.py126
-rw-r--r--scripts/test_svd_cvh.py141
-rw-r--r--scripts/theta_analysis.py281
10 files changed, 1907 insertions, 0 deletions
diff --git a/scripts/__init__.py b/scripts/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/scripts/__init__.py
diff --git a/scripts/run_dev.py b/scripts/run_dev.py
new file mode 100644
index 0000000..d8ccee5
--- /dev/null
+++ b/scripts/run_dev.py
@@ -0,0 +1,310 @@
+"""Main experiment runner for CVH development.
+
+Runs all methods (Base, Prompt-All-K, BM25-Top1, Unconditional, CVH) on
+LongLaMP validation set and reports metrics.
+"""
+
+import sys
+import os
+import json
+import time
+import argparse
+import yaml
+import torch
+
+# Add project root to path
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from baselines.prompt_all_k import generate_prompt_all_k
+from baselines.bm25_top1 import generate_bm25_top1
+from eval.metrics import evaluate_all, print_results_table, compute_recovery
+
+
+def run_method_base(wrapper, examples, cfg):
+ """Run Base (no personalization) method."""
+ predictions = []
+ for i, ex in enumerate(examples):
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(
+ prompt,
+ max_new_tokens=cfg['max_new_tokens'],
+ temperature=cfg.get('temperature', 0.7),
+ top_p=cfg.get('top_p', 0.9),
+ )
+ predictions.append(pred)
+ if (i + 1) % 20 == 0:
+ print(f" Base: {i+1}/{len(examples)}")
+ return predictions
+
+
+def run_method_prompt_all_k(wrapper, examples, support_sets, cfg):
+ """Run Prompt-All-K baseline."""
+ predictions = []
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ pred = generate_prompt_all_k(
+ wrapper, ex['query_input'], support, ex['task'],
+ max_new_tokens=cfg['max_new_tokens'],
+ temperature=cfg.get('temperature', 0.7),
+ top_p=cfg.get('top_p', 0.9),
+ )
+ predictions.append(pred)
+ if (i + 1) % 20 == 0:
+ print(f" Prompt-All-K: {i+1}/{len(examples)}")
+ return predictions
+
+
+def run_method_bm25_top1(wrapper, examples, support_sets, cfg):
+ """Run BM25-Top1 baseline."""
+ predictions = []
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ pred = generate_bm25_top1(
+ wrapper, ex['query_input'], support, ex['task'],
+ max_new_tokens=cfg['max_new_tokens'],
+ temperature=cfg.get('temperature', 0.7),
+ top_p=cfg.get('top_p', 0.9),
+ )
+ predictions.append(pred)
+ if (i + 1) % 20 == 0:
+ print(f" BM25-Top1: {i+1}/{len(examples)}")
+ return predictions
+
+
+def run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='cvh'):
+ """Run CVH or Unconditional head method."""
+ device = cfg['device']
+ d = cfg['d']
+ alpha = cfg['alpha']
+ basis_seed = cfg['basis_seed']
+ H = wrapper.hidden_size
+
+ # Create head
+ if head_type == 'cvh':
+ head = CVHHead(H, d=d, alpha=alpha, basis_seed=basis_seed).to(device)
+ else:
+ head = UnconditionalHead(H, d=d, alpha=alpha, basis_seed=basis_seed).to(device)
+
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ predictions = []
+ adapt_times = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ # Step 1: Cache hidden states from support set
+ t0 = time.time()
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+
+ if not cached_h:
+ # Fallback to base generation if caching fails
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=cfg['max_new_tokens'])
+ predictions.append(pred)
+ adapt_times.append(0.0)
+ continue
+
+ # Step 2: Fit theta_u
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head,
+ d=d,
+ lr=cfg['lr'],
+ steps=cfg['adapt_steps'],
+ beta=cfg['beta'],
+ lam=cfg['lam'],
+ max_grad_norm=cfg['max_grad_norm'],
+ device=device,
+ verbose=False,
+ )
+ adapt_time = time.time() - t0
+ adapt_times.append(adapt_time)
+
+ # Step 3: Generate with personalized head (blended)
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ blend_gamma = cfg.get('blend_gamma', 0.5)
+ pred = wrapper.generate_with_head_blended(
+ prompt, theta, head.forward_fn,
+ blend_gamma=blend_gamma,
+ max_new_tokens=cfg['max_new_tokens'],
+ min_new_tokens=cfg.get('min_new_tokens', 128),
+ temperature=cfg.get('temperature', 0.0),
+ )
+ predictions.append(pred)
+
+ # Cleanup GPU memory
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 10 == 0:
+ avg_adapt = sum(adapt_times) / len(adapt_times)
+ print(f" {head_type.upper()}: {i+1}/{len(examples)} "
+ f"(avg adapt: {avg_adapt:.2f}s)")
+
+ avg_adapt_time = sum(adapt_times) / max(len(adapt_times), 1)
+ print(f" Average adaptation time: {avg_adapt_time:.2f}s")
+
+ return predictions, avg_adapt_time
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--config', type=str, required=True)
+ parser.add_argument('--num_eval', type=int, default=None,
+ help='Override number of examples to evaluate')
+ parser.add_argument('--methods', type=str, default='all',
+ help='Comma-separated methods: base,prompt_all_k,bm25_top1,uncond,cvh')
+ parser.add_argument('--output_dir', type=str, default='outputs')
+ parser.add_argument('--seed', type=int, default=0)
+ args = parser.parse_args()
+
+ # Load config
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ if args.num_eval is not None:
+ cfg['num_eval'] = args.num_eval
+
+ torch.manual_seed(args.seed)
+
+ print(f"=== CVH Dev Experiment ===")
+ print(f"Config: {args.config}")
+ print(f"Task: {cfg['task']}, Setting: {cfg['setting']}")
+ print(f"Model: {cfg['model_name']}, Device: {cfg['device']}")
+ print(f"d={cfg['d']}, K={cfg['K']}, alpha={cfg['alpha']}, steps={cfg['adapt_steps']}")
+
+ # Load data
+ print("\nLoading LongLaMP data...")
+ examples = load_longlamp(cfg['dataset_config'], split='val')
+ print(f"Loaded {len(examples)} validation examples")
+
+ # Limit examples if specified
+ num_eval = cfg.get('num_eval', -1)
+ if num_eval > 0:
+ examples = examples[:num_eval]
+ print(f"Evaluating on {len(examples)} examples")
+
+ # Prepare support sets
+ K = cfg['K']
+ support_sets = []
+ for ex in examples:
+ support = select_k_profile_items(ex['profile_items'], K, seed=args.seed)
+ support_sets.append(support)
+
+ # Gather references and support texts for metrics
+ references = [ex['target_output'] for ex in examples]
+ support_texts_per_example = [
+ [s['support_output'] for s in support]
+ for support in support_sets
+ ]
+
+ # Parse methods to run
+ if args.methods == 'all':
+ methods_to_run = ['base', 'prompt_all_k', 'bm25_top1', 'uncond', 'cvh']
+ else:
+ methods_to_run = args.methods.split(',')
+
+ # Load model
+ print(f"\nLoading model {cfg['model_name']}...")
+ wrapper = QwenWrapper(cfg['model_name'], device=cfg['device'])
+ print(f"Model loaded. Hidden size: {wrapper.hidden_size}")
+
+ results = {}
+ all_predictions = {}
+
+ # Run each method
+ if 'base' in methods_to_run:
+ print("\n--- Running Base ---")
+ preds = run_method_base(wrapper, examples, cfg)
+ all_predictions['Base'] = preds
+ results['Base'] = evaluate_all(preds, references, support_texts_per_example)
+ print(f" ROUGE-L: {results['Base']['rougeL']:.4f}, METEOR: {results['Base']['meteor']:.4f}, SFD: {results['Base']['sfd']:.4f}")
+
+ if 'prompt_all_k' in methods_to_run:
+ print(f"\n--- Running Prompt-All-K (K={K}) ---")
+ preds = run_method_prompt_all_k(wrapper, examples, support_sets, cfg)
+ all_predictions['Prompt-All-K'] = preds
+ results['Prompt-All-K'] = evaluate_all(preds, references, support_texts_per_example)
+ print(f" ROUGE-L: {results['Prompt-All-K']['rougeL']:.4f}, METEOR: {results['Prompt-All-K']['meteor']:.4f}, SFD: {results['Prompt-All-K']['sfd']:.4f}")
+
+ if 'bm25_top1' in methods_to_run:
+ print(f"\n--- Running BM25-Top1 ---")
+ preds = run_method_bm25_top1(wrapper, examples, support_sets, cfg)
+ all_predictions['BM25-Top1'] = preds
+ results['BM25-Top1'] = evaluate_all(preds, references, support_texts_per_example)
+ print(f" ROUGE-L: {results['BM25-Top1']['rougeL']:.4f}, METEOR: {results['BM25-Top1']['meteor']:.4f}, SFD: {results['BM25-Top1']['sfd']:.4f}")
+
+ if 'uncond' in methods_to_run:
+ print(f"\n--- Running Unconditional Head ---")
+ preds, adapt_time = run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='uncond')
+ all_predictions['Uncond-Head'] = preds
+ results['Uncond-Head'] = evaluate_all(preds, references, support_texts_per_example)
+ results['Uncond-Head']['adapt_time'] = adapt_time
+ print(f" ROUGE-L: {results['Uncond-Head']['rougeL']:.4f}, METEOR: {results['Uncond-Head']['meteor']:.4f}, SFD: {results['Uncond-Head']['sfd']:.4f}")
+
+ if 'cvh' in methods_to_run:
+ print(f"\n--- Running CVH ---")
+ preds, adapt_time = run_method_vector_head(wrapper, examples, support_sets, cfg, head_type='cvh')
+ all_predictions['CVH'] = preds
+ results['CVH'] = evaluate_all(preds, references, support_texts_per_example)
+ results['CVH']['adapt_time'] = adapt_time
+ print(f" ROUGE-L: {results['CVH']['rougeL']:.4f}, METEOR: {results['CVH']['meteor']:.4f}, SFD: {results['CVH']['sfd']:.4f}")
+
+ # Print comparison table
+ print("\n" + "=" * 70)
+ print("RESULTS SUMMARY")
+ print("=" * 70)
+ print_results_table(results)
+
+ # Compute compression for vector methods
+ print("\n--- Efficiency ---")
+ theta_bytes = cfg['d'] * 2 # bf16 = 2 bytes per dim
+ for ex_support in support_texts_per_example[:5]:
+ from eval.metrics import compute_compression
+ comp = compute_compression(ex_support, theta_bytes)
+ print(f" Compression ratio (example): {comp:.0f}x")
+
+ print(f" Theta size: {theta_bytes} bytes")
+ print(f" Personalization prompt tokens at inference: 0")
+
+ # Save results
+ os.makedirs(args.output_dir, exist_ok=True)
+ exp_name = f"{cfg['task']}_{cfg['setting']}_K{K}_d{cfg['d']}"
+ output_path = os.path.join(args.output_dir, f"{exp_name}_results.json")
+
+ # Convert results to serializable format
+ save_data = {
+ 'config': cfg,
+ 'results': results,
+ 'num_examples': len(examples),
+ }
+ with open(output_path, 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print(f"\nResults saved to {output_path}")
+
+ # Save predictions
+ pred_path = os.path.join(args.output_dir, f"{exp_name}_predictions.json")
+ save_preds = {}
+ for method, preds in all_predictions.items():
+ save_preds[method] = [
+ {
+ 'example_id': examples[i]['example_id'],
+ 'prediction': preds[i],
+ 'reference': references[i],
+ }
+ for i in range(len(preds))
+ ]
+ with open(pred_path, 'w') as f:
+ json.dump(save_preds, f, indent=2)
+ print(f"Predictions saved to {pred_path}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/run_fair_audit.py b/scripts/run_fair_audit.py
new file mode 100644
index 0000000..73cc744
--- /dev/null
+++ b/scripts/run_fair_audit.py
@@ -0,0 +1,381 @@
+"""Fair audit experiment: all methods use the SAME decode policy.
+
+Decode policy: greedy (temperature=0), min_new_tokens=128, max_new_tokens=512.
+For vector methods: blended generation with gamma=0.5.
+For base/prompt methods: standard generation with min_new_tokens=128.
+
+Reports: ROUGE-L, METEOR, SFD_all, SFD_-len, avg_len, feature-level deltas.
+"""
+
+import sys
+import os
+import json
+import time
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from data.style_features import (
+ extract_style_features, compute_sfd, compute_feature_deltas, FEATURE_NAMES
+)
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from adapt.fit_theta_weighted import fit_theta_weighted
+from baselines.prompt_all_k import generate_prompt_all_k
+from baselines.bm25_top1 import generate_bm25_top1
+from eval.metrics import compute_rouge, compute_meteor
+
+
+def generate_base_with_min(wrapper, input_text, max_new_tokens=512, min_new_tokens=128):
+ """Base generation with min_new_tokens constraint (fair decode policy)."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ with torch.no_grad():
+ outputs = wrapper.model.generate(
+ input_ids,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ temperature=None,
+ top_p=None,
+ do_sample=False,
+ pad_token_id=wrapper.tokenizer.pad_token_id,
+ )
+
+ generated_ids = outputs[0, input_ids.shape[1]:]
+ return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+
+def generate_prompt_with_min(wrapper, input_text, max_new_tokens=512, min_new_tokens=128):
+ """Prompt-based generation with min_new_tokens constraint."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ with torch.no_grad():
+ outputs = wrapper.model.generate(
+ input_ids,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ temperature=None,
+ top_p=None,
+ do_sample=False,
+ pad_token_id=wrapper.tokenizer.pad_token_id,
+ )
+
+ generated_ids = outputs[0, input_ids.shape[1]:]
+ return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+
+def compute_detailed_metrics(predictions, references, support_texts_per_example):
+ """Compute comprehensive metrics including SFD split and feature-level analysis."""
+ rouge = compute_rouge(predictions, references)
+ meteor = compute_meteor(predictions, references)
+
+ # SFD all features
+ sfd_all_list = []
+ sfd_nolen_list = []
+ feature_deltas_all = {name: [] for name in FEATURE_NAMES}
+
+ for pred, support_texts in zip(predictions, support_texts_per_example):
+ if not pred.strip():
+ pred = "empty"
+ sfd_all = compute_sfd(pred, support_texts, exclude_length=False)
+ sfd_nolen = compute_sfd(pred, support_texts, exclude_length=True)
+ sfd_all_list.append(sfd_all)
+ sfd_nolen_list.append(sfd_nolen)
+
+ deltas = compute_feature_deltas(pred, support_texts)
+ for name in FEATURE_NAMES:
+ if name in deltas:
+ feature_deltas_all[name].append(deltas[name]['delta'])
+
+ avg_sfd_all = sum(sfd_all_list) / max(len(sfd_all_list), 1)
+ avg_sfd_nolen = sum(sfd_nolen_list) / max(len(sfd_nolen_list), 1)
+
+ avg_feature_deltas = {}
+ for name in FEATURE_NAMES:
+ vals = feature_deltas_all[name]
+ avg_feature_deltas[name] = sum(vals) / max(len(vals), 1)
+
+ avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1)
+
+ return {
+ 'rouge1': rouge['rouge1'],
+ 'rougeL': rouge['rougeL'],
+ 'meteor': meteor,
+ 'sfd_all': avg_sfd_all,
+ 'sfd_nolen': avg_sfd_nolen,
+ 'avg_len': avg_len,
+ 'feature_deltas': avg_feature_deltas,
+ }
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_eval', type=int, default=200)
+ parser.add_argument('--task', type=str, default='review', choices=['review', 'topic'])
+ parser.add_argument('--setting', type=str, default='user', choices=['user', 'temporal'])
+ parser.add_argument('--output_dir', type=str, default='outputs/fair_audit')
+ args = parser.parse_args()
+
+ N = args.num_eval
+ task = args.task
+ setting = args.setting
+
+ config_map = {
+ ('review', 'user'): 'product_review_user',
+ ('review', 'temporal'): 'product_review_temporal',
+ ('topic', 'user'): 'topic_writing_user',
+ ('topic', 'temporal'): 'topic_writing_temporal',
+ }
+ config_name = config_map[(task, setting)]
+
+ print(f"=== Fair Audit: {task}_{setting}, N={N} ===")
+ print(f"Decode policy: greedy, min_new_tokens=128, max_new_tokens=512")
+ print(f"Vector methods: blended gamma=0.5")
+
+ print("\nLoading data...")
+ examples = load_longlamp(config_name, split='val')[:N]
+ K = 4
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ avg_ref_len = sum(len(r.split()) for r in references) / len(references)
+ print(f"Examples: {len(examples)}, Avg reference len: {avg_ref_len:.0f}")
+
+ print("\nLoading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+ H = wrapper.hidden_size
+ device = 'cuda:1'
+
+ all_results = {}
+ all_predictions = {}
+
+ # === 1. Base (with min_new_tokens=128) ===
+ print("\n--- Base ---")
+ preds = []
+ for i, ex in enumerate(examples):
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_base_with_min(wrapper, prompt, min_new_tokens=128)
+ preds.append(pred)
+ if (i + 1) % 40 == 0:
+ print(f" {i+1}/{N}")
+ r = compute_detailed_metrics(preds, references, support_texts)
+ all_results['Base'] = r
+ all_predictions['Base'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # === 2. Prompt-All-K (with min_new_tokens=128) ===
+ print(f"\n--- Prompt-All-K (K=4) ---")
+ preds = []
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ from data.templates import build_prompt_with_examples
+ prompt = build_prompt_with_examples(ex['query_input'], support, ex['task'])
+ pred = generate_prompt_with_min(wrapper, prompt, min_new_tokens=128)
+ preds.append(pred)
+ if (i + 1) % 40 == 0:
+ print(f" {i+1}/{N}")
+ r = compute_detailed_metrics(preds, references, support_texts)
+ all_results['Prompt-All-K'] = r
+ all_predictions['Prompt-All-K'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # === 3. BM25-Top1 (with min_new_tokens=128) ===
+ print(f"\n--- BM25-Top1 ---")
+ preds = []
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ from baselines.bm25_top1 import bm25_select_top1
+ from data.templates import build_prompt_with_examples
+ selected = bm25_select_top1(ex['query_input'], support)
+ prompt = build_prompt_with_examples(ex['query_input'], selected, ex['task'])
+ pred = generate_prompt_with_min(wrapper, prompt, min_new_tokens=128)
+ preds.append(pred)
+ if (i + 1) % 40 == 0:
+ print(f" {i+1}/{N}")
+ r = compute_detailed_metrics(preds, references, support_texts)
+ all_results['BM25-Top1'] = r
+ all_predictions['BM25-Top1'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # Helper for vector head methods
+ def run_vector_head(name, head_module, d=64, use_weighted=False):
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ preds = []
+ adapt_times = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ t0 = time.time()
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_base_with_min(wrapper, prompt)
+ preds.append(pred)
+ adapt_times.append(0.0)
+ continue
+
+ if use_weighted:
+ theta = fit_theta_weighted(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head_module,
+ tokenizer=wrapper.tokenizer,
+ d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device,
+ max_tokens_per_item=128,
+ verbose=False,
+ )
+ else:
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head_module,
+ d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device,
+ verbose=False,
+ )
+
+ adapt_time = time.time() - t0
+ adapt_times.append(adapt_time)
+
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_with_head_blended(
+ prompt, theta, head_module.forward_fn,
+ blend_gamma=0.5, max_new_tokens=512,
+ min_new_tokens=128, temperature=0.0,
+ )
+ preds.append(pred)
+
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 40 == 0:
+ avg_t = sum(adapt_times) / len(adapt_times)
+ print(f" {i+1}/{N} (avg adapt: {avg_t:.1f}s)")
+
+ avg_adapt = sum(adapt_times) / max(len(adapt_times), 1)
+ r = compute_detailed_metrics(preds, references, support_texts)
+ r['adapt_time'] = avg_adapt
+ return preds, r
+
+ # === 4. Uncond-Head ===
+ print(f"\n--- Uncond-Head ---")
+ uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+ preds, r = run_vector_head('Uncond', uncond, d=64)
+ all_results['Uncond-Head'] = r
+ all_predictions['Uncond-Head'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # === 5. CVH ===
+ print(f"\n--- CVH ---")
+ cvh = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+ preds, r = run_vector_head('CVH', cvh, d=64)
+ all_results['CVH'] = r
+ all_predictions['CVH'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # === 6. Uncond-Head with style-weighted loss ===
+ print(f"\n--- Uncond-Head (style-weighted) ---")
+ uncond_sw = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+ preds, r = run_vector_head('Uncond-SW', uncond_sw, d=64, use_weighted=True)
+ all_results['Uncond-SW'] = r
+ all_predictions['Uncond-SW'] = preds
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD_all: {r['sfd_all']:.4f}, SFD_-len: {r['sfd_nolen']:.4f}, len: {r['avg_len']:.0f}")
+
+ # === Print comprehensive results ===
+ print("\n" + "=" * 110)
+ print("COMPREHENSIVE RESULTS (FAIR DECODE POLICY)")
+ print("=" * 110)
+ header = f"{'Method':<20} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD_all':<8} {'SFD_-len':<8} {'Len':<6}"
+ print(header)
+ print("-" * 110)
+ for name, r in all_results.items():
+ print(f"{name:<20} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} {r['meteor']:<8.4f} "
+ f"{r['sfd_all']:<8.4f} {r['sfd_nolen']:<8.4f} {r['avg_len']:<6.0f}")
+
+ # Feature-level analysis
+ print("\n" + "=" * 110)
+ print("FEATURE-LEVEL DELTAS (gen - proto, closer to 0 = better)")
+ print("=" * 110)
+ header = f"{'Method':<20}" + "".join(f"{n:<14}" for n in FEATURE_NAMES)
+ print(header)
+ print("-" * 110)
+ for name, r in all_results.items():
+ fd = r['feature_deltas']
+ row = f"{name:<20}"
+ for feat_name in FEATURE_NAMES:
+ row += f"{fd[feat_name]:<14.3f}"
+ print(row)
+
+ # Recovery analysis
+ if 'BM25-Top1' in all_results and 'Base' in all_results:
+ print("\n--- Recovery (vs BM25-Top1 baseline) ---")
+ base_rl = all_results['Base']['rougeL']
+ bm25_rl = all_results['BM25-Top1']['rougeL']
+ base_m = all_results['Base']['meteor']
+ bm25_m = all_results['BM25-Top1']['meteor']
+
+ for name, r in all_results.items():
+ if name in ('Base', 'Prompt-All-K', 'BM25-Top1'):
+ continue
+ denom_rl = bm25_rl - base_rl
+ denom_m = bm25_m - base_m
+ rec_rl = (r['rougeL'] - base_rl) / denom_rl if abs(denom_rl) > 1e-8 else 0
+ rec_m = (r['meteor'] - base_m) / denom_m if abs(denom_m) > 1e-8 else 0
+ print(f" {name}: R-L Recovery={rec_rl:.3f}, METEOR Recovery={rec_m:.3f}")
+
+ # Efficiency
+ print("\n--- Efficiency ---")
+ theta_bytes = 64 * 2 # d=64, bf16
+ print(f" Theta size: {theta_bytes} bytes")
+ print(f" Personalization prompt tokens at inference: 0 (vector methods)")
+ if support_texts:
+ from eval.metrics import compute_compression
+ comp = compute_compression(support_texts[0], theta_bytes)
+ print(f" Compression ratio (example): {comp:.0f}x")
+
+ # Save results
+ os.makedirs(args.output_dir, exist_ok=True)
+ exp_name = f"{task}_{setting}_K4_d64_N{N}_fair"
+ output_path = os.path.join(args.output_dir, f"{exp_name}_results.json")
+
+ save_data = {
+ 'results': {k: {kk: vv for kk, vv in v.items()} for k, v in all_results.items()},
+ 'num_examples': len(examples),
+ 'decode_policy': 'greedy, min_new_tokens=128, max_new_tokens=512, blend_gamma=0.5',
+ }
+ with open(output_path, 'w') as f:
+ json.dump(save_data, f, indent=2, default=str)
+ print(f"\nResults saved to {output_path}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/shift_analysis.py b/scripts/shift_analysis.py
new file mode 100644
index 0000000..99c7fd2
--- /dev/null
+++ b/scripts/shift_analysis.py
@@ -0,0 +1,178 @@
+"""Support-query distribution shift analysis.
+
+For each user, compute:
+ s_u = cos(mean_support_hidden, mean_query_hidden)
+Then correlate with CVH-UPH performance gap:
+ delta_u = ROUGE-L(CVH, u) - ROUGE-L(UPH, u)
+
+If correlation is positive: CVH benefits when support-query are aligned.
+"""
+
+import sys
+import os
+import json
+import numpy as np
+from scipy import stats
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt, build_support_prompt
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from eval.metrics import compute_rouge
+
+
+def get_query_hidden_mean(wrapper, query_text, task):
+ """Get mean hidden state from the query prompt."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": build_query_prompt(query_text, task)},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ with torch.no_grad():
+ outputs = wrapper.model(
+ input_ids=input_ids,
+ output_hidden_states=True,
+ return_dict=True,
+ )
+ last_hidden = outputs.hidden_states[-1][0] # (seq_len, H)
+ return last_hidden.mean(dim=0).cpu().float().numpy()
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_eval', type=int, default=100)
+ parser.add_argument('--config', type=str, default='product_review_user')
+ args = parser.parse_args()
+
+ N = args.num_eval
+ print(f"=== Shift Analysis: {args.config}, N={N} ===")
+
+ print("Loading data...")
+ examples = load_longlamp(args.config, split='val')[:N]
+
+ print("Loading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+ H = wrapper.hidden_size
+ device = 'cuda:1'
+
+ uph_head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+ cvh_head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ K = 4
+ shift_cosines = []
+ uph_rouges = []
+ cvh_rouges = []
+
+ for i, ex in enumerate(examples):
+ support = select_k_profile_items(ex['profile_items'], K, seed=0)
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ continue
+
+ # Mean support hidden
+ all_h = torch.cat([h for h, _ in cached_h], dim=0)
+ support_mean = all_h.mean(dim=0).numpy()
+
+ # Mean query hidden
+ query_mean = get_query_hidden_mean(wrapper, ex['query_input'], ex['task'])
+
+ # Cosine similarity
+ cos = np.dot(support_mean, query_mean) / (
+ np.linalg.norm(support_mean) * np.linalg.norm(query_mean) + 1e-8)
+ shift_cosines.append(float(cos))
+
+ # Fit UPH theta
+ theta_uph = fit_theta(
+ cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias, head_module=uph_head,
+ d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+
+ # Fit CVH theta
+ theta_cvh = fit_theta(
+ cached_h=cached_h, lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias, head_module=cvh_head,
+ d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+
+ # Generate with both
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred_uph = wrapper.generate_with_head_blended(
+ prompt, theta_uph, uph_head.forward_fn,
+ blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0,
+ )
+ pred_cvh = wrapper.generate_with_head_blended(
+ prompt, theta_cvh, cvh_head.forward_fn,
+ blend_gamma=0.5, max_new_tokens=512, min_new_tokens=128, temperature=0.0,
+ )
+
+ # ROUGE-L for each
+ rouge_uph = compute_rouge([pred_uph], [ex['target_output']])['rougeL']
+ rouge_cvh = compute_rouge([pred_cvh], [ex['target_output']])['rougeL']
+ uph_rouges.append(rouge_uph)
+ cvh_rouges.append(rouge_cvh)
+
+ del cached_h, theta_uph, theta_cvh
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 20 == 0:
+ print(f" {i+1}/{N}")
+
+ # Compute correlation
+ shift_cosines = np.array(shift_cosines)
+ deltas = np.array(cvh_rouges) - np.array(uph_rouges) # positive = CVH better
+
+ rho, pval = stats.spearmanr(shift_cosines, deltas)
+
+ print(f"\n=== Results (N={len(shift_cosines)}) ===")
+ print(f" Mean shift cosine: {shift_cosines.mean():.4f} +/- {shift_cosines.std():.4f}")
+ print(f" Mean delta (CVH - UPH): {deltas.mean():.4f} +/- {deltas.std():.4f}")
+ print(f" Spearman(shift_cos, delta): rho={rho:.4f}, p={pval:.4f}")
+ print(f" Mean UPH ROUGE-L: {np.mean(uph_rouges):.4f}")
+ print(f" Mean CVH ROUGE-L: {np.mean(cvh_rouges):.4f}")
+
+ # Bin analysis: high vs low shift
+ median_cos = np.median(shift_cosines)
+ high_mask = shift_cosines >= median_cos
+ low_mask = shift_cosines < median_cos
+
+ print(f"\n High-alignment (cos >= {median_cos:.3f}, n={high_mask.sum()}):")
+ print(f" UPH R-L: {np.mean(np.array(uph_rouges)[high_mask]):.4f}")
+ print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[high_mask]):.4f}")
+ print(f" Low-alignment (cos < {median_cos:.3f}, n={low_mask.sum()}):")
+ print(f" UPH R-L: {np.mean(np.array(uph_rouges)[low_mask]):.4f}")
+ print(f" CVH R-L: {np.mean(np.array(cvh_rouges)[low_mask]):.4f}")
+
+ # Save
+ os.makedirs('outputs/analysis', exist_ok=True)
+ save_data = {
+ 'shift_cosines': [float(x) for x in shift_cosines],
+ 'uph_rouges': [float(x) for x in uph_rouges],
+ 'cvh_rouges': [float(x) for x in cvh_rouges],
+ 'deltas': [float(x) for x in deltas],
+ 'spearman_rho': float(rho),
+ 'spearman_pval': float(pval),
+ }
+ with open('outputs/analysis/shift_analysis.json', 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print("\nSaved to outputs/analysis/shift_analysis.json")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/sweep_alpha.py b/scripts/sweep_alpha.py
new file mode 100644
index 0000000..bc35b0c
--- /dev/null
+++ b/scripts/sweep_alpha.py
@@ -0,0 +1,122 @@
+"""Quick sweep over alpha values to find the right perturbation scale."""
+
+import sys
+import os
+import json
+import time
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from eval.metrics import evaluate_all
+
+
+def run_cvh_with_params(wrapper, examples, support_sets, alpha, beta, steps, d=64, lr=0.05):
+ """Run CVH with specific hyperparameters."""
+ device = 'cuda:1'
+ H = wrapper.hidden_size
+ head = CVHHead(H, d=d, alpha=alpha, basis_seed=42).to(device)
+
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ predictions = []
+ theta_norms = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=256)
+ predictions.append(pred)
+ continue
+
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head,
+ d=d, lr=lr, steps=steps, beta=beta, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+ theta_norms.append(theta.norm().item())
+
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_with_head(
+ prompt, theta, head.forward_fn,
+ max_new_tokens=256, temperature=0.0,
+ )
+ predictions.append(pred)
+
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ avg_norm = sum(theta_norms) / max(len(theta_norms), 1)
+ return predictions, avg_norm
+
+
+def main():
+ print("Loading data...")
+ examples = load_longlamp('product_review_user', split='val')[:50]
+ K = 4
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ print("Loading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+
+ # Run base
+ print("\n=== Base ===")
+ base_preds = []
+ for ex in examples:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=256, temperature=0.0)
+ base_preds.append(pred)
+ base_results = evaluate_all(base_preds, references, support_texts)
+ print(f" ROUGE-L: {base_results['rougeL']:.4f}, METEOR: {base_results['meteor']:.4f}, SFD: {base_results['sfd']:.4f}")
+
+ # Sweep
+ configs = [
+ {'alpha': 0.1, 'beta': 0.05, 'steps': 30, 'lr': 0.05},
+ {'alpha': 0.3, 'beta': 0.05, 'steps': 30, 'lr': 0.05},
+ {'alpha': 0.5, 'beta': 0.05, 'steps': 30, 'lr': 0.05},
+ {'alpha': 0.3, 'beta': 0.01, 'steps': 50, 'lr': 0.05},
+ {'alpha': 0.5, 'beta': 0.01, 'steps': 50, 'lr': 0.05},
+ {'alpha': 0.3, 'beta': 0.01, 'steps': 50, 'lr': 0.1},
+ ]
+
+ all_results = {'Base': base_results}
+
+ for cfg in configs:
+ name = f"a{cfg['alpha']}_b{cfg['beta']}_s{cfg['steps']}_lr{cfg['lr']}"
+ print(f"\n=== CVH {name} ===")
+ t0 = time.time()
+ preds, avg_norm = run_cvh_with_params(
+ wrapper, examples, support_sets,
+ alpha=cfg['alpha'], beta=cfg['beta'],
+ steps=cfg['steps'], lr=cfg['lr'],
+ )
+ elapsed = time.time() - t0
+ results = evaluate_all(preds, references, support_texts)
+ all_results[name] = results
+ print(f" ROUGE-L: {results['rougeL']:.4f}, METEOR: {results['meteor']:.4f}, "
+ f"SFD: {results['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, time: {elapsed:.0f}s")
+
+ # Summary
+ print("\n" + "=" * 80)
+ print(f"{'Config':<40} {'ROUGE-L':<10} {'METEOR':<10} {'SFD':<10}")
+ print("-" * 80)
+ for name, r in all_results.items():
+ print(f"{name:<40} {r['rougeL']:<10.4f} {r['meteor']:<10.4f} {r['sfd']:<10.4f}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/sweep_d_and_multi.py b/scripts/sweep_d_and_multi.py
new file mode 100644
index 0000000..004f6cf
--- /dev/null
+++ b/scripts/sweep_d_and_multi.py
@@ -0,0 +1,165 @@
+"""Sweep d values and test multi-basis CVH."""
+
+import sys
+import os
+import time
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from eval.metrics import evaluate_all
+
+
+class MultiBasisCVH(torch.nn.Module):
+ """Two-basis CVH: h'_t = h_t + a1*B1(theta⊙A1*h) + a2*B2(theta⊙A2*h)"""
+
+ def __init__(self, hidden_size, d=64, alpha=0.1, basis_seed=42):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.d = d
+ self.alpha = alpha
+
+ gen1 = torch.Generator()
+ gen1.manual_seed(basis_seed)
+ gen2 = torch.Generator()
+ gen2.manual_seed(basis_seed + 500)
+
+ scale_a = 1.0 / (hidden_size ** 0.5)
+ scale_b = 1.0 / (d ** 0.5)
+
+ self.register_buffer('A1', torch.randn(d, hidden_size, generator=gen1) * scale_a)
+ self.register_buffer('B1', torch.randn(hidden_size, d, generator=gen1) * scale_b)
+ self.register_buffer('A2', torch.randn(d, hidden_size, generator=gen2) * scale_a)
+ self.register_buffer('B2', torch.randn(hidden_size, d, generator=gen2) * scale_b)
+
+ def forward(self, h, theta):
+ proj1 = (self.A1.float() @ h.T).T
+ gated1 = theta.unsqueeze(0) * proj1
+ res1 = (self.B1.float() @ gated1.T).T
+
+ proj2 = (self.A2.float() @ h.T).T
+ gated2 = theta.unsqueeze(0) * proj2
+ res2 = (self.B2.float() @ gated2.T).T
+
+ return h + self.alpha * (res1 + res2)
+
+ def forward_fn(self, h, theta):
+ return self.forward(h, theta)
+
+
+def run_head(wrapper, examples, support_sets, head_module, d=64, alpha=0.1,
+ beta=0.05, steps=30, lr=0.05, max_new_tokens=512):
+ device = 'cuda:1'
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ predictions = []
+ theta_norms = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens)
+ predictions.append(pred)
+ continue
+
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head_module,
+ d=d, lr=lr, steps=steps, beta=beta, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+ theta_norms.append(theta.norm().item())
+
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_with_head(
+ prompt, theta, head_module.forward_fn,
+ max_new_tokens=max_new_tokens, temperature=0.0,
+ )
+ predictions.append(pred)
+
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 20 == 0:
+ print(f" {i+1}/{len(examples)}")
+
+ avg_norm = sum(theta_norms) / max(len(theta_norms), 1)
+ return predictions, avg_norm
+
+
+def main():
+ print("Loading data...")
+ examples = load_longlamp('product_review_user', split='val')[:50]
+ K = 4
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ print("Loading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+ H = wrapper.hidden_size
+ device = 'cuda:1'
+
+ # Base
+ print("\n=== Base ===")
+ base_preds = []
+ for ex in examples:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0)
+ base_preds.append(pred)
+ base_r = evaluate_all(base_preds, references, support_texts)
+ print(f" ROUGE-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}")
+
+ results = {'Base': base_r}
+ configs = [
+ ('CVH d=64', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ ('CVH d=128', CVHHead(H, d=128, alpha=0.1, basis_seed=42).to(device), 128),
+ ('CVH d=256', CVHHead(H, d=256, alpha=0.1, basis_seed=42).to(device), 256),
+ ('Uncond d=64', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ ('Uncond d=128', UnconditionalHead(H, d=128, alpha=0.1, basis_seed=42).to(device), 128),
+ ('MultiBasis d=64', MultiBasisCVH(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ # Higher beta to preserve content
+ ('CVH d=64 b=0.1', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ ('CVH d=64 b=0.2', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ ]
+
+ betas = {
+ 'CVH d=64 b=0.1': 0.1,
+ 'CVH d=64 b=0.2': 0.2,
+ }
+
+ for name, head, d in configs:
+ beta = betas.get(name, 0.05)
+ print(f"\n=== {name} (beta={beta}) ===")
+ t0 = time.time()
+ preds, avg_norm = run_head(
+ wrapper, examples, support_sets, head, d=d,
+ alpha=0.1, beta=beta, steps=30, lr=0.05, max_new_tokens=512,
+ )
+ elapsed = time.time() - t0
+ r = evaluate_all(preds, references, support_texts)
+ results[name] = r
+ print(f" ROUGE-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD: {r['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, time: {elapsed:.0f}s")
+
+ # Summary
+ print("\n" + "=" * 90)
+ print(f"{'Config':<25} {'ROUGE-1':<10} {'ROUGE-L':<10} {'METEOR':<10} {'SFD':<10}")
+ print("-" * 90)
+ for name, r in results.items():
+ print(f"{name:<25} {r['rouge1']:<10.4f} {r['rougeL']:<10.4f} {r['meteor']:<10.4f} {r['sfd']:<10.4f}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/test_length_fix.py b/scripts/test_length_fix.py
new file mode 100644
index 0000000..ad3f358
--- /dev/null
+++ b/scripts/test_length_fix.py
@@ -0,0 +1,203 @@
+"""Test different strategies to fix the output length issue."""
+
+import sys
+import os
+import time
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from eval.metrics import evaluate_all
+
+
+def run_with_blend(wrapper, examples, support_sets, head_module, d=64,
+ beta=0.05, steps=30, lr=0.05, blend_gamma=0.5,
+ min_new_tokens=64, max_new_tokens=512):
+ """Run CVH with logit blending: logits = (1-gamma)*base + gamma*cvh"""
+ device = 'cuda:1'
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ predictions = []
+ theta_norms = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens)
+ predictions.append(pred)
+ continue
+
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head_module,
+ d=d, lr=lr, steps=steps, beta=beta, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+ theta_norms.append(theta.norm().item())
+
+ # Generate with blending
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_with_blend(
+ wrapper, prompt, theta, head_module,
+ gamma=blend_gamma, max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ )
+ predictions.append(pred)
+
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 20 == 0:
+ print(f" {i+1}/{len(examples)}")
+
+ avg_norm = sum(theta_norms) / max(len(theta_norms), 1)
+ avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1)
+ return predictions, avg_norm, avg_len
+
+
+def generate_with_blend(wrapper, input_text, theta, head_module,
+ gamma=0.5, max_new_tokens=512, min_new_tokens=64):
+ """Generate with blended base + CVH logits."""
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": input_text},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+
+ generated_ids = []
+ past_key_values = None
+
+ for step in range(max_new_tokens):
+ if step == 0:
+ cur_input = input_ids
+ else:
+ cur_input = torch.tensor([[generated_ids[-1]]], device=wrapper.device)
+
+ with torch.no_grad():
+ outputs = wrapper.model(
+ input_ids=cur_input,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ use_cache=True,
+ return_dict=True,
+ )
+
+ past_key_values = outputs.past_key_values
+ last_hidden = outputs.hidden_states[-1][:, -1, :] # (1, H)
+
+ # Base logits
+ base_logits = torch.nn.functional.linear(
+ last_hidden.to(wrapper.lm_head_weight.dtype),
+ wrapper.lm_head_weight,
+ wrapper.model.lm_head.bias if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None else None,
+ ).float()
+
+ # CVH logits
+ h_prime = head_module.forward_fn(last_hidden.float(), theta)
+ cvh_logits = torch.nn.functional.linear(
+ h_prime.to(wrapper.lm_head_weight.dtype),
+ wrapper.lm_head_weight,
+ wrapper.model.lm_head.bias if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None else None,
+ ).float()
+
+ # Blend logits
+ logits = (1 - gamma) * base_logits + gamma * cvh_logits
+
+ # Suppress EOS before min_new_tokens
+ if step < min_new_tokens and wrapper.tokenizer.eos_token_id is not None:
+ logits[0, wrapper.tokenizer.eos_token_id] = float('-inf')
+
+ next_token = logits.argmax(dim=-1).item()
+
+ if next_token == wrapper.tokenizer.eos_token_id:
+ break
+
+ generated_ids.append(next_token)
+
+ return wrapper.tokenizer.decode(generated_ids, skip_special_tokens=True)
+
+
+def main():
+ N = 50
+ print(f"Loading data ({N} examples)...")
+ examples = load_longlamp('product_review_user', split='val')[:N]
+ K = 4
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ avg_ref_len = sum(len(r.split()) for r in references) / len(references)
+ print(f"Avg reference length: {avg_ref_len:.0f} words")
+
+ print("Loading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+ H = wrapper.hidden_size
+ device = 'cuda:1'
+
+ # Base with min_new_tokens=200
+ print("\n=== Base ===")
+ base_preds = []
+ for ex in examples:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0)
+ base_preds.append(pred)
+ base_r = evaluate_all(base_preds, references, support_texts)
+ base_len = sum(len(p.split()) for p in base_preds) / len(base_preds)
+ print(f" R-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}, len: {base_len:.0f}")
+
+ results = {'Base': {**base_r, 'avg_len': base_len}}
+
+ head = CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+ uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+
+ configs = [
+ # Standard CVH with higher min_new_tokens
+ ('CVH min=200', head, 1.0, 200),
+ # Blended CVH with different gammas
+ ('CVH blend=0.3 min=64', head, 0.3, 64),
+ ('CVH blend=0.5 min=64', head, 0.5, 64),
+ ('CVH blend=0.7 min=64', head, 0.7, 64),
+ ('CVH blend=0.5 min=128', head, 0.5, 128),
+ # Uncond blend
+ ('Uncond blend=0.5 min=64', uncond, 0.5, 64),
+ ]
+
+ for name, head_mod, gamma, min_tok in configs:
+ print(f"\n=== {name} ===")
+ t0 = time.time()
+ preds, avg_norm, avg_len = run_with_blend(
+ wrapper, examples, support_sets, head_mod, d=64,
+ beta=0.05, steps=30, lr=0.05, blend_gamma=gamma,
+ min_new_tokens=min_tok, max_new_tokens=512,
+ )
+ elapsed = time.time() - t0
+ r = evaluate_all(preds, references, support_texts)
+ results[name] = {**r, 'avg_len': avg_len}
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, SFD: {r['sfd']:.4f}, "
+ f"|theta|: {avg_norm:.3f}, len: {avg_len:.0f}, time: {elapsed:.0f}s")
+
+ # Summary
+ print("\n" + "=" * 100)
+ print(f"{'Config':<30} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8} {'Len':<6}")
+ print("-" * 100)
+ for name, r in results.items():
+ print(f"{name:<30} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} "
+ f"{r['meteor']:<8.4f} {r['sfd']:<8.4f} {r.get('avg_len', 0):<6.0f}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/test_normalized_cvh.py b/scripts/test_normalized_cvh.py
new file mode 100644
index 0000000..0057341
--- /dev/null
+++ b/scripts/test_normalized_cvh.py
@@ -0,0 +1,126 @@
+"""Quick test: normalized CVH vs Uncond with blending."""
+
+import sys
+import os
+import time
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from eval.metrics import evaluate_all
+
+
+def run_blended(wrapper, examples, support_sets, head_module, d=64,
+ beta=0.05, steps=30, lr=0.05, blend_gamma=0.5):
+ device = 'cuda:1'
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ predictions = []
+ theta_norms = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=512)
+ predictions.append(pred)
+ continue
+
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head_module,
+ d=d, lr=lr, steps=steps, beta=beta, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=(i == 0),
+ )
+ theta_norms.append(theta.norm().item())
+
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_with_head_blended(
+ prompt, theta, head_module.forward_fn,
+ blend_gamma=blend_gamma,
+ max_new_tokens=512, min_new_tokens=128,
+ temperature=0.0,
+ )
+ predictions.append(pred)
+
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 20 == 0:
+ print(f" {i+1}/{len(examples)}")
+
+ avg_norm = sum(theta_norms) / max(len(theta_norms), 1)
+ avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1)
+ return predictions, avg_norm, avg_len
+
+
+def main():
+ N = 100
+ print(f"Loading data ({N} examples)...")
+ examples = load_longlamp('product_review_user', split='val')[:N]
+ K = 4
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ print("Loading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+ H = wrapper.hidden_size
+ device = 'cuda:1'
+
+ # Base
+ print("\n=== Base ===")
+ base_preds = []
+ for i, ex in enumerate(examples):
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0)
+ base_preds.append(pred)
+ if (i + 1) % 20 == 0:
+ print(f" {i+1}/{N}")
+ base_r = evaluate_all(base_preds, references, support_texts)
+ base_len = sum(len(p.split()) for p in base_preds) / len(base_preds)
+ print(f" R-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, SFD: {base_r['sfd']:.4f}, len: {base_len:.0f}")
+
+ results = {'Base': base_r}
+
+ configs = [
+ ('Uncond d=64 g=0.5', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.5),
+ ('CVH-norm d=64 g=0.5', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.5),
+ ('CVH-norm d=64 g=0.3', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.3),
+ ('CVH-norm d=64 g=0.7', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64, 0.7),
+ ]
+
+ for name, head, d, gamma in configs:
+ print(f"\n=== {name} ===")
+ t0 = time.time()
+ preds, avg_norm, avg_len = run_blended(
+ wrapper, examples, support_sets, head, d=d,
+ beta=0.05, steps=30, lr=0.05, blend_gamma=gamma,
+ )
+ elapsed = time.time() - t0
+ r = evaluate_all(preds, references, support_texts)
+ results[name] = r
+ print(f" R-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, SFD: {r['sfd']:.4f}, "
+ f"|θ|: {avg_norm:.3f}, len: {avg_len:.0f}, time: {elapsed:.0f}s")
+
+ # Summary
+ print("\n" + "=" * 90)
+ print(f"{'Config':<30} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8}")
+ print("-" * 90)
+ for name, r in results.items():
+ print(f"{name:<30} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} "
+ f"{r['meteor']:<8.4f} {r['sfd']:<8.4f}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/test_svd_cvh.py b/scripts/test_svd_cvh.py
new file mode 100644
index 0000000..1f93bd1
--- /dev/null
+++ b/scripts/test_svd_cvh.py
@@ -0,0 +1,141 @@
+"""Test SVD-based CVH vs random basis CVH vs Unconditional."""
+
+import sys
+import os
+import time
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, UnconditionalHead
+from models.svd_cvh import SVDCVHHead, SVDUncondHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from eval.metrics import evaluate_all
+
+
+def run_head(wrapper, examples, support_sets, head_module, d=64,
+ beta=0.05, steps=30, lr=0.05, max_new_tokens=512,
+ min_new_tokens=64):
+ device = 'cuda:1'
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ predictions = []
+ theta_norms = []
+
+ for i, (ex, support) in enumerate(zip(examples, support_sets)):
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=max_new_tokens)
+ predictions.append(pred)
+ continue
+
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head_module,
+ d=d, lr=lr, steps=steps, beta=beta, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=(i == 0),
+ )
+ theta_norms.append(theta.norm().item())
+
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_with_head(
+ prompt, theta, head_module.forward_fn,
+ max_new_tokens=max_new_tokens, temperature=0.0,
+ min_new_tokens=min_new_tokens,
+ )
+ predictions.append(pred)
+
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 20 == 0:
+ print(f" {i+1}/{len(examples)}")
+
+ avg_norm = sum(theta_norms) / max(len(theta_norms), 1)
+ # Check avg output length
+ avg_len = sum(len(p.split()) for p in predictions) / max(len(predictions), 1)
+ return predictions, avg_norm, avg_len
+
+
+def main():
+ N = 50
+ print(f"Loading data ({N} examples)...")
+ examples = load_longlamp('product_review_user', split='val')[:N]
+ K = 4
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ avg_ref_len = sum(len(r.split()) for r in references) / len(references)
+ print(f"Avg reference length: {avg_ref_len:.0f} words")
+
+ print("Loading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+ H = wrapper.hidden_size
+ device = 'cuda:1'
+
+ # Base
+ print("\n=== Base ===")
+ base_preds = []
+ for ex in examples:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = wrapper.generate_base(prompt, max_new_tokens=512, temperature=0.0)
+ base_preds.append(pred)
+ base_r = evaluate_all(base_preds, references, support_texts)
+ base_len = sum(len(p.split()) for p in base_preds) / len(base_preds)
+ print(f" ROUGE-L: {base_r['rougeL']:.4f}, METEOR: {base_r['meteor']:.4f}, "
+ f"SFD: {base_r['sfd']:.4f}, avg_len: {base_len:.0f}")
+
+ results = {}
+ results['Base'] = {**base_r, 'avg_len': base_len}
+
+ # SVD-based heads
+ print("\nComputing SVD of lm_head...")
+ svd_cvh = SVDCVHHead(wrapper.lm_head_weight, d=64, alpha=0.1).to(device)
+ svd_uncond = SVDUncondHead(wrapper.lm_head_weight, d=64, alpha=0.1).to(device)
+
+ configs = [
+ ('Random CVH d=64', CVHHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ ('Random Uncond d=64', UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device), 64),
+ ('SVD CVH d=64', svd_cvh, 64),
+ ('SVD Uncond d=64', svd_uncond, 64),
+ # Try different alpha with SVD
+ ('SVD CVH d=64 a=0.05', SVDCVHHead(wrapper.lm_head_weight, d=64, alpha=0.05).to(device), 64),
+ ('SVD CVH d=64 a=0.2', SVDCVHHead(wrapper.lm_head_weight, d=64, alpha=0.2).to(device), 64),
+ ]
+
+ for name, head, d in configs:
+ print(f"\n=== {name} ===")
+ t0 = time.time()
+ preds, avg_norm, avg_len = run_head(
+ wrapper, examples, support_sets, head, d=d,
+ beta=0.05, steps=30, lr=0.05, max_new_tokens=512,
+ min_new_tokens=64,
+ )
+ elapsed = time.time() - t0
+ r = evaluate_all(preds, references, support_texts)
+ results[name] = {**r, 'avg_len': avg_len}
+ print(f" ROUGE-L: {r['rougeL']:.4f}, METEOR: {r['meteor']:.4f}, "
+ f"SFD: {r['sfd']:.4f}, avg|theta|: {avg_norm:.3f}, "
+ f"avg_len: {avg_len:.0f}, time: {elapsed:.0f}s")
+
+ # Summary
+ print("\n" + "=" * 100)
+ print(f"{'Config':<25} {'R-1':<8} {'R-L':<8} {'METEOR':<8} {'SFD':<8} {'Len':<6}")
+ print("-" * 100)
+ for name, r in results.items():
+ print(f"{name:<25} {r['rouge1']:<8.4f} {r['rougeL']:<8.4f} "
+ f"{r['meteor']:<8.4f} {r['sfd']:<8.4f} {r.get('avg_len', 0):<6.0f}")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/theta_analysis.py b/scripts/theta_analysis.py
new file mode 100644
index 0000000..94d4010
--- /dev/null
+++ b/scripts/theta_analysis.py
@@ -0,0 +1,281 @@
+"""User-state geometry / representational alignment analysis.
+
+Computes:
+1. RSA: Spearman(cos(theta_u, theta_v), cos(phi_u, phi_v)) for all-style and -len/newline
+2. Self-consistency: Delta_self = E_u[cos(theta_a, theta_b)] - E_{u!=v}[cos(theta_a, theta_v)]
+3. Ridge probe: R^2 for predicting style features from theta
+4. PCA visualization
+"""
+
+import sys
+import os
+import json
+import numpy as np
+from scipy import stats
+from sklearn.linear_model import Ridge
+from sklearn.model_selection import cross_val_score
+import torch
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.style_features import extract_style_features, FEATURE_NAMES
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+
+
+def collect_thetas_and_styles(wrapper, examples, K=4, seed=0):
+ """Collect theta_u and style prototypes for all users."""
+ device = 'cuda:1'
+ H = wrapper.hidden_size
+ head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ thetas = []
+ style_protos = []
+ user_ids = []
+
+ for i, ex in enumerate(examples):
+ support = select_k_profile_items(ex['profile_items'], K, seed=seed)
+ cached_h = cache_support_hidden_states(wrapper, support, ex['task'])
+ if not cached_h:
+ continue
+
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=head,
+ d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+
+ thetas.append(theta.cpu().numpy())
+
+ # Compute style prototype
+ support_texts = [s['support_output'] for s in support]
+ features_list = [extract_style_features(t) for t in support_texts]
+ proto = np.mean(features_list, axis=0)
+ style_protos.append(proto)
+ user_ids.append(ex['user_id'])
+
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 40 == 0:
+ print(f" Collected {i+1}/{len(examples)}")
+
+ return np.array(thetas), np.array(style_protos), user_ids
+
+
+def compute_rsa(thetas, style_protos, exclude_indices=None):
+ """Compute RSA: Spearman correlation between theta similarity and style similarity."""
+ N = len(thetas)
+
+ # Theta cosine similarity matrix
+ theta_norms = np.linalg.norm(thetas, axis=1, keepdims=True)
+ theta_norms = np.maximum(theta_norms, 1e-8)
+ theta_normed = thetas / theta_norms
+ theta_sim = theta_normed @ theta_normed.T
+
+ # Style cosine similarity matrix
+ if exclude_indices is not None:
+ style = np.delete(style_protos, exclude_indices, axis=1)
+ else:
+ style = style_protos.copy()
+
+ style_norms = np.linalg.norm(style, axis=1, keepdims=True)
+ style_norms = np.maximum(style_norms, 1e-8)
+ style_normed = style / style_norms
+ style_sim = style_normed @ style_normed.T
+
+ # Extract upper triangle
+ idx = np.triu_indices(N, k=1)
+ theta_upper = theta_sim[idx]
+ style_upper = style_sim[idx]
+
+ rho, pval = stats.spearmanr(theta_upper, style_upper)
+ return rho, pval
+
+
+def compute_self_consistency(wrapper, examples, K=4):
+ """Compute self-consistency by fitting theta with different support subsets."""
+ device = 'cuda:1'
+ H = wrapper.hidden_size
+ head = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(device)
+
+ lm_head_bias = None
+ if hasattr(wrapper.model.lm_head, 'bias') and wrapper.model.lm_head.bias is not None:
+ lm_head_bias = wrapper.model.lm_head.bias.data
+
+ thetas_a = []
+ thetas_b = []
+ valid_indices = []
+
+ for i, ex in enumerate(examples):
+ profile = ex['profile_items']
+ if len(profile) < 2 * K:
+ continue
+
+ # Two different subsets
+ support_a = select_k_profile_items(profile, K, seed=100)
+ support_b = select_k_profile_items(profile, K, seed=200)
+
+ cached_a = cache_support_hidden_states(wrapper, support_a, ex['task'])
+ cached_b = cache_support_hidden_states(wrapper, support_b, ex['task'])
+
+ if not cached_a or not cached_b:
+ continue
+
+ theta_a = fit_theta(
+ cached_h=cached_a, lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias, head_module=head,
+ d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+ theta_b = fit_theta(
+ cached_h=cached_b, lm_head_weight=wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias, head_module=head,
+ d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=device, verbose=False,
+ )
+
+ thetas_a.append(theta_a.cpu().numpy())
+ thetas_b.append(theta_b.cpu().numpy())
+ valid_indices.append(i)
+
+ del cached_a, cached_b, theta_a, theta_b
+ torch.cuda.empty_cache()
+
+ if (i + 1) % 40 == 0:
+ print(f" Self-consistency: {i+1}/{len(examples)} ({len(valid_indices)} valid)")
+
+ thetas_a = np.array(thetas_a)
+ thetas_b = np.array(thetas_b)
+ N = len(thetas_a)
+
+ if N < 5:
+ return 0.0, 0.0, 0.0
+
+ # Self similarity: cos(theta_a_u, theta_b_u)
+ self_cos = []
+ for u in range(N):
+ cos = np.dot(thetas_a[u], thetas_b[u]) / (
+ np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[u]) + 1e-8)
+ self_cos.append(cos)
+ avg_self = np.mean(self_cos)
+
+ # Cross similarity: cos(theta_a_u, theta_b_v) for u != v
+ cross_cos = []
+ for u in range(N):
+ for v in range(N):
+ if u == v:
+ continue
+ cos = np.dot(thetas_a[u], thetas_b[v]) / (
+ np.linalg.norm(thetas_a[u]) * np.linalg.norm(thetas_b[v]) + 1e-8)
+ cross_cos.append(cos)
+ avg_cross = np.mean(cross_cos)
+
+ delta_self = avg_self - avg_cross
+ return avg_self, avg_cross, delta_self
+
+
+def compute_ridge_probe(thetas, style_protos):
+ """Probe: predict each style feature from theta using Ridge regression."""
+ results = {}
+ N = len(thetas)
+
+ for i, feat_name in enumerate(FEATURE_NAMES):
+ y = style_protos[:, i]
+
+ # Check if target has variance
+ if np.std(y) < 1e-8:
+ results[feat_name] = 0.0
+ continue
+
+ ridge = Ridge(alpha=1.0)
+ scores = cross_val_score(ridge, thetas, y, cv=min(5, N), scoring='r2')
+ results[feat_name] = max(np.mean(scores), 0.0) # Clip negative R2 to 0
+
+ return results
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_eval', type=int, default=200)
+ parser.add_argument('--config', type=str, default='product_review_user')
+ args = parser.parse_args()
+
+ N = args.num_eval
+ print(f"=== Theta Analysis: {args.config}, N={N} ===")
+
+ print("\nLoading data...")
+ examples = load_longlamp(args.config, split='val')[:N]
+ print(f"Loaded {len(examples)} examples")
+
+ print("\nLoading model...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device='cuda:1')
+
+ # 1. Collect thetas and style prototypes
+ print("\n--- Collecting thetas and style prototypes ---")
+ thetas, style_protos, user_ids = collect_thetas_and_styles(wrapper, examples, K=4, seed=0)
+ print(f"Collected {len(thetas)} vectors")
+
+ # 2. RSA
+ print("\n--- RSA (Representational Similarity Analysis) ---")
+ rho_all, pval_all = compute_rsa(thetas, style_protos)
+ # Exclude length (index 0) and newline_rate (index 3)
+ rho_nolen, pval_nolen = compute_rsa(thetas, style_protos, exclude_indices=[0, 3])
+ print(f" rho_all: {rho_all:.4f} (p={pval_all:.2e})")
+ print(f" rho_-len/newline: {rho_nolen:.4f} (p={pval_nolen:.2e})")
+
+ # 3. Self-consistency
+ print("\n--- Self-Consistency ---")
+ avg_self, avg_cross, delta_self = compute_self_consistency(wrapper, examples, K=4)
+ print(f" avg_self_cos: {avg_self:.4f}")
+ print(f" avg_cross_cos: {avg_cross:.4f}")
+ print(f" Delta_self: {delta_self:.4f}")
+
+ # 4. Ridge probe
+ print("\n--- Ridge Probe (R^2) ---")
+ probe_results = compute_ridge_probe(thetas, style_protos)
+ for feat_name in FEATURE_NAMES:
+ r2 = probe_results[feat_name]
+ print(f" {feat_name:<20}: R^2 = {r2:.4f}")
+
+ # Summary: the 6 key numbers
+ print("\n" + "=" * 60)
+ print("KEY NUMBERS FOR PAPER DECISION")
+ print("=" * 60)
+ print(f" rho_all: {rho_all:.4f}")
+ print(f" rho_-len/newline: {rho_nolen:.4f}")
+ print(f" Delta_self: {delta_self:.4f}")
+ print(f" R^2_TTR: {probe_results.get('TTR', 0.0):.4f}")
+ print(f" R^2_first_person: {probe_results.get('first_person_rate', 0.0):.4f}")
+ print(f" R^2_newline: {probe_results.get('newline_rate', 0.0):.4f}")
+
+ # Save results
+ os.makedirs('outputs/analysis', exist_ok=True)
+ save_data = {
+ 'rsa_all': {'rho': float(rho_all), 'pval': float(pval_all)},
+ 'rsa_nolen': {'rho': float(rho_nolen), 'pval': float(pval_nolen)},
+ 'self_consistency': {'avg_self': float(avg_self), 'avg_cross': float(avg_cross), 'delta_self': float(delta_self)},
+ 'probe_r2': {k: float(v) for k, v in probe_results.items()},
+ 'num_users': len(thetas),
+ 'thetas': [[float(x) for x in row] for row in thetas],
+ 'style_protos': [[float(x) for x in row] for row in style_protos],
+ 'user_ids': user_ids,
+ }
+ with open('outputs/analysis/theta_analysis.json', 'w') as f:
+ json.dump(save_data, f, indent=2)
+ print("\nSaved to outputs/analysis/theta_analysis.json")
+
+
+if __name__ == '__main__':
+ main()