summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore7
-rw-r--r--README.md2
-rw-r--r--dataset/1shot_rlvr/pi1_r128.parquetbin0 -> 7221 bytes
-rw-r--r--dataset/1shot_rlvr/pi1_r1280.parquetbin0 -> 7222 bytes
-rw-r--r--dataset/numina/numina_00.parquetbin0 -> 48471723 bytes
-rw-r--r--dataset/numina/numina_01.parquetbin0 -> 48531583 bytes
-rw-r--r--dataset/numina/numina_02.parquetbin0 -> 48550309 bytes
-rw-r--r--dataset/numina/numina_03.parquetbin0 -> 48543544 bytes
-rw-r--r--dataset/numina/numina_04.parquetbin0 -> 48406308 bytes
-rw-r--r--dataset/numina/numina_05.parquetbin0 -> 48623275 bytes
-rw-r--r--dataset/numina/numina_06.parquetbin0 -> 48233835 bytes
-rw-r--r--dataset/numina/numina_07.parquetbin0 -> 48385429 bytes
-rw-r--r--dataset/numina/numina_08.parquetbin0 -> 48653268 bytes
-rw-r--r--dataset/numina/numina_09.parquetbin0 -> 48536371 bytes
-rw-r--r--train.py155
15 files changed, 164 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..556fac0
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,7 @@
+checkpoints/*
+wandb/*
+archived/*
+*.log
+*.sh
+*.ipynb
+log/* \ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..44da686
--- /dev/null
+++ b/README.md
@@ -0,0 +1,2 @@
+# one-shot-em
+One-shot Entropy Minimization
diff --git a/dataset/1shot_rlvr/pi1_r128.parquet b/dataset/1shot_rlvr/pi1_r128.parquet
new file mode 100644
index 0000000..bdf9a0f
--- /dev/null
+++ b/dataset/1shot_rlvr/pi1_r128.parquet
Binary files differ
diff --git a/dataset/1shot_rlvr/pi1_r1280.parquet b/dataset/1shot_rlvr/pi1_r1280.parquet
new file mode 100644
index 0000000..faca841
--- /dev/null
+++ b/dataset/1shot_rlvr/pi1_r1280.parquet
Binary files differ
diff --git a/dataset/numina/numina_00.parquet b/dataset/numina/numina_00.parquet
new file mode 100644
index 0000000..265fbcb
--- /dev/null
+++ b/dataset/numina/numina_00.parquet
Binary files differ
diff --git a/dataset/numina/numina_01.parquet b/dataset/numina/numina_01.parquet
new file mode 100644
index 0000000..80d7ec4
--- /dev/null
+++ b/dataset/numina/numina_01.parquet
Binary files differ
diff --git a/dataset/numina/numina_02.parquet b/dataset/numina/numina_02.parquet
new file mode 100644
index 0000000..943969f
--- /dev/null
+++ b/dataset/numina/numina_02.parquet
Binary files differ
diff --git a/dataset/numina/numina_03.parquet b/dataset/numina/numina_03.parquet
new file mode 100644
index 0000000..0336392
--- /dev/null
+++ b/dataset/numina/numina_03.parquet
Binary files differ
diff --git a/dataset/numina/numina_04.parquet b/dataset/numina/numina_04.parquet
new file mode 100644
index 0000000..ee310b8
--- /dev/null
+++ b/dataset/numina/numina_04.parquet
Binary files differ
diff --git a/dataset/numina/numina_05.parquet b/dataset/numina/numina_05.parquet
new file mode 100644
index 0000000..5fad140
--- /dev/null
+++ b/dataset/numina/numina_05.parquet
Binary files differ
diff --git a/dataset/numina/numina_06.parquet b/dataset/numina/numina_06.parquet
new file mode 100644
index 0000000..90f1267
--- /dev/null
+++ b/dataset/numina/numina_06.parquet
Binary files differ
diff --git a/dataset/numina/numina_07.parquet b/dataset/numina/numina_07.parquet
new file mode 100644
index 0000000..6483a5e
--- /dev/null
+++ b/dataset/numina/numina_07.parquet
Binary files differ
diff --git a/dataset/numina/numina_08.parquet b/dataset/numina/numina_08.parquet
new file mode 100644
index 0000000..4817f3a
--- /dev/null
+++ b/dataset/numina/numina_08.parquet
Binary files differ
diff --git a/dataset/numina/numina_09.parquet b/dataset/numina/numina_09.parquet
new file mode 100644
index 0000000..6d69876
--- /dev/null
+++ b/dataset/numina/numina_09.parquet
Binary files differ
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..557b209
--- /dev/null
+++ b/train.py
@@ -0,0 +1,155 @@
+import argparse
+import random
+import string
+from pathlib import Path
+import torch.nn.functional as F
+import os, math, torch, pandas as pd
+import numpy as np
+from torch.utils.data import Dataset, DataLoader
+from accelerate import Accelerator, DeepSpeedPlugin
+from accelerate.utils import set_seed
+from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AdamW
+
+
+parser = argparse.ArgumentParser(description="Train model with configurable hyperparameters")
+parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature")
+parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
+parser.add_argument("--effective_batch", type=int, default=64, help="Effective batch size")
+parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
+args = parser.parse_args()
+
+set_seed(args.seed)
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+
+temp_str = str(args.temperature).replace(".", "")
+lr_str = f"{args.lr:.0e}"
+bsz_str = str(args.effective_batch)
+save_root = f"/volume/ailab4sci/ztgao/em/checkpoints_32b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}"
+
+temperature = args.temperature
+learning_rate = args.lr
+effective_batch = args.effective_batch
+micro_batch_size = 2
+world_size = int(os.environ.get("WORLD_SIZE", 1))
+accum_steps = max(1, effective_batch // (micro_batch_size * world_size))
+
+DEEPSPEED_CONFIG = {
+ "train_micro_batch_size_per_gpu": micro_batch_size,
+ "train_batch_size": effective_batch,
+ "gradient_accumulation_steps": accum_steps,
+ "bf16": {"enabled": True},
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {"device": "cpu"},
+ "offload_param": {"device": "none"}
+ },
+ "gradient_clipping": 1.0,
+}
+
+ds_plugin = DeepSpeedPlugin(hf_ds_config=DEEPSPEED_CONFIG)
+accelerator = Accelerator(
+ mixed_precision="bf16",
+ gradient_accumulation_steps=accum_steps,
+ deepspeed_plugin=ds_plugin,
+)
+print = accelerator.print
+
+model_name = "Qwen2.5-32B-Instruct"
+model_path = f"/volume/ailab4sci/models/{model_name}"
+config = AutoConfig.from_pretrained(model_path)
+config.use_cache = False
+model = AutoModelForCausalLM.from_pretrained(model_path, config=config)
+model.gradient_checkpointing_enable()
+
+tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
+if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+class FTDataset(Dataset):
+ def __init__(self, rows): self.rows = rows
+ def __len__(self): return len(self.rows)
+ def __getitem__(self, idx): return self.rows[idx]
+
+def custom_collate(batch):
+ return {"input": [item["input"] for item in batch]}
+
+def apply_chat_template(problem: str) -> str:
+ return tokenizer.apply_chat_template(
+ [{"role": "user", "content": problem}],
+ tokenize=False, add_generation_prompt=True
+ )
+
+df = pd.read_parquet("/volume/ailab4sci/ztgao/em/dataset/1shot_rlvr/pi1_r1280.parquet")
+data = [{"input": apply_chat_template(p)} for p in df["problem"].dropna().tolist()]
+dataset = FTDataset(data)
+data_loader = DataLoader(
+ dataset,
+ batch_size=micro_batch_size,
+ shuffle=True,
+ collate_fn=custom_collate,
+)
+
+optimizer = AdamW(model.parameters(), lr=learning_rate)
+model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
+
+model.train()
+for step, batch in enumerate(data_loader, start=1):
+ if step > 30:
+ break
+
+ with accelerator.accumulate(model):
+ enc = tokenizer(
+ batch["input"],
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=2048,
+ ).to(accelerator.device)
+
+ with torch.no_grad():
+ gen_ids = accelerator.unwrap_model(model).generate(
+ **enc,
+ max_new_tokens=2048,
+ do_sample=True,
+ top_p=0.95,
+ temperature=temperature,
+ synced_gpus=True,
+ repetition_penalty=1.15,
+ pad_token_id=151643
+ )
+
+ seq = torch.cat(
+ [enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]],
+ dim=1
+ )[:, :4096]
+
+ pad_mask = seq.ne(tokenizer.pad_token_id)
+ prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1)
+ token_idx = torch.arange(seq.size(1), device=seq.device)
+ gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask
+
+ logits = model(seq, attention_mask=pad_mask).logits
+ probs = F.softmax(logits / temperature, dim=-1)
+ H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1)
+ loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1)
+
+ accelerator.backward(loss)
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step(); optimizer.zero_grad()
+ print(f"Step {step} | loss={loss.item():.8f}")
+
+ if accelerator.is_main_process:
+ ckpt_dir = Path(save_root) / f"step_{step}"
+ ckpt_dir.mkdir(parents=True, exist_ok=True)
+ accelerator.unwrap_model(model).save_pretrained(ckpt_dir, safe_serialization=True)
+ tokenizer.save_pretrained(ckpt_dir)
+ accelerator.wait_for_everyone()
+ print(f"Checkpoint saved to {ckpt_dir}")
+
+if accelerator.is_main_process:
+ final_dir = Path(save_root) / "final"
+ final_dir.mkdir(parents=True, exist_ok=True)
+ accelerator.unwrap_model(model).save_pretrained(final_dir, safe_serialization=True)
+ tokenizer.save_pretrained(final_dir)
+ print(f"Final checkpoint saved to {final_dir}") \ No newline at end of file