diff options
Diffstat (limited to 'scripts/analyze_energy.py')
| -rw-r--r-- | scripts/analyze_energy.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/scripts/analyze_energy.py b/scripts/analyze_energy.py index fd044a4..cd93b15 100644 --- a/scripts/analyze_energy.py +++ b/scripts/analyze_energy.py @@ -32,6 +32,7 @@ def main() -> None: parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--questions", type=str, required=True) parser.add_argument("--output", type=str, default="energy_analysis.json") + parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() with open(args.config) as f: @@ -43,13 +44,13 @@ def main() -> None: # Load memory bank mb = MemoryBank(memory_config) - mb.load(args.memory_bank) + mb.load(args.memory_bank, device=args.device) # Load questions with open(args.questions) as f: questions = [json.loads(line)["question"] for line in f] - encoder = Encoder(encoder_config) + encoder = Encoder(encoder_config, device=args.device) hopfield = HopfieldRetrieval(hopfield_config) analyses = [] |
