"""Offline script: encode corpus passages into a memory bank. Usage: python scripts/build_memory_bank.py --config configs/default.yaml --corpus data/corpus.jsonl --output data/memory_bank.pt """ import argparse import json import logging import torch import yaml from tqdm import tqdm from hag.config import EncoderConfig, MemoryBankConfig from hag.encoder import Encoder from hag.memory_bank import MemoryBank logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def load_corpus(path: str) -> list[str]: """Load passages from a JSONL file (one JSON object per line with 'text' field).""" passages = [] with open(path) as f: for line in f: obj = json.loads(line) passages.append(obj["text"]) return passages def main() -> None: parser = argparse.ArgumentParser(description="Build memory bank from corpus") parser.add_argument("--config", type=str, default="configs/default.yaml") parser.add_argument("--corpus", type=str, required=True) parser.add_argument("--output", type=str, required=True) parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) encoder_config = EncoderConfig(**cfg.get("encoder", {})) memory_config = MemoryBankConfig(**cfg.get("memory", {})) # Load corpus logger.info("Loading corpus from %s", args.corpus) passages = load_corpus(args.corpus) logger.info("Loaded %d passages", len(passages)) # Encode passages in batches encoder = Encoder(encoder_config) all_embeddings = [] for i in tqdm(range(0, len(passages), encoder_config.batch_size), desc="Encoding"): batch = passages[i : i + encoder_config.batch_size] emb = encoder.encode(batch) # (batch_size, d) all_embeddings.append(emb.cpu()) embeddings = torch.cat(all_embeddings, dim=0) # (N, d) logger.info("Encoded %d passages -> embeddings shape: %s", len(passages), embeddings.shape) # Build and save memory bank mb = MemoryBank(memory_config) mb.build_from_embeddings(embeddings, passages) mb.save(args.output) logger.info("Memory bank saved to %s", args.output) if __name__ == "__main__": main()