diff options
Diffstat (limited to 'Group-Entropy-Equalization/train.py')
| -rw-r--r-- | Group-Entropy-Equalization/train.py | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/Group-Entropy-Equalization/train.py b/Group-Entropy-Equalization/train.py index 11f658a..d1ba4f0 100644 --- a/Group-Entropy-Equalization/train.py +++ b/Group-Entropy-Equalization/train.py @@ -124,6 +124,9 @@ def main(): if accelerator.is_main_process: wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) + # Friendly error if the parquet path is missing + if not os.path.exists(args.train_data): + raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.") df = pd.read_parquet(args.train_data) train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()] train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate) @@ -146,15 +149,18 @@ def main(): max_length=2048).to(accelerator.device) with torch.no_grad(): - gen_ids = accelerator.unwrap_model(model).generate(**enc, - max_new_tokens=512, - do_sample=True, - top_p=0.95, - temperature=args.sample_temp, - synced_gpus=True, - repetition_penalty=1.15, - pad_token_id=tokenizer.pad_token_id, - use_cache=False) + use_synced = getattr(accelerator, "num_processes", 1) and accelerator.num_processes > 1 + gen_ids = accelerator.unwrap_model(model).generate( + **enc, + max_new_tokens=512, + do_sample=True, + top_p=0.95, + temperature=args.sample_temp, + synced_gpus=use_synced, + repetition_penalty=1.15, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + ) seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096] pad_mask = seq.ne(tokenizer.pad_token_id) |
