summaryrefslogtreecommitdiff
path: root/resulets/scripts/run_all_methods.py
diff options
context:
space:
mode:
Diffstat (limited to 'resulets/scripts/run_all_methods.py')
-rw-r--r--resulets/scripts/run_all_methods.py908
1 files changed, 908 insertions, 0 deletions
diff --git a/resulets/scripts/run_all_methods.py b/resulets/scripts/run_all_methods.py
new file mode 100644
index 0000000..3809333
--- /dev/null
+++ b/resulets/scripts/run_all_methods.py
@@ -0,0 +1,908 @@
+"""Unified evaluation pipeline: all methods, all per-user data saved.
+
+CRASH-SAFE: Each example is appended to a JSONL file immediately after
+computation. If the process is killed, all completed examples are preserved.
+Already-complete methods are automatically skipped on re-run.
+
+Usage:
+ python scripts/run_all_methods.py --task review --setting user --device cuda:0
+ python scripts/run_all_methods.py --task review --setting user --methods base,uph,lora
+"""
+
+import sys
+import os
+import json
+import time
+import numpy as np
+import torch
+from scipy import stats
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from data.longlamp import load_longlamp, select_k_profile_items
+from data.templates import build_query_prompt, build_prompt_with_examples
+from data.style_features import compute_sfd, compute_feature_deltas
+from transformers import AutoModelForCausalLM
+from models.qwen_wrapper import QwenWrapper
+from models.cvh import CVHHead, LMHeadUpdate, UnconditionalHead
+from adapt.cache_hidden import cache_support_hidden_states
+from adapt.fit_theta import fit_theta
+from adapt.fit_theta_lm_head_update import fit_theta_lm_head_update
+from baselines.peft_baseline import (
+ PEFTBaseline, get_lora_config, get_tiny_lora_config, get_vera_config,
+ get_prompt_tuning_config, get_prefix_tuning_config,
+)
+from baselines.bm25_top1 import bm25_select_top1
+from baselines.dense_retrieval import (
+ DENSE_RETRIEVER_CONFIGS,
+ DenseRetriever,
+ get_dense_retriever_config,
+)
+from baselines.logit_bias import (
+ build_global_log_probs,
+ build_user_unigram_bias,
+ fit_sparse_logit_bias,
+ generate_with_logit_bias,
+)
+from baselines.profile_based import generate_profile, build_profile_conditioned_prompt
+from eval.metrics import compute_rouge, compute_meteor
+
+
+ALL_METHODS = [
+ 'base', 'uph', 'cvh', 'lm_head_update',
+ 'user_unigram_bias', 'learned_sparse_logit_bias',
+ 'prompt_all_k', 'bm25_top1', 'dense_top1',
+ 'dense_minilm_top1', 'dense_mpnet_top1', 'dense_e5_top1', 'dense_bge_top1',
+ 'profile_based',
+ 'lora', 'tiny_lora', 'vera',
+ 'prompt_tuning_5', 'prompt_tuning_10', 'prompt_tuning_20',
+ 'prefix_tuning_5', 'prefix_tuning_10',
+]
+
+
+def compute_per_user_metrics(pred, ref, support_texts):
+ r = compute_rouge([pred], [ref])
+ m = compute_meteor([pred], [ref])
+ p = pred if pred.strip() else "empty"
+ sfd_all = compute_sfd(p, support_texts, exclude_length=False)
+ sfd_nolen = compute_sfd(p, support_texts, exclude_length=True)
+ deltas = compute_feature_deltas(p, support_texts)
+ return {
+ 'rouge1': r['rouge1'],
+ 'rougeL': r['rougeL'],
+ 'meteor': m,
+ 'sfd_all': sfd_all,
+ 'sfd_nolen': sfd_nolen,
+ 'length': len(pred.split()),
+ 'feature_deltas': {k: v['delta'] for k, v in deltas.items()},
+ }
+
+
+def generate_greedy(wrapper, prompt, max_new_tokens=512, min_new_tokens=128):
+ chat_messages = [
+ {"role": "system", "content": "You are a helpful writing assistant."},
+ {"role": "user", "content": prompt},
+ ]
+ prompt_text = wrapper.tokenizer.apply_chat_template(
+ chat_messages, tokenize=False, add_generation_prompt=True
+ )
+ input_ids = wrapper.tokenizer.encode(prompt_text, return_tensors="pt").to(wrapper.device)
+ with torch.no_grad():
+ outputs = wrapper.model.generate(
+ input_ids,
+ max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens,
+ temperature=None, top_p=None, do_sample=False,
+ pad_token_id=wrapper.tokenizer.pad_token_id,
+ )
+ return wrapper.tokenizer.decode(outputs[0, input_ids.shape[1]:], skip_special_tokens=True)
+
+
+# ─── Incremental saving ──────────────────────────────────────────────
+
+def get_method_dir(output_dir, task, setting, K, method_name, d=64):
+ """Get the output directory for a method."""
+ exp_dir = os.path.join(output_dir, f"{task}_{setting}_K{K}")
+ method_label = f"uph_d{d}" if method_name == 'uph' and d != 64 else method_name
+ return os.path.join(exp_dir, method_label), method_label
+
+
+def is_method_complete(method_dir, N):
+ """Check if a method already has a complete per_user.json."""
+ path = os.path.join(method_dir, 'per_user.json')
+ if not os.path.exists(path):
+ return False
+ try:
+ with open(path) as f:
+ data = json.load(f)
+ return len(data.get('per_user', [])) >= N
+ except:
+ return False
+
+
+def append_jsonl(path, entry):
+ """Append one JSON entry to a JSONL file (crash-safe)."""
+ with open(path, 'a') as f:
+ f.write(json.dumps(entry, default=str) + '\n')
+
+
+def read_jsonl(path):
+ """Read all entries from a JSONL file."""
+ entries = []
+ if os.path.exists(path):
+ with open(path) as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ entries.append(json.loads(line))
+ return entries
+
+
+def finalize_method(method_dir, method_label, per_user, task, setting, K, d=64):
+ """Write final per_user.json from completed per-user list."""
+ agg = {
+ 'rougeL': float(np.mean([u['metrics']['rougeL'] for u in per_user])),
+ 'meteor': float(np.mean([u['metrics']['meteor'] for u in per_user])),
+ 'sfd_nolen': float(np.mean([u['metrics']['sfd_nolen'] for u in per_user])),
+ 'avg_len': float(np.mean([u['metrics']['length'] for u in per_user])),
+ }
+ save_data = {
+ 'per_user': per_user,
+ 'aggregate': agg,
+ 'num_examples': len(per_user),
+ 'task': task, 'setting': setting, 'K': K,
+ 'method': method_label,
+ 'decode_policy': 'greedy, min=128, max=512',
+ }
+ if 'uph' in method_label:
+ save_data['d'] = d
+ path = os.path.join(method_dir, 'per_user.json')
+ with open(path, 'w') as f:
+ json.dump(save_data, f, indent=2, default=str)
+ print(f" Saved: {path} ({len(per_user)} examples)")
+
+
+# ─── Method runners ──────────────────────────────────────────────────
+
+class MethodRunner:
+ def __init__(
+ self,
+ wrapper,
+ device,
+ dense_retriever=None,
+ uph_d=64,
+ bias_top_m=512,
+ unigram_scale=0.5,
+ sparse_bias_lr=0.05,
+ sparse_bias_steps=30,
+ ):
+ self.wrapper = wrapper
+ self.device = device
+ self.dense_retriever = dense_retriever
+ self.dense_retrievers = {}
+ self.uph_d = uph_d
+ self.bias_top_m = bias_top_m
+ self.unigram_scale = unigram_scale
+ self.sparse_bias_lr = sparse_bias_lr
+ self.sparse_bias_steps = sparse_bias_steps
+
+ def _make_entry(self, ex, ref, stexts, K, pred, timing, extra=None):
+ metrics = compute_per_user_metrics(pred, ref, stexts)
+ entry = {
+ 'example_id': ex['example_id'],
+ 'user_id': ex['user_id'],
+ 'prediction': pred,
+ 'reference': ref,
+ 'support_texts': stexts,
+ 'K': K,
+ 'metrics': metrics,
+ **timing,
+ }
+ if extra:
+ entry.update(extra)
+ return entry
+
+ def run(self, method_name, examples, support_sets, references, support_texts,
+ N, method_dir, method_label, task, setting, K, d=64):
+ """Run a method with incremental JSONL saving. Returns per_user list."""
+
+ dispatch = {
+ 'base': self._run_base,
+ 'uph': self._run_uph,
+ 'cvh': self._run_cvh,
+ 'lm_head_update': self._run_lm_head_update,
+ 'user_unigram_bias': self._run_user_unigram_bias,
+ 'learned_sparse_logit_bias': self._run_learned_sparse_logit_bias,
+ 'prompt_all_k': self._run_prompt_all_k,
+ 'bm25_top1': self._run_bm25_top1,
+ 'dense_top1': self._run_dense_top1,
+ 'profile_based': self._run_profile_based,
+ 'lora': lambda *a, **kw: self._run_peft(*a, config=get_lora_config(rank=8), lr=1e-4, desc='LoRA r=8', **kw),
+ 'tiny_lora': lambda *a, **kw: self._run_peft(*a, config=get_tiny_lora_config(rank=1), lr=1e-4, desc='Tiny LoRA r=1', **kw),
+ 'vera': lambda *a, **kw: self._run_peft(*a, config=get_vera_config(rank=256), lr=1e-3, desc='VeRA r=256', **kw),
+ 'prompt_tuning_5': lambda *a, **kw: self._run_peft(*a, config=get_prompt_tuning_config(5), lr=1e-3, desc='PromptTuning L=5', steps=100, **kw),
+ 'prompt_tuning_10': lambda *a, **kw: self._run_peft(*a, config=get_prompt_tuning_config(10), lr=1e-3, desc='PromptTuning L=10', steps=100, **kw),
+ 'prompt_tuning_20': lambda *a, **kw: self._run_peft(*a, config=get_prompt_tuning_config(20), lr=1e-3, desc='PromptTuning L=20', steps=100, **kw),
+ 'prefix_tuning_5': lambda *a, **kw: self._run_peft(*a, config=get_prefix_tuning_config(5), lr=5e-4, desc='PrefixTuning L=5', steps=100, **kw),
+ 'prefix_tuning_10': lambda *a, **kw: self._run_peft(*a, config=get_prefix_tuning_config(10), lr=5e-4, desc='PrefixTuning L=10', steps=100, **kw),
+ }
+
+ if method_name not in dispatch:
+ if method_name in DENSE_RETRIEVER_CONFIGS:
+ run_fn = lambda *a, **kw: self._run_dense_configured(method_name, *a, **kw)
+ else:
+ print(f"Unknown method: {method_name}")
+ return []
+ else:
+ run_fn = dispatch[method_name]
+
+ os.makedirs(method_dir, exist_ok=True)
+ jsonl_path = os.path.join(method_dir, 'progress.jsonl')
+
+ # Resume: check how many examples already done
+ existing = read_jsonl(jsonl_path)
+ start_idx = len(existing)
+
+ if start_idx >= N:
+ print(f"\n--- {method_name} --- SKIPPED (already {start_idx}/{N} done)")
+ per_user = existing[:N]
+ else:
+ if start_idx > 0:
+ print(f"\n--- {method_name} --- RESUMING from {start_idx}/{N}")
+ else:
+ print(f"\n--- {method_name} ---")
+
+ per_user = run_fn(
+ examples, support_sets, references, support_texts, N,
+ jsonl_path=jsonl_path, start_idx=start_idx, existing=existing,
+ )
+
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user])
+ avg_sfd = np.mean([u['metrics']['sfd_nolen'] for u in per_user])
+ print(f" Mean R-L: {avg_rl:.4f}, SFD_-len: {avg_sfd:.4f}")
+
+ # Write final per_user.json
+ finalize_method(method_dir, method_label, per_user, task, setting, K, d)
+ return per_user
+
+ # --- Individual method runners ---
+ # All accept jsonl_path, start_idx, existing for resume support
+
+ def _run_base(self, examples, support_sets, references, support_texts, N,
+ jsonl_path, start_idx, existing):
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex = examples[i]
+ t0 = time.time()
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_greedy(self.wrapper, prompt)
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support_sets[i]),
+ pred, {'gen_time': time.time() - t0}
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 40 == 0:
+ print(f" {i+1}/{N}")
+ return per_user
+
+ def _run_prompt_all_k(self, examples, support_sets, references, support_texts, N,
+ jsonl_path, start_idx, existing):
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex, support = examples[i], support_sets[i]
+ t0 = time.time()
+ prompt = build_prompt_with_examples(ex['query_input'], support, ex['task'])
+ pred = generate_greedy(self.wrapper, prompt)
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support),
+ pred, {'gen_time': time.time() - t0}
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 40 == 0:
+ print(f" {i+1}/{N}")
+ return per_user
+
+ def _run_bm25_top1(self, examples, support_sets, references, support_texts, N,
+ jsonl_path, start_idx, existing):
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex, support = examples[i], support_sets[i]
+ t0 = time.time()
+ selected = bm25_select_top1(ex['query_input'], support)
+ prompt = build_prompt_with_examples(ex['query_input'], selected, ex['task'])
+ pred = generate_greedy(self.wrapper, prompt)
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support),
+ pred, {'gen_time': time.time() - t0}
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 40 == 0:
+ print(f" {i+1}/{N}")
+ return per_user
+
+ def _run_dense_top1(self, examples, support_sets, references, support_texts, N,
+ jsonl_path, start_idx, existing):
+ if self.dense_retriever is None:
+ self.dense_retriever = DenseRetriever(
+ model_name='sentence-transformers/all-MiniLM-L6-v2',
+ device='cpu',
+ text_mode='input',
+ normalize_embeddings=True,
+ )
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex, support = examples[i], support_sets[i]
+ t0 = time.time()
+ selected, retrieval = self.dense_retriever.retrieve_top_k(
+ ex['query_input'], support, k=1, return_metadata=True
+ )
+ prompt = build_prompt_with_examples(ex['query_input'], selected, ex['task'])
+ pred = generate_greedy(self.wrapper, prompt)
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support),
+ pred, {'gen_time': time.time() - t0},
+ extra={
+ 'retriever_model': self.dense_retriever.model_name,
+ 'retrieval_text_mode': self.dense_retriever.text_mode,
+ 'retrieval': retrieval,
+ },
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 40 == 0:
+ print(f" {i+1}/{N}")
+ return per_user
+
+ def _get_dense_retriever(self, config):
+ key = (
+ config.model_name,
+ config.text_mode,
+ config.query_prefix,
+ config.passage_prefix,
+ config.normalize_embeddings,
+ )
+ if key not in self.dense_retrievers:
+ self.dense_retrievers[key] = DenseRetriever(
+ model_name=config.model_name,
+ device='cpu',
+ text_mode=config.text_mode,
+ query_prefix=config.query_prefix,
+ passage_prefix=config.passage_prefix,
+ normalize_embeddings=config.normalize_embeddings,
+ )
+ return self.dense_retrievers[key]
+
+ def _run_dense_configured(self, method_name, examples, support_sets, references, support_texts, N,
+ jsonl_path, start_idx, existing):
+ config = get_dense_retriever_config(method_name)
+ retriever = self._get_dense_retriever(config)
+ print(
+ f" Dense retriever: {config.model_name}, "
+ f"text_mode={config.text_mode}, year={config.citation_year}"
+ )
+
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex, support = examples[i], support_sets[i]
+ t0 = time.time()
+ selected, retrieval = retriever.retrieve_top_k(
+ ex['query_input'], support, k=1, return_metadata=True
+ )
+ prompt = build_prompt_with_examples(ex['query_input'], selected, ex['task'])
+ pred = generate_greedy(self.wrapper, prompt)
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support),
+ pred, {'gen_time': time.time() - t0},
+ extra={
+ 'retriever_model': config.model_name,
+ 'retrieval_text_mode': config.text_mode,
+ 'retriever_year': config.citation_year,
+ 'retriever_description': config.description,
+ 'retrieval': retrieval,
+ },
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 40 == 0:
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user])
+ print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})")
+ return per_user
+
+ def _run_profile_based(self, examples, support_sets, references, support_texts, N,
+ jsonl_path, start_idx, existing):
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex, support = examples[i], support_sets[i]
+ t0 = time.time()
+ profile = generate_profile(self.wrapper, support, ex['task'])
+ prompt = build_profile_conditioned_prompt(ex['query_input'], profile, ex['task'])
+ pred = generate_greedy(self.wrapper, prompt)
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support),
+ pred, {'gen_time': time.time() - t0},
+ extra={'profile_summary': profile},
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 40 == 0:
+ print(f" {i+1}/{N}")
+ return per_user
+
+ def _run_uph(self, examples, support_sets, references, support_texts, N,
+ jsonl_path, start_idx, existing):
+ d = self.uph_d
+ H = self.wrapper.hidden_size
+ uncond = UnconditionalHead(H, d=d, alpha=0.1, basis_seed=42).to(self.device)
+ print(f" UPH d={d}, params={d}, bytes={d*2}")
+ lm_head_bias = None
+ if hasattr(self.wrapper.model.lm_head, 'bias') and self.wrapper.model.lm_head.bias is not None:
+ lm_head_bias = self.wrapper.model.lm_head.bias.data
+
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex, support = examples[i], support_sets[i]
+ t0 = time.time()
+ cached_h = cache_support_hidden_states(self.wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_greedy(self.wrapper, prompt)
+ else:
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=self.wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=uncond,
+ d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=self.device,
+ )
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ delta_h = uncond.alpha * (uncond.U.float() @ theta.to(self.device).float())
+ logit_bias = 0.5 * torch.mv(self.wrapper.lm_head_weight.float(), delta_h)
+ pred = generate_with_logit_bias(
+ self.wrapper,
+ prompt,
+ logit_bias.detach().cpu(),
+ max_new_tokens=512,
+ min_new_tokens=128,
+ temperature=0.0,
+ )
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support),
+ pred, {'adapt_time': time.time() - t0}
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 40 == 0:
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user])
+ print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})")
+ return per_user
+
+ def _run_cvh(self, examples, support_sets, references, support_texts, N,
+ jsonl_path, start_idx, existing):
+ d = self.uph_d
+ H = self.wrapper.hidden_size
+ cvh = CVHHead(H, d=d, alpha=0.1, basis_seed=42).to(self.device)
+ print(f" CVH d={d}, params={d}, bytes={d*2}")
+ lm_head_bias = None
+ if hasattr(self.wrapper.model.lm_head, 'bias') and self.wrapper.model.lm_head.bias is not None:
+ lm_head_bias = self.wrapper.model.lm_head.bias.data
+
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex, support = examples[i], support_sets[i]
+ t0 = time.time()
+ cached_h = cache_support_hidden_states(self.wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_greedy(self.wrapper, prompt)
+ else:
+ theta = fit_theta(
+ cached_h=cached_h,
+ lm_head_weight=self.wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_module=cvh,
+ d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ max_grad_norm=5.0, device=self.device,
+ )
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = self.wrapper.generate_with_head_blended(
+ prompt, theta, cvh.forward_fn,
+ blend_gamma=0.5, max_new_tokens=512,
+ min_new_tokens=128, temperature=0.0,
+ )
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support),
+ pred, {'adapt_time': time.time() - t0}
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 40 == 0:
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user])
+ print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})")
+ return per_user
+
+ def _run_lm_head_update(self, examples, support_sets, references, support_texts, N,
+ jsonl_path, start_idx, existing):
+ d = self.uph_d
+ H = self.wrapper.hidden_size
+ vocab_size = self.wrapper.lm_head_weight.shape[0]
+ head_update = LMHeadUpdate(H, vocab_size, d=d, alpha=0.1, basis_seed=42).to(self.device)
+ print(
+ f" LM-head update d={d}, user params={d}, "
+ f"fixed basis params={H*d + vocab_size*d}, bytes={d*2}"
+ )
+ lm_head_bias = None
+ if hasattr(self.wrapper.model.lm_head, 'bias') and self.wrapper.model.lm_head.bias is not None:
+ lm_head_bias = self.wrapper.model.lm_head.bias.data
+
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex, support = examples[i], support_sets[i]
+ t0 = time.time()
+ cached_h = cache_support_hidden_states(self.wrapper, support, ex['task'])
+ if not cached_h:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_greedy(self.wrapper, prompt)
+ else:
+ theta = fit_theta_lm_head_update(
+ cached_h=cached_h,
+ lm_head_weight=self.wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ head_update=head_update,
+ d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ blend_gamma=0.5, max_grad_norm=5.0, device=self.device,
+ )
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = self.wrapper.generate_with_lm_head_update(
+ prompt, theta, head_update,
+ blend_gamma=0.5, max_new_tokens=512,
+ min_new_tokens=128, temperature=0.0,
+ )
+ del cached_h, theta
+ torch.cuda.empty_cache()
+
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support),
+ pred, {'adapt_time': time.time() - t0},
+ extra={
+ 'update_form': 'W + gamma * alpha * C diag(theta) A',
+ 'blend_gamma': 0.5,
+ },
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 40 == 0:
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user])
+ print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})")
+ return per_user
+
+ def _run_user_unigram_bias(self, examples, support_sets, references, support_texts, N,
+ jsonl_path, start_idx, existing):
+ print(f" User-Unigram Bias top_m={self.bias_top_m}, scale={self.unigram_scale}")
+ vocab_size = self.wrapper.lm_head_weight.shape[0]
+ global_log_probs = build_global_log_probs(
+ self.wrapper.tokenizer, support_sets[:N], smoothing=0.1, vocab_size=vocab_size
+ )
+
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex, support = examples[i], support_sets[i]
+ t0 = time.time()
+ bias, token_ids = build_user_unigram_bias(
+ self.wrapper.tokenizer,
+ support,
+ global_log_probs,
+ vocab_size=vocab_size,
+ top_m=self.bias_top_m,
+ scale=self.unigram_scale,
+ smoothing=0.1,
+ )
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_with_logit_bias(
+ self.wrapper, prompt, bias,
+ max_new_tokens=512, min_new_tokens=128, temperature=0.0,
+ )
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support),
+ pred, {'gen_time': time.time() - t0},
+ extra={'bias_top_m': self.bias_top_m, 'bias_tokens': len(token_ids),
+ 'unigram_scale': self.unigram_scale},
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 40 == 0:
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user])
+ print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})")
+ return per_user
+
+ def _run_learned_sparse_logit_bias(self, examples, support_sets, references, support_texts, N,
+ jsonl_path, start_idx, existing):
+ print(
+ f" Learned Sparse Logit Bias top_m={self.bias_top_m}, "
+ f"steps={self.sparse_bias_steps}, lr={self.sparse_bias_lr}"
+ )
+ vocab_size = self.wrapper.lm_head_weight.shape[0]
+ global_log_probs = build_global_log_probs(
+ self.wrapper.tokenizer, support_sets[:N], smoothing=0.1, vocab_size=vocab_size
+ )
+ lm_head_bias = None
+ if hasattr(self.wrapper.model.lm_head, 'bias') and self.wrapper.model.lm_head.bias is not None:
+ lm_head_bias = self.wrapper.model.lm_head.bias.data
+
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex, support = examples[i], support_sets[i]
+ t0 = time.time()
+ init_bias, token_ids = build_user_unigram_bias(
+ self.wrapper.tokenizer,
+ support,
+ global_log_probs,
+ vocab_size=vocab_size,
+ top_m=self.bias_top_m,
+ scale=0.0,
+ smoothing=0.1,
+ )
+ cached_h = cache_support_hidden_states(self.wrapper, support, ex['task'])
+ if not cached_h or not token_ids:
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_greedy(self.wrapper, prompt)
+ n_bias = 0
+ else:
+ learned_bias, n_bias = fit_sparse_logit_bias(
+ cached_h=cached_h,
+ lm_head_weight=self.wrapper.lm_head_weight,
+ lm_head_bias=lm_head_bias,
+ token_ids=token_ids,
+ vocab_size=vocab_size,
+ init_values=None,
+ lr=self.sparse_bias_lr,
+ steps=self.sparse_bias_steps,
+ beta=0.05,
+ lam=1e-4,
+ max_grad_norm=5.0,
+ device=self.device,
+ )
+ prompt = build_query_prompt(ex['query_input'], ex['task'])
+ pred = generate_with_logit_bias(
+ self.wrapper, prompt, learned_bias,
+ max_new_tokens=512, min_new_tokens=128, temperature=0.0,
+ )
+ del cached_h, learned_bias
+ torch.cuda.empty_cache()
+
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support),
+ pred, {'adapt_time': time.time() - t0},
+ extra={'bias_top_m': self.bias_top_m, 'bias_tokens': n_bias,
+ 'sparse_bias_steps': self.sparse_bias_steps,
+ 'sparse_bias_lr': self.sparse_bias_lr},
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 40 == 0:
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user])
+ print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f})")
+ return per_user
+
+ def _run_peft(self, examples, support_sets, references, support_texts, N,
+ config, lr, desc, steps=30, jsonl_path=None, start_idx=0, existing=None):
+ if existing is None:
+ existing = []
+
+ # Reload model fresh to avoid contamination from previous PEFT methods
+ print(f" Reloading model for {desc}...")
+ self.wrapper.model = AutoModelForCausalLM.from_pretrained(
+ 'Qwen/Qwen2.5-1.5B-Instruct',
+ torch_dtype=torch.bfloat16,
+ trust_remote_code=True,
+ ).to(self.device)
+ self.wrapper.model.eval()
+ self.wrapper.lm_head_weight = self.wrapper.model.lm_head.weight.data
+ torch.cuda.empty_cache()
+
+ baseline = PEFTBaseline(self.wrapper, config)
+ print(f" {desc}: {baseline.n_params:,} params ({baseline.n_bytes:,} bytes), steps={steps}, lr={lr}")
+
+ per_user = list(existing)
+ for i in range(start_idx, N):
+ ex, support = examples[i], support_sets[i]
+ t0 = time.time()
+ pred = baseline.adapt_and_generate(
+ support_items=support,
+ query_input=ex['query_input'],
+ task=ex['task'],
+ lr=lr, steps=steps,
+ max_new_tokens=512, min_new_tokens=128,
+ )
+ entry = self._make_entry(
+ ex, references[i], support_texts[i], len(support),
+ pred, {'adapt_time': time.time() - t0},
+ extra={'n_params': baseline.n_params, 'n_bytes': baseline.n_bytes},
+ )
+ per_user.append(entry)
+ append_jsonl(jsonl_path, entry)
+ if (i + 1) % 20 == 0:
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in per_user])
+ avg_t = np.mean([u['adapt_time'] for u in per_user])
+ print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f}, avg time: {avg_t:.1f}s)")
+
+ # No cleanup needed — model will be reloaded fresh for next PEFT method
+ del baseline
+ torch.cuda.empty_cache()
+ return per_user
+
+
+# ─── Main ────────────────────────────────────────────────────────────
+
+def paired_test(scores_a, scores_b, name_a, name_b, metric_name):
+ a, b = np.array(scores_a), np.array(scores_b)
+ diff = a - b
+ mean_diff = np.mean(diff)
+ t_stat, t_pval = stats.ttest_rel(a, b)
+ try:
+ w_stat, w_pval = stats.wilcoxon(a, b)
+ except ValueError:
+ w_stat, w_pval = float('nan'), float('nan')
+ se = stats.sem(diff)
+ ci_low, ci_high = mean_diff - 1.96 * se, mean_diff + 1.96 * se
+ return {
+ 'mean_a': float(np.mean(a)), 'mean_b': float(np.mean(b)),
+ 'mean_diff': float(mean_diff),
+ 'ci_low': float(ci_low), 'ci_high': float(ci_high),
+ 't_pval': float(t_pval), 'w_pval': float(w_pval),
+ }
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_eval', type=int, default=200)
+ parser.add_argument('--task', type=str, default='review', choices=['review', 'topic'])
+ parser.add_argument('--setting', type=str, default='user', choices=['user', 'temporal'])
+ parser.add_argument('--methods', type=str, default='all',
+ help='Comma-separated methods or "all"')
+ parser.add_argument('--device', type=str, default='cuda:0')
+ parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--d', type=int, default=64, help='UPH theta dimension')
+ parser.add_argument('--output_dir', type=str, default='outputs/unified')
+ parser.add_argument('--bias_top_m', type=int, default=512,
+ help='Number of user-specific tokens for logit-bias baselines')
+ parser.add_argument('--unigram_scale', type=float, default=0.5,
+ help='Scale for zero-training user unigram logit bias')
+ parser.add_argument('--sparse_bias_lr', type=float, default=0.05,
+ help='Learning rate for learned sparse logit-bias baseline')
+ parser.add_argument('--sparse_bias_steps', type=int, default=30,
+ help='Adaptation steps for learned sparse logit-bias baseline')
+ args = parser.parse_args()
+
+ N = args.num_eval
+ task = args.task
+ setting = args.setting
+ K = args.K
+
+ config_map = {
+ ('review', 'user'): 'product_review_user',
+ ('review', 'temporal'): 'product_review_temporal',
+ ('topic', 'user'): 'topic_writing_user',
+ ('topic', 'temporal'): 'topic_writing_temporal',
+ }
+ config_name = config_map[(task, setting)]
+
+ if args.methods == 'all':
+ methods = ALL_METHODS
+ else:
+ methods = [m.strip() for m in args.methods.split(',')]
+
+ print(f"=== Unified Eval: {task}_{setting}, N={N}, K={K}, d={args.d} ===")
+ print(f"Methods: {methods}")
+ print(f"Decode: greedy, min=128, max=512")
+
+ print("\nLoading data...")
+ examples = load_longlamp(config_name, split='val')[:N]
+ support_sets = [select_k_profile_items(ex['profile_items'], K, seed=0) for ex in examples]
+ references = [ex['target_output'] for ex in examples]
+ support_texts = [[s['support_output'] for s in ss] for ss in support_sets]
+
+ print(f"Loading model on {args.device}...")
+ wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device=args.device)
+
+ runner = MethodRunner(
+ wrapper,
+ args.device,
+ uph_d=args.d,
+ bias_top_m=args.bias_top_m,
+ unigram_scale=args.unigram_scale,
+ sparse_bias_lr=args.sparse_bias_lr,
+ sparse_bias_steps=args.sparse_bias_steps,
+ )
+ all_per_user = {}
+
+ for method in methods:
+ method_dir, method_label = get_method_dir(
+ args.output_dir, task, setting, K, method, args.d
+ )
+
+ # Skip if already complete
+ if is_method_complete(method_dir, N):
+ print(f"\n--- {method} --- COMPLETE (loading from disk)")
+ with open(os.path.join(method_dir, 'per_user.json')) as f:
+ data = json.load(f)
+ all_per_user[method] = data['per_user'][:N]
+ avg_rl = np.mean([u['metrics']['rougeL'] for u in all_per_user[method]])
+ print(f" Mean R-L: {avg_rl:.4f}")
+ continue
+
+ per_user = runner.run(
+ method, examples, support_sets, references, support_texts,
+ N, method_dir, method_label, task, setting, K, args.d,
+ )
+ all_per_user[method] = per_user
+
+ # Summary table
+ print("\n" + "=" * 90)
+ print(f"{'Method':<15} {'R-L':<8} {'METEOR':<8} {'SFD_-len':<9} {'Len':<6}")
+ print("-" * 90)
+ for method in methods:
+ if method not in all_per_user:
+ continue
+ pu = all_per_user[method]
+ rl = np.mean([u['metrics']['rougeL'] for u in pu])
+ mt = np.mean([u['metrics']['meteor'] for u in pu])
+ sf = np.mean([u['metrics']['sfd_nolen'] for u in pu])
+ ln = np.mean([u['metrics']['length'] for u in pu])
+ print(f"{method:<15} {rl:<8.4f} {mt:<8.4f} {sf:<9.4f} {ln:<6.0f}")
+
+ # Significance tests (UPH vs all others)
+ sig_results = {}
+ if 'uph' in all_per_user:
+ print("\n" + "=" * 90)
+ print("Significance (UPH vs each, paired t-test p-value)")
+ print("=" * 90)
+ uph_rl = [u['metrics']['rougeL'] for u in all_per_user['uph']]
+ uph_sf = [u['metrics']['sfd_nolen'] for u in all_per_user['uph']]
+ for method in methods:
+ if method == 'uph' or method not in all_per_user:
+ continue
+ other_rl = [u['metrics']['rougeL'] for u in all_per_user[method]]
+ other_sf = [u['metrics']['sfd_nolen'] for u in all_per_user[method]]
+ rl_t = paired_test(uph_rl, other_rl, 'uph', method, 'R-L')
+ sf_t = paired_test(uph_sf, other_sf, 'uph', method, 'SFD')
+ sig_results[method] = {'rougeL': rl_t, 'sfd_nolen': sf_t}
+ print(f" vs {method:<12} R-L: diff={rl_t['mean_diff']:+.4f} p={rl_t['t_pval']:.2e} "
+ f"SFD: diff={sf_t['mean_diff']:+.4f} p={sf_t['t_pval']:.2e}")
+
+ # Save summary
+ exp_dir = os.path.join(args.output_dir, f"{task}_{setting}_K{K}")
+ summary = {}
+ for method in methods:
+ if method not in all_per_user:
+ continue
+ pu = all_per_user[method]
+ summary[method] = {
+ 'rougeL': float(np.mean([u['metrics']['rougeL'] for u in pu])),
+ 'meteor': float(np.mean([u['metrics']['meteor'] for u in pu])),
+ 'sfd_nolen': float(np.mean([u['metrics']['sfd_nolen'] for u in pu])),
+ 'avg_len': float(np.mean([u['metrics']['length'] for u in pu])),
+ }
+ summary_path = os.path.join(exp_dir, 'summary.json')
+ with open(summary_path, 'w') as f:
+ json.dump({
+ 'aggregate': summary,
+ 'significance': sig_results,
+ 'num_examples': N, 'task': task, 'setting': setting, 'K': K,
+ 'methods': methods,
+ }, f, indent=2, default=str)
+
+ print(f"\nSummary: {summary_path}")
+
+
+if __name__ == '__main__':
+ main()