diff options
| author | Zitian Gao <zitian.gao@outlook.com> | 2025-05-29 21:10:56 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-29 21:10:56 +0800 |
| commit | 6bd72b87d9d9e6b2832e1799b403b23202e54204 (patch) | |
| tree | 60f1b00e6c66512ab01cf3ecb7ba17f6e006cc20 /train.py | |
| parent | a82af6165d54d4e162d6301fc211f8728fc2a442 (diff) | |
update code format style
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 31 |
1 files changed, 23 insertions, 8 deletions
@@ -91,15 +91,23 @@ def main(): "train_batch_size": eff_bs, "gradient_accumulation_steps": accum_steps, "bf16": {"enabled": True}, - "zero_optimization": {"stage": 2, "offload_optimizer": {"device": "cpu"}, "offload_param": {"device": "none"}}, + "zero_optimization": { + "stage": 2, + "offload_optimizer": {"device": "cpu"}, + "offload_param": {"device": "none"} + }, "gradient_clipping": 1.0, } - accelerator = Accelerator(mixed_precision="bf16", gradient_accumulation_steps=accum_steps, deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config)) + accelerator = Accelerator(mixed_precision="bf16", + gradient_accumulation_steps=accum_steps, + deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config)) print = accelerator.print model_path = args.model_path or f"/volume/ailab4sci/models/{args.model_name}" - config = AutoConfig.from_pretrained(model_path); config.use_cache = False - model = AutoModelForCausalLM.from_pretrained(model_path, config=config); model.gradient_checkpointing_enable() + config = AutoConfig.from_pretrained(model_path) + config.use_cache = False + model = AutoModelForCausalLM.from_pretrained(model_path, config=config) + model.gradient_checkpointing_enable() tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token @@ -121,7 +129,11 @@ def main(): break with accelerator.accumulate(model): - enc = tokenizer(batch["input"], return_tensors="pt", padding="longest", truncation=True, max_length=2048).to(accelerator.device) + enc = tokenizer(batch["input"], + return_tensors="pt", + padding="longest", + truncation=True, + max_length=2048).to(accelerator.device) with torch.no_grad(): gen_ids = accelerator.unwrap_model(model).generate(**enc, @@ -133,7 +145,8 @@ def main(): use_cache=False) seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096] - pad_mask = seq.ne(tokenizer.pad_token_id); prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1) + pad_mask = seq.ne(tokenizer.pad_token_id) + prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1) token_idx = torch.arange(seq.size(1), device=seq.device) gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask @@ -146,8 +159,10 @@ def main(): loss = loss + args.self_distill_gamma * F.mse_loss(logits, prev_logits) prev_logits = logits.detach() - accelerator.backward(loss); accelerator.clip_grad_norm_(model.parameters(), 1.0) - optimizer.step(); optimizer.zero_grad() + accelerator.backward(loss) + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + optimizer.zero_grad() if accelerator.is_main_process: if step % args.log_steps == 0: |
