summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md5
-rw-r--r--train.py17
2 files changed, 12 insertions, 10 deletions
diff --git a/README.md b/README.md
index 44da686..7d3b0e6 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,3 @@
-# one-shot-em
-One-shot Entropy Minimization
+## One-shot Entropy Minimization
+
+<a href='https://arxiv.org/abs/2505.20282'><img src='https://img.shields.io/badge/arXiv-2505.20282-b31b1b.svg'></a> &nbsp; \ No newline at end of file
diff --git a/train.py b/train.py
index 557b209..6413e4b 100644
--- a/train.py
+++ b/train.py
@@ -11,11 +11,11 @@ from accelerate.utils import set_seed
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, AdamW
-parser = argparse.ArgumentParser(description="Train model with configurable hyperparameters")
-parser.add_argument("--temperature", type=float, default=0.6, help="Sampling temperature")
-parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
-parser.add_argument("--effective_batch", type=int, default=64, help="Effective batch size")
-parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
+parser = argparse.ArgumentParser()
+parser.add_argument("--temperature", type=float, default=0.5)
+parser.add_argument("--lr", type=float, default=2e-5)
+parser.add_argument("--bsz", type=int)
+parser.add_argument("--seed", type=int, default=15)
args = parser.parse_args()
set_seed(args.seed)
@@ -25,7 +25,7 @@ torch.backends.cudnn.benchmark = False
temp_str = str(args.temperature).replace(".", "")
lr_str = f"{args.lr:.0e}"
bsz_str = str(args.effective_batch)
-save_root = f"/volume/ailab4sci/ztgao/em/checkpoints_32b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}"
+save_root = f"/volume/ailab4sci/ztgao/em/checkpoints/qwen25math7b/t{temp_str}_lr{lr_str}_bsz{bsz_str}_seed{args.seed}"
temperature = args.temperature
learning_rate = args.lr
@@ -55,7 +55,7 @@ accelerator = Accelerator(
)
print = accelerator.print
-model_name = "Qwen2.5-32B-Instruct"
+model_name = "Qwen2.5-Math-7B"
model_path = f"/volume/ailab4sci/models/{model_name}"
config = AutoConfig.from_pretrained(model_path)
config.use_cache = False
@@ -116,7 +116,8 @@ for step, batch in enumerate(data_loader, start=1):
temperature=temperature,
synced_gpus=True,
repetition_penalty=1.15,
- pad_token_id=151643
+ pad_token_id=151643,
+ use_cache=False
)
seq = torch.cat(