summaryrefslogtreecommitdiff
path: root/diag/aggregate.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/aggregate.py')
-rw-r--r--diag/aggregate.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/diag/aggregate.py b/diag/aggregate.py
index b0f737a..4ded30f 100644
--- a/diag/aggregate.py
+++ b/diag/aggregate.py
@@ -1,4 +1,4 @@
-"""Aggregate multi-seed coloring results -> mean+/-std per (grad_mode, pe, contract)."""
+"""Aggregate multi-seed coloring results -> mean+/-std per architecture/config."""
import glob, json
import numpy as np
from collections import defaultdict
@@ -22,7 +22,8 @@ def load(pat):
def key(d):
- return (d.get('conv', 'gin'), d.get('pe'), d.get('grad_mode'), 'ctr' if d.get('contract') else '-')
+ return (d.get('arch', 'legacy'), d.get('conv', 'gin'), d.get('pe'),
+ d.get('grad_mode'), 'ctr' if d.get('contract') else '-')
solve, le, ml = defaultdict(list), defaultdict(list), defaultdict(list)