diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-05 10:31:36 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-05 10:31:36 -0500 |
| commit | ea4a8f837e81b5e5fab6086cb3014c711c5e58e9 (patch) | |
| tree | 11638546dc91c97815e5bdab8fa0b587481d0a3c /baselines/dense_retrieval.py | |
| parent | 8fe28101366dd32562b8c5534d7fe359b252bdf3 (diff) | |
Add PEFT baselines, ICL baselines, profile-based, and unified pipeline
New baselines:
- baselines/peft_baseline.py: LoRA, Tiny LoRA, VeRA (per-user PEFT adaptation)
- baselines/dense_retrieval.py: Dense retrieval ICL (sentence-transformers)
- baselines/profile_based.py: LLM-generated user profile conditioned generation
New scripts:
- scripts/run_all_methods.py: Unified pipeline running all 9 methods with
per-method directory output structure (method/per_user.json)
- scripts/run_peft_baselines.py: PEFT-only evaluation (legacy)
- scripts/run_significance.py: Significance tests (UPH+Base per-user)
- scripts/run_uph_base_per_user.py: UPH+Base with full per-user data
- scripts/compute_bertscore.py: BERTScore from saved predictions
- scripts/significance_test.py: Standalone significance test framework
Updated .gitignore to exclude outputs/ directory.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'baselines/dense_retrieval.py')
| -rw-r--r-- | baselines/dense_retrieval.py | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/baselines/dense_retrieval.py b/baselines/dense_retrieval.py new file mode 100644 index 0000000..db403a8 --- /dev/null +++ b/baselines/dense_retrieval.py @@ -0,0 +1,42 @@ +"""Dense Retrieval ICL baseline. + +Uses sentence-transformers for dense retrieval over user support set, +then places top-K retrieved items as in-context examples. +""" + +import torch +from sentence_transformers import SentenceTransformer + + +class DenseRetriever: + """Dense retriever using sentence-transformers embeddings.""" + + def __init__(self, model_name='all-MiniLM-L6-v2', device='cpu'): + self.model = SentenceTransformer(model_name, device=device) + + def retrieve_top_k(self, query: str, support_items: list, k: int = 1): + """Retrieve top-k support items most relevant to query. + + Args: + query: query input text + support_items: list of dicts with 'support_input', 'support_output' + k: number of items to retrieve + + Returns: + List of top-k support items (sorted by relevance) + """ + if len(support_items) <= k: + return support_items + + texts = [item['support_input'] for item in support_items] + embeddings = self.model.encode([query] + texts, convert_to_tensor=True) + + query_emb = embeddings[0] + item_embs = embeddings[1:] + + similarities = torch.nn.functional.cosine_similarity( + query_emb.unsqueeze(0), item_embs, dim=1 + ) + + top_indices = similarities.argsort(descending=True)[:k].tolist() + return [support_items[i] for i in top_indices] |
