diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-13 00:22:32 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-13 00:22:32 -0500 |
| commit | 3b2b49845a256fcabc55af789562ca034bb69ebe (patch) | |
| tree | 6e3d82e8f4e248963509d091b31a9bf18f6517cc /scripts | |
| parent | bfdcc36c0e31adfa95410ce87e7da646e0b948fe (diff) | |
Fix VeRA crash: reload model fresh before each PEFT method
Root cause: get_peft_model() modifies model in-place. After LoRA/TinyLoRA
cleanup, the model's modules are altered so VeRA can't find target_modules.
Fix: reload AutoModelForCausalLM from scratch before each PEFT method.
Slower but reliable — no more cross-contamination between PEFT methods.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'scripts')
| -rw-r--r-- | scripts/run_all_methods.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/scripts/run_all_methods.py b/scripts/run_all_methods.py index 73a0a87..7bbcadf 100644 --- a/scripts/run_all_methods.py +++ b/scripts/run_all_methods.py @@ -22,6 +22,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.longlamp import load_longlamp, select_k_profile_items from data.templates import build_query_prompt, build_prompt_with_examples from data.style_features import compute_sfd, compute_feature_deltas +from transformers import AutoModelForCausalLM from models.qwen_wrapper import QwenWrapper from models.cvh import UnconditionalHead from adapt.cache_hidden import cache_support_hidden_states @@ -375,6 +376,18 @@ class MethodRunner: config, lr, desc, steps=30, jsonl_path=None, start_idx=0, existing=None): if existing is None: existing = [] + + # Reload model fresh to avoid contamination from previous PEFT methods + print(f" Reloading model for {desc}...") + self.wrapper.model = AutoModelForCausalLM.from_pretrained( + 'Qwen/Qwen2.5-1.5B-Instruct', + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ).to(self.device) + self.wrapper.model.eval() + self.wrapper.lm_head_weight = self.wrapper.model.lm_head.weight.data + torch.cuda.empty_cache() + baseline = PEFTBaseline(self.wrapper, config) print(f" {desc}: {baseline.n_params:,} params ({baseline.n_bytes:,} bytes), steps={steps}, lr={lr}") @@ -401,7 +414,9 @@ class MethodRunner: avg_t = np.mean([u['adapt_time'] for u in per_user]) print(f" {i+1}/{N} (avg R-L: {avg_rl:.4f}, avg time: {avg_t:.1f}s)") - baseline.cleanup() + # No cleanup needed — model will be reloaded fresh for next PEFT method + del baseline + torch.cuda.empty_cache() return per_user |
