diff options
Diffstat (limited to 'scripts/build_memory_bank.py')
| -rw-r--r-- | scripts/build_memory_bank.py | 72 |
1 files changed, 72 insertions, 0 deletions
diff --git a/scripts/build_memory_bank.py b/scripts/build_memory_bank.py new file mode 100644 index 0000000..2aff828 --- /dev/null +++ b/scripts/build_memory_bank.py @@ -0,0 +1,72 @@ +"""Offline script: encode corpus passages into a memory bank. + +Usage: + python scripts/build_memory_bank.py --config configs/default.yaml --corpus data/corpus.jsonl --output data/memory_bank.pt +""" + +import argparse +import json +import logging + +import torch +import yaml +from tqdm import tqdm + +from hag.config import EncoderConfig, MemoryBankConfig +from hag.encoder import Encoder +from hag.memory_bank import MemoryBank + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def load_corpus(path: str) -> list[str]: + """Load passages from a JSONL file (one JSON object per line with 'text' field).""" + passages = [] + with open(path) as f: + for line in f: + obj = json.loads(line) + passages.append(obj["text"]) + return passages + + +def main() -> None: + parser = argparse.ArgumentParser(description="Build memory bank from corpus") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--corpus", type=str, required=True) + parser.add_argument("--output", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + memory_config = MemoryBankConfig(**cfg.get("memory", {})) + + # Load corpus + logger.info("Loading corpus from %s", args.corpus) + passages = load_corpus(args.corpus) + logger.info("Loaded %d passages", len(passages)) + + # Encode passages in batches + encoder = Encoder(encoder_config) + all_embeddings = [] + + for i in tqdm(range(0, len(passages), encoder_config.batch_size), desc="Encoding"): + batch = passages[i : i + encoder_config.batch_size] + emb = encoder.encode(batch) # (batch_size, d) + all_embeddings.append(emb.cpu()) + + embeddings = torch.cat(all_embeddings, dim=0) # (N, d) + logger.info("Encoded %d passages -> embeddings shape: %s", len(passages), embeddings.shape) + + # Build and save memory bank + mb = MemoryBank(memory_config) + mb.build_from_embeddings(embeddings, passages) + mb.save(args.output) + logger.info("Memory bank saved to %s", args.output) + + +if __name__ == "__main__": + main() |
