summaryrefslogtreecommitdiff
path: root/scripts/run_eval.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
committerYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
commit09d50e47860da0035e178a442dc936028808a0b3 (patch)
tree9d651b0c7d289a9a0405953f2da989a3c431f147 /scripts/run_eval.py
parentc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff)
Add memory centering, grid search experiments, and energy visualizationsHEADmaster
- Add centering support to MemoryBank (center_query, apply_centering, mean persistence in save/load) to remove centroid attractor in Hopfield dynamics - Add center flag to MemoryBankConfig, device field to PipelineConfig - Grid search scripts: initial (β≤8), residual, high-β, and centered grids with dedup-based LLM caching (89-91% call savings) - Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap comparing centered vs uncentered dynamics - Experiment log (note.md) documenting 4 rounds of results and root cause analysis of centroid attractor problem - Key finding: β_critical ≈ 37.6 for centered memory; best configs beat FAISS baseline by +3-4% F1 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/run_eval.py')
-rw-r--r--scripts/run_eval.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/scripts/run_eval.py b/scripts/run_eval.py
index 713b3c2..144fc2f 100644
--- a/scripts/run_eval.py
+++ b/scripts/run_eval.py
@@ -36,6 +36,7 @@ def main() -> None:
parser.add_argument("--split", type=str, default="validation")
parser.add_argument("--max-samples", type=int, default=500)
parser.add_argument("--output", type=str, default="results.json")
+ parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
with open(args.config) as f:
@@ -47,15 +48,16 @@ def main() -> None:
encoder=EncoderConfig(**cfg.get("encoder", {})),
generator=GeneratorConfig(**cfg.get("generator", {})),
retriever_type=cfg.get("retriever_type", "hopfield"),
+ device=args.device,
)
# Load memory bank
mb = MemoryBank(pipeline_config.memory)
- mb.load(args.memory_bank)
+ mb.load(args.memory_bank, device=args.device)
# Build pipeline
- encoder = Encoder(pipeline_config.encoder)
- generator = Generator(pipeline_config.generator)
+ encoder = Encoder(pipeline_config.encoder, device=args.device)
+ generator = Generator(pipeline_config.generator, device=args.device)
pipeline = RAGPipeline(
config=pipeline_config,
encoder=encoder,