summaryrefslogtreecommitdiff
path: root/Group-Entropy-Equalization/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'Group-Entropy-Equalization/train.py')
-rw-r--r--Group-Entropy-Equalization/train.py24
1 files changed, 15 insertions, 9 deletions
diff --git a/Group-Entropy-Equalization/train.py b/Group-Entropy-Equalization/train.py
index 11f658a..d1ba4f0 100644
--- a/Group-Entropy-Equalization/train.py
+++ b/Group-Entropy-Equalization/train.py
@@ -124,6 +124,9 @@ def main():
if accelerator.is_main_process:
wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))
+ # Friendly error if the parquet path is missing
+ if not os.path.exists(args.train_data):
+ raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.")
df = pd.read_parquet(args.train_data)
train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()]
train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate)
@@ -146,15 +149,18 @@ def main():
max_length=2048).to(accelerator.device)
with torch.no_grad():
- gen_ids = accelerator.unwrap_model(model).generate(**enc,
- max_new_tokens=512,
- do_sample=True,
- top_p=0.95,
- temperature=args.sample_temp,
- synced_gpus=True,
- repetition_penalty=1.15,
- pad_token_id=tokenizer.pad_token_id,
- use_cache=False)
+ use_synced = getattr(accelerator, "num_processes", 1) and accelerator.num_processes > 1
+ gen_ids = accelerator.unwrap_model(model).generate(
+ **enc,
+ max_new_tokens=512,
+ do_sample=True,
+ top_p=0.95,
+ temperature=args.sample_temp,
+ synced_gpus=use_synced,
+ repetition_penalty=1.15,
+ pad_token_id=tokenizer.pad_token_id,
+ use_cache=False,
+ )
seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096]
pad_mask = seq.ne(tokenizer.pad_token_id)