summaryrefslogtreecommitdiff
path: root/scripts/analyze_energy.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/analyze_energy.py')
-rw-r--r--scripts/analyze_energy.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/scripts/analyze_energy.py b/scripts/analyze_energy.py
index fd044a4..cd93b15 100644
--- a/scripts/analyze_energy.py
+++ b/scripts/analyze_energy.py
@@ -32,6 +32,7 @@ def main() -> None:
parser.add_argument("--memory-bank", type=str, required=True)
parser.add_argument("--questions", type=str, required=True)
parser.add_argument("--output", type=str, default="energy_analysis.json")
+ parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
with open(args.config) as f:
@@ -43,13 +44,13 @@ def main() -> None:
# Load memory bank
mb = MemoryBank(memory_config)
- mb.load(args.memory_bank)
+ mb.load(args.memory_bank, device=args.device)
# Load questions
with open(args.questions) as f:
questions = [json.loads(line)["question"] for line in f]
- encoder = Encoder(encoder_config)
+ encoder = Encoder(encoder_config, device=args.device)
hopfield = HopfieldRetrieval(hopfield_config)
analyses = []