diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-02-16 14:44:42 -0600 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-02-16 14:44:42 -0600 |
| commit | 09d50e47860da0035e178a442dc936028808a0b3 (patch) | |
| tree | 9d651b0c7d289a9a0405953f2da989a3c431f147 /scripts/run_hag.py | |
| parent | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff) | |
- Add centering support to MemoryBank (center_query, apply_centering, mean
persistence in save/load) to remove centroid attractor in Hopfield dynamics
- Add center flag to MemoryBankConfig, device field to PipelineConfig
- Grid search scripts: initial (β≤8), residual, high-β, and centered grids
with dedup-based LLM caching (89-91% call savings)
- Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap
comparing centered vs uncentered dynamics
- Experiment log (note.md) documenting 4 rounds of results and root cause
analysis of centroid attractor problem
- Key finding: β_critical ≈ 37.6 for centered memory; best configs beat
FAISS baseline by +3-4% F1
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/run_hag.py')
| -rw-r--r-- | scripts/run_hag.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/scripts/run_hag.py b/scripts/run_hag.py index 4cacd1a..b6c9004 100644 --- a/scripts/run_hag.py +++ b/scripts/run_hag.py @@ -33,6 +33,7 @@ def main() -> None: parser.add_argument("--beta", type=float, default=None) parser.add_argument("--max-iter", type=int, default=None) parser.add_argument("--top-k", type=int, default=None) + parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() with open(args.config) as f: @@ -52,13 +53,14 @@ def main() -> None: encoder=EncoderConfig(**cfg.get("encoder", {})), generator=GeneratorConfig(**cfg.get("generator", {})), retriever_type="hopfield", + device=args.device, ) mb = MemoryBank(pipeline_config.memory) - mb.load(args.memory_bank) + mb.load(args.memory_bank, device=args.device) - encoder = Encoder(pipeline_config.encoder) - generator = Generator(pipeline_config.generator) + encoder = Encoder(pipeline_config.encoder, device=args.device) + generator = Generator(pipeline_config.generator, device=args.device) pipeline = RAGPipeline( config=pipeline_config, |
