From f21f7dd85365b10505bbd1cfa28f6a8648ba1b7e Mon Sep 17 00:00:00 2001 From: blackhao <13851610112@163.com> Date: Sat, 23 Aug 2025 13:56:30 -0500 Subject: docs: add Colab quickstart; feat: safer train data check and single-GPU generate fix --- Group-Entropy-Equalization/train.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) (limited to 'Group-Entropy-Equalization/train.py') diff --git a/Group-Entropy-Equalization/train.py b/Group-Entropy-Equalization/train.py index 11f658a..d1ba4f0 100644 --- a/Group-Entropy-Equalization/train.py +++ b/Group-Entropy-Equalization/train.py @@ -124,6 +124,9 @@ def main(): if accelerator.is_main_process: wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) + # Friendly error if the parquet path is missing + if not os.path.exists(args.train_data): + raise FileNotFoundError(f"Training data not found: {args.train_data}. Create/upload the parquet under the project folder or pass --train_data to an existing path.") df = pd.read_parquet(args.train_data) train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()] train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate) @@ -146,15 +149,18 @@ def main(): max_length=2048).to(accelerator.device) with torch.no_grad(): - gen_ids = accelerator.unwrap_model(model).generate(**enc, - max_new_tokens=512, - do_sample=True, - top_p=0.95, - temperature=args.sample_temp, - synced_gpus=True, - repetition_penalty=1.15, - pad_token_id=tokenizer.pad_token_id, - use_cache=False) + use_synced = getattr(accelerator, "num_processes", 1) and accelerator.num_processes > 1 + gen_ids = accelerator.unwrap_model(model).generate( + **enc, + max_new_tokens=512, + do_sample=True, + top_p=0.95, + temperature=args.sample_temp, + synced_gpus=use_synced, + repetition_penalty=1.15, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + ) seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096] pad_mask = seq.ne(tokenizer.pad_token_id) -- cgit v1.2.3