diff options
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 30 |
1 files changed, 14 insertions, 16 deletions
@@ -7,15 +7,16 @@ from pathlib import Path import psutil import torch import torch.nn.functional as F +from torch.optim import AdamW import pandas as pd import numpy as np from torch.utils.data import Dataset, DataLoader -import wandb +# import wandb from accelerate import Accelerator, DeepSpeedPlugin from accelerate.utils import set_seed -from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AdamW +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM os.environ.setdefault("NCCL_TIMEOUT", "2700") @@ -23,9 +24,9 @@ os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700") def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name') + parser.add_argument('--model_name', type=str, default='Qwen3-8B', help='Model name') parser.add_argument('--model_path', type=str, default=None, help='Local model path') - parser.add_argument('--train_data', type=str, default='', help='Training data file path') + parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path') parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory') parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size') parser.add_argument('--micro_batch_size', type=str, default=2, help='Micro batch size or "auto"') @@ -33,12 +34,12 @@ def parse_args(): parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate') parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval') parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval') - parser.add_argument('--max_steps', type=int, default=50, help='Maximum training steps') + parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps') parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter') parser.add_argument('--run_name', type=str, default=None, help='Experiment run name') - parser.add_argument('--wandb_project', type=str, default='one-shot-em', help='W&B project name') - parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name') - parser.add_argument('--seed', type=int, default=15, help='Random seed') + # parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name') + # parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name') + parser.add_argument('--seed', type=int, default=42, help='Random seed') return parser.parse_args() class FTDataset(Dataset): @@ -101,7 +102,7 @@ def main(): deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config)) print = accelerator.print - model_path = args.model_path or f"/volume/ailab4sci/models/{args.model_name}" + model_path = args.model_path or f"/volume/pt-train/models/{args.model_name}" config = AutoConfig.from_pretrained(model_path) config.use_cache = False model = AutoModelForCausalLM.from_pretrained(model_path, config=config) @@ -109,8 +110,8 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token - if accelerator.is_main_process: - wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) + # if accelerator.is_main_process: + # wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) df = pd.read_parquet(args.train_data) train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()] @@ -154,9 +155,6 @@ def main(): H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1) - if args.enable_self_distill and prev_logits is not None: - loss = loss + args.self_distill_gamma * F.mse_loss(logits, prev_logits) - prev_logits = logits.detach() accelerator.backward(loss) accelerator.clip_grad_norm_(model.parameters(), 1.0) @@ -166,7 +164,7 @@ def main(): if accelerator.is_main_process: if step % args.log_steps == 0: print(f"Step {step} | loss={loss.item():.6f}") - wandb.log({"step": step, "loss": loss.item()}) + # wandb.log({"step": step, "loss": loss.item()}) if step % args.save_steps == 0: ckpt = Path(save_root) / f"step_{step}" @@ -181,7 +179,7 @@ def main(): accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True) tokenizer.save_pretrained(final) print(f"Final checkpoint saved to {final}") - wandb.finish() + # wandb.finish() if __name__ == "__main__": main() |
