diff options
| -rw-r--r-- | .gitignore | 7 | ||||
| -rw-r--r-- | README.md | 2 | ||||
| -rw-r--r-- | dataset/1shot_rlvr/pi1_r128.parquet | bin | 0 -> 7221 bytes | |||
| -rw-r--r-- | dataset/1shot_rlvr/pi1_r1280.parquet | bin | 0 -> 7222 bytes | |||
| -rw-r--r-- | dataset/numina/numina_00.parquet | bin | 0 -> 48471723 bytes | |||
| -rw-r--r-- | dataset/numina/numina_01.parquet | bin | 0 -> 48531583 bytes | |||
| -rw-r--r-- | dataset/numina/numina_02.parquet | bin | 0 -> 48550309 bytes | |||
| -rw-r--r-- | dataset/numina/numina_03.parquet | bin | 0 -> 48543544 bytes | |||
| -rw-r--r-- | dataset/numina/numina_04.parquet | bin | 0 -> 48406308 bytes | |||
| -rw-r--r-- | dataset/numina/numina_05.parquet | bin | 0 -> 48623275 bytes | |||
| -rw-r--r-- | dataset/numina/numina_06.parquet | bin | 0 -> 48233835 bytes | |||
| -rw-r--r-- | dataset/numina/numina_07.parquet | bin | 0 -> 48385429 bytes | |||
| -rw-r--r-- | dataset/numina/numina_08.parquet | bin | 0 -> 48653268 bytes | |||
| -rw-r--r-- | dataset/numina/numina_09.parquet | bin | 0 -> 48536371 bytes | |||
| -rw-r--r-- | train.py | 155 |
15 files changed, 164 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..556fac0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +checkpoints/* +wandb/* +archived/* +*.log +*.sh +*.ipynb +log/*
\ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..44da686 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# one-shot-em +One-shot Entropy Minimization diff --git a/dataset/1shot_rlvr/pi1_r128.parquet b/dataset/1shot_rlvr/pi1_r128.parquet Binary files differnew file mode 100644 index 0000000..bdf9a0f --- /dev/null +++ b/dataset/1shot_rlvr/pi1_r128.parquet diff --git a/dataset/1shot_rlvr/pi1_r1280.parquet b/dataset/1shot_rlvr/pi1_r1280.parquet Binary files differnew file mode 100644 index 0000000..faca841 --- /dev/null +++ b/dataset/1shot_rlvr/pi1_r1280.parquet diff --git a/dataset/numina/numina_00.parquet b/dataset/numina/numina_00.parquet Binary files differnew file mode 100644 index 0000000..265fbcb --- /dev/null +++ b/dataset/numina/numina_00.parquet diff --git a/dataset/numina/numina_01.parquet b/dataset/numina/numina_01.parquet Binary files differnew file mode 100644 index 0000000..80d7ec4 --- /dev/null +++ b/dataset/numina/numina_01.parquet diff --git a/dataset/numina/numina_02.parquet b/dataset/numina/numina_02.parquet Binary files differnew file mode 100644 index 0000000..943969f --- /dev/null +++ b/dataset/numina/numina_02.parquet diff --git a/dataset/numina/numina_03.parquet b/dataset/numina/numina_03.parquet Binary files differnew file mode 100644 index 0000000..0336392 --- /dev/null +++ b/dataset/numina/numina_03.parquet diff --git a/dataset/numina/numina_04.parquet b/dataset/numina/numina_04.parquet Binary files differnew file mode 100644 index 0000000..ee310b8 --- /dev/null +++ b/dataset/numina/numina_04.parquet diff --git a/dataset/numina/numina_05.parquet b/dataset/numina/numina_05.parquet Binary files differnew file mode 100644 index 0000000..5fad140 --- /dev/null +++ b/dataset/numina/numina_05.parquet diff --git a/dataset/numina/numina_06.parquet b/dataset/numina/numina_06.parquet Binary files differnew file mode 100644 index 0000000..90f1267 --- /dev/null +++ b/dataset/numina/numina_06.parquet diff --git a/dataset/numina/numina_07.parquet b/dataset/numina/numina_07.parquet Binary files differnew file mode 100644 index 0000000..6483a5e --- /dev/null +++ b/dataset/numina/numina_07.parquet diff --git a/dataset/numina/numina_08.parquet b/dataset/numina/numina_08.parquet Binary files differnew file mode 100644 index 0000000..4817f3a --- /dev/null +++ b/dataset/numina/numina_08.parquet diff --git a/dataset/numina/numina_09.parquet b/dataset/numina/numina_09.parquet Binary files differnew file mode 100644 index 0000000..6d69876 --- /dev/null +++ b/dataset/numina/numina_09.parquet diff --git a/train.py b/train.py new file mode 100644 index 0000000..557b209 --- /dev/null +++ b/train.py @@ -0,0 +1,155 @@ +import argparse +import random +import string +from pathlib import Path +import torch.nn.functional as F +import os, math, torch, pandas as pd +import numpy as np +from torch.utils.data import Dataset, DataLoader +from accelerate import Accelerator, DeepSpeedPlugin +from accelerate.utils import set_seed +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AdamW + + +parser = argparse.ArgumentParser(description="Train model with configurable hyperparameters") +parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature") +parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate") +parser.add_argument("--effective_batch", type=int, default=64, help="Effective batch size") +parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") +args = parser.parse_args() + +set_seed(args.seed) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +temp_str = str(args.temperature).replace(".", "") +lr_str = f"{args.lr:.0e}" +bsz_str = str(args.effective_batch) +save_root = f"/volume/ailab4sci/ztgao/em/checkpoints_32b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}" + +temperature = args.temperature +learning_rate = args.lr +effective_batch = args.effective_batch +micro_batch_size = 2 +world_size = int(os.environ.get("WORLD_SIZE", 1)) +accum_steps = max(1, effective_batch // (micro_batch_size * world_size)) + +DEEPSPEED_CONFIG = { + "train_micro_batch_size_per_gpu": micro_batch_size, + "train_batch_size": effective_batch, + "gradient_accumulation_steps": accum_steps, + "bf16": {"enabled": True}, + "zero_optimization": { + "stage": 2, + "offload_optimizer": {"device": "cpu"}, + "offload_param": {"device": "none"} + }, + "gradient_clipping": 1.0, +} + +ds_plugin = DeepSpeedPlugin(hf_ds_config=DEEPSPEED_CONFIG) +accelerator = Accelerator( + mixed_precision="bf16", + gradient_accumulation_steps=accum_steps, + deepspeed_plugin=ds_plugin, +) +print = accelerator.print + +model_name = "Qwen2.5-32B-Instruct" +model_path = f"/volume/ailab4sci/models/{model_name}" +config = AutoConfig.from_pretrained(model_path) +config.use_cache = False +model = AutoModelForCausalLM.from_pretrained(model_path, config=config) +model.gradient_checkpointing_enable() + +tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") +if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + +class FTDataset(Dataset): + def __init__(self, rows): self.rows = rows + def __len__(self): return len(self.rows) + def __getitem__(self, idx): return self.rows[idx] + +def custom_collate(batch): + return {"input": [item["input"] for item in batch]} + +def apply_chat_template(problem: str) -> str: + return tokenizer.apply_chat_template( + [{"role": "user", "content": problem}], + tokenize=False, add_generation_prompt=True + ) + +df = pd.read_parquet("/volume/ailab4sci/ztgao/em/dataset/1shot_rlvr/pi1_r1280.parquet") +data = [{"input": apply_chat_template(p)} for p in df["problem"].dropna().tolist()] +dataset = FTDataset(data) +data_loader = DataLoader( + dataset, + batch_size=micro_batch_size, + shuffle=True, + collate_fn=custom_collate, +) + +optimizer = AdamW(model.parameters(), lr=learning_rate) +model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader) + +model.train() +for step, batch in enumerate(data_loader, start=1): + if step > 30: + break + + with accelerator.accumulate(model): + enc = tokenizer( + batch["input"], + return_tensors="pt", + padding="longest", + truncation=True, + max_length=2048, + ).to(accelerator.device) + + with torch.no_grad(): + gen_ids = accelerator.unwrap_model(model).generate( + **enc, + max_new_tokens=2048, + do_sample=True, + top_p=0.95, + temperature=temperature, + synced_gpus=True, + repetition_penalty=1.15, + pad_token_id=151643 + ) + + seq = torch.cat( + [enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], + dim=1 + )[:, :4096] + + pad_mask = seq.ne(tokenizer.pad_token_id) + prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1) + token_idx = torch.arange(seq.size(1), device=seq.device) + gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask + + logits = model(seq, attention_mask=pad_mask).logits + probs = F.softmax(logits / temperature, dim=-1) + H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) + loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1) + + accelerator.backward(loss) + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step(); optimizer.zero_grad() + print(f"Step {step} | loss={loss.item():.8f}") + + if accelerator.is_main_process: + ckpt_dir = Path(save_root) / f"step_{step}" + ckpt_dir.mkdir(parents=True, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(ckpt_dir, safe_serialization=True) + tokenizer.save_pretrained(ckpt_dir) + accelerator.wait_for_everyone() + print(f"Checkpoint saved to {ckpt_dir}") + +if accelerator.is_main_process: + final_dir = Path(save_root) / "final" + final_dir.mkdir(parents=True, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(final_dir, safe_serialization=True) + tokenizer.save_pretrained(final_dir) + print(f"Final checkpoint saved to {final_dir}")
\ No newline at end of file |
