summaryrefslogtreecommitdiff
path: root/adapt/fit_theta.py
diff options
context:
space:
mode:
Diffstat (limited to 'adapt/fit_theta.py')
-rw-r--r--adapt/fit_theta.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/adapt/fit_theta.py b/adapt/fit_theta.py
index f5b047b..e8bff28 100644
--- a/adapt/fit_theta.py
+++ b/adapt/fit_theta.py
@@ -7,7 +7,8 @@ import torch
import torch.nn.functional as F
# Maximum chunk size for logit computation to avoid OOM
-CHUNK_SIZE = 128
+# Reduced from 128 to 32 to handle K=16 (longer sequences)
+CHUNK_SIZE = 32
def _chunked_ce_kl(h_prime, h_base, lm_w, lm_bias, y, beta):