diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-10 14:50:22 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-10 14:50:22 -0500 |
| commit | 112c5d354f36d6ea6e8049cf1aeaebeb9944aa02 (patch) | |
| tree | 7b56f786ee1450aa545caf565f039e7144ed6b7c /adapt/fit_theta.py | |
| parent | 26c899101dbb192981cc67d73fc00a2d158b503e (diff) | |
Fix two bugs: PEFT cleanup model corruption and K=16 OOM
Bug 1: PEFTBaseline.cleanup() corrupted wrapper.model after LoRA unload,
causing 'Qwen2Model has no attribute prepare_inputs_for_generation' for
subsequent methods. Fix: save reference to original model before wrapping,
restore it directly in cleanup() instead of relying on unload().
Bug 2: fit_theta OOM at K=16 due to large logit chunks (128 × 151936 vocab).
Fix: reduce CHUNK_SIZE from 128 to 32 (~4x less memory per chunk).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'adapt/fit_theta.py')
| -rw-r--r-- | adapt/fit_theta.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/adapt/fit_theta.py b/adapt/fit_theta.py index f5b047b..e8bff28 100644 --- a/adapt/fit_theta.py +++ b/adapt/fit_theta.py @@ -7,7 +7,8 @@ import torch import torch.nn.functional as F # Maximum chunk size for logit computation to avoid OOM -CHUNK_SIZE = 128 +# Reduced from 128 to 32 to handle K=16 (longer sequences) +CHUNK_SIZE = 32 def _chunked_ce_kl(h_prime, h_base, lm_w, lm_bias, y, beta): |
