summaryrefslogtreecommitdiff
path: root/scripts/build_memory_bank.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/build_memory_bank.py')
-rw-r--r--scripts/build_memory_bank.py72
1 files changed, 72 insertions, 0 deletions
diff --git a/scripts/build_memory_bank.py b/scripts/build_memory_bank.py
new file mode 100644
index 0000000..2aff828
--- /dev/null
+++ b/scripts/build_memory_bank.py
@@ -0,0 +1,72 @@
+"""Offline script: encode corpus passages into a memory bank.
+
+Usage:
+ python scripts/build_memory_bank.py --config configs/default.yaml --corpus data/corpus.jsonl --output data/memory_bank.pt
+"""
+
+import argparse
+import json
+import logging
+
+import torch
+import yaml
+from tqdm import tqdm
+
+from hag.config import EncoderConfig, MemoryBankConfig
+from hag.encoder import Encoder
+from hag.memory_bank import MemoryBank
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def load_corpus(path: str) -> list[str]:
+ """Load passages from a JSONL file (one JSON object per line with 'text' field)."""
+ passages = []
+ with open(path) as f:
+ for line in f:
+ obj = json.loads(line)
+ passages.append(obj["text"])
+ return passages
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Build memory bank from corpus")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--corpus", type=str, required=True)
+ parser.add_argument("--output", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+ memory_config = MemoryBankConfig(**cfg.get("memory", {}))
+
+ # Load corpus
+ logger.info("Loading corpus from %s", args.corpus)
+ passages = load_corpus(args.corpus)
+ logger.info("Loaded %d passages", len(passages))
+
+ # Encode passages in batches
+ encoder = Encoder(encoder_config)
+ all_embeddings = []
+
+ for i in tqdm(range(0, len(passages), encoder_config.batch_size), desc="Encoding"):
+ batch = passages[i : i + encoder_config.batch_size]
+ emb = encoder.encode(batch) # (batch_size, d)
+ all_embeddings.append(emb.cpu())
+
+ embeddings = torch.cat(all_embeddings, dim=0) # (N, d)
+ logger.info("Encoded %d passages -> embeddings shape: %s", len(passages), embeddings.shape)
+
+ # Build and save memory bank
+ mb = MemoryBank(memory_config)
+ mb.build_from_embeddings(embeddings, passages)
+ mb.save(args.output)
+ logger.info("Memory bank saved to %s", args.output)
+
+
+if __name__ == "__main__":
+ main()