summaryrefslogtreecommitdiff
path: root/trm/config/arch/hrm.yaml
blob: 0d183bed981d7196e76d9b17c9583c5ba06a976e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
name: recursive_reasoning.hrm@HierarchicalReasoningModel_ACTV1
loss:
  name: losses@ACTLossHead
  loss_type: stablemax_cross_entropy

halt_exploration_prob: 0.1
halt_max_steps: 16

H_cycles: 2
L_cycles: 2

H_layers: 4
L_layers: 4

hidden_size: 512
num_heads: 8  # min(2, hidden_size // 64)
expansion: 4

puzzle_emb_ndim: ${.hidden_size}

pos_encodings: rope
forward_dtype: bfloat16

mlp_t: False # use mlp on L instead of transformer