summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py31
1 files changed, 23 insertions, 8 deletions
diff --git a/train.py b/train.py
index e23b23d..9506d76 100644
--- a/train.py
+++ b/train.py
@@ -91,15 +91,23 @@ def main():
"train_batch_size": eff_bs,
"gradient_accumulation_steps": accum_steps,
"bf16": {"enabled": True},
- "zero_optimization": {"stage": 2, "offload_optimizer": {"device": "cpu"}, "offload_param": {"device": "none"}},
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {"device": "cpu"},
+ "offload_param": {"device": "none"}
+ },
"gradient_clipping": 1.0,
}
- accelerator = Accelerator(mixed_precision="bf16", gradient_accumulation_steps=accum_steps, deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config))
+ accelerator = Accelerator(mixed_precision="bf16",
+ gradient_accumulation_steps=accum_steps,
+ deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config))
print = accelerator.print
model_path = args.model_path or f"/volume/ailab4sci/models/{args.model_name}"
- config = AutoConfig.from_pretrained(model_path); config.use_cache = False
- model = AutoModelForCausalLM.from_pretrained(model_path, config=config); model.gradient_checkpointing_enable()
+ config = AutoConfig.from_pretrained(model_path)
+ config.use_cache = False
+ model = AutoModelForCausalLM.from_pretrained(model_path, config=config)
+ model.gradient_checkpointing_enable()
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
@@ -121,7 +129,11 @@ def main():
break
with accelerator.accumulate(model):
- enc = tokenizer(batch["input"], return_tensors="pt", padding="longest", truncation=True, max_length=2048).to(accelerator.device)
+ enc = tokenizer(batch["input"],
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=2048).to(accelerator.device)
with torch.no_grad():
gen_ids = accelerator.unwrap_model(model).generate(**enc,
@@ -133,7 +145,8 @@ def main():
use_cache=False)
seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096]
- pad_mask = seq.ne(tokenizer.pad_token_id); prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1)
+ pad_mask = seq.ne(tokenizer.pad_token_id)
+ prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1)
token_idx = torch.arange(seq.size(1), device=seq.device)
gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask
@@ -146,8 +159,10 @@ def main():
loss = loss + args.self_distill_gamma * F.mse_loss(logits, prev_logits)
prev_logits = logits.detach()
- accelerator.backward(loss); accelerator.clip_grad_norm_(model.parameters(), 1.0)
- optimizer.step(); optimizer.zero_grad()
+ accelerator.backward(loss)
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ optimizer.zero_grad()
if accelerator.is_main_process:
if step % args.log_steps == 0: