summaryrefslogtreecommitdiff
path: root/diag/train_cycle.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-06-29 12:04:47 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-06-29 12:04:47 -0500
commitc54ddb88b532be28ca3096e21de405d90163ecfa (patch)
tree3270ec9269dbee14ea915963f0d28e933303d5a7 /diag/train_cycle.py
parentd12722525fc010a3910b5152c72654a2ade5eac4 (diff)
Package full RRoG GNN project
Diffstat (limited to 'diag/train_cycle.py')
-rw-r--r--diag/train_cycle.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/diag/train_cycle.py b/diag/train_cycle.py
index d2342f3..598e349 100644
--- a/diag/train_cycle.py
+++ b/diag/train_cycle.py
@@ -24,9 +24,14 @@ from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx
from torch_geometric.nn import GINConv, GCNConv, global_add_pool
-ROOT = '/home/yurenh2/rrog/data/zinc'
-CACHE = '/home/yurenh2/rrog/data/cycle_cache'
-OUT = '/home/yurenh2/rrog/runs'
+PROJECT_ROOT = os.environ.get(
+ 'RROG_ROOT',
+ os.path.abspath(os.path.join(os.path.dirname(__file__), '..')),
+)
+DATA_ROOT = os.environ.get('RROG_DATA_DIR', os.path.join(PROJECT_ROOT, 'data'))
+OUT = os.environ.get('RROG_RUNS_DIR', os.path.join(PROJECT_ROOT, 'runs'))
+ROOT = os.path.join(DATA_ROOT, 'zinc')
+CACHE = os.path.join(DATA_ROOT, 'cycle_cache')
RWSE_K = 16