summaryrefslogtreecommitdiff
path: root/baselines/dense_retrieval.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-05 10:31:36 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-05 10:31:36 -0500
commitea4a8f837e81b5e5fab6086cb3014c711c5e58e9 (patch)
tree11638546dc91c97815e5bdab8fa0b587481d0a3c /baselines/dense_retrieval.py
parent8fe28101366dd32562b8c5534d7fe359b252bdf3 (diff)
Add PEFT baselines, ICL baselines, profile-based, and unified pipeline
New baselines: - baselines/peft_baseline.py: LoRA, Tiny LoRA, VeRA (per-user PEFT adaptation) - baselines/dense_retrieval.py: Dense retrieval ICL (sentence-transformers) - baselines/profile_based.py: LLM-generated user profile conditioned generation New scripts: - scripts/run_all_methods.py: Unified pipeline running all 9 methods with per-method directory output structure (method/per_user.json) - scripts/run_peft_baselines.py: PEFT-only evaluation (legacy) - scripts/run_significance.py: Significance tests (UPH+Base per-user) - scripts/run_uph_base_per_user.py: UPH+Base with full per-user data - scripts/compute_bertscore.py: BERTScore from saved predictions - scripts/significance_test.py: Standalone significance test framework Updated .gitignore to exclude outputs/ directory. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'baselines/dense_retrieval.py')
-rw-r--r--baselines/dense_retrieval.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/baselines/dense_retrieval.py b/baselines/dense_retrieval.py
new file mode 100644
index 0000000..db403a8
--- /dev/null
+++ b/baselines/dense_retrieval.py
@@ -0,0 +1,42 @@
+"""Dense Retrieval ICL baseline.
+
+Uses sentence-transformers for dense retrieval over user support set,
+then places top-K retrieved items as in-context examples.
+"""
+
+import torch
+from sentence_transformers import SentenceTransformer
+
+
+class DenseRetriever:
+ """Dense retriever using sentence-transformers embeddings."""
+
+ def __init__(self, model_name='all-MiniLM-L6-v2', device='cpu'):
+ self.model = SentenceTransformer(model_name, device=device)
+
+ def retrieve_top_k(self, query: str, support_items: list, k: int = 1):
+ """Retrieve top-k support items most relevant to query.
+
+ Args:
+ query: query input text
+ support_items: list of dicts with 'support_input', 'support_output'
+ k: number of items to retrieve
+
+ Returns:
+ List of top-k support items (sorted by relevance)
+ """
+ if len(support_items) <= k:
+ return support_items
+
+ texts = [item['support_input'] for item in support_items]
+ embeddings = self.model.encode([query] + texts, convert_to_tensor=True)
+
+ query_emb = embeddings[0]
+ item_embs = embeddings[1:]
+
+ similarities = torch.nn.functional.cosine_similarity(
+ query_emb.unsqueeze(0), item_embs, dim=1
+ )
+
+ top_indices = similarities.argsort(descending=True)[:k].tolist()
+ return [support_items[i] for i in top_indices]