summaryrefslogtreecommitdiff
path: root/diag/train_cycle.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/train_cycle.py')
-rw-r--r--diag/train_cycle.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/diag/train_cycle.py b/diag/train_cycle.py
index d2342f3..598e349 100644
--- a/diag/train_cycle.py
+++ b/diag/train_cycle.py
@@ -24,9 +24,14 @@ from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx
from torch_geometric.nn import GINConv, GCNConv, global_add_pool
-ROOT = '/home/yurenh2/rrog/data/zinc'
-CACHE = '/home/yurenh2/rrog/data/cycle_cache'
-OUT = '/home/yurenh2/rrog/runs'
+PROJECT_ROOT = os.environ.get(
+ 'RROG_ROOT',
+ os.path.abspath(os.path.join(os.path.dirname(__file__), '..')),
+)
+DATA_ROOT = os.environ.get('RROG_DATA_DIR', os.path.join(PROJECT_ROOT, 'data'))
+OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs'))
+ROOT = os.path.join(DATA_ROOT, 'zinc')
+CACHE = os.path.join(DATA_ROOT, 'cycle_cache')
RWSE_K = 16