summaryrefslogtreecommitdiff
path: root/Group-Entropy-Equalization
diff options
context:
space:
mode:
authorblackhao <13851610112@163.com>2025-08-23 13:35:13 -0500
committerblackhao <13851610112@163.com>2025-08-23 13:35:13 -0500
commit4f81a87ef95b190450ed5202bfa725dbb0a539f4 (patch)
tree875f5966cdaaa526d85ff49a13cd6bf27ab4a723 /Group-Entropy-Equalization
parentad3e216afd066375219ef8b3928ef4096237fbf6 (diff)
init
Diffstat (limited to 'Group-Entropy-Equalization')
-rw-r--r--Group-Entropy-Equalization/README.md90
-rw-r--r--Group-Entropy-Equalization/requirements.txt280
-rw-r--r--Group-Entropy-Equalization/train.py197
3 files changed, 567 insertions, 0 deletions
diff --git a/Group-Entropy-Equalization/README.md b/Group-Entropy-Equalization/README.md
new file mode 100644
index 0000000..804af95
--- /dev/null
+++ b/Group-Entropy-Equalization/README.md
@@ -0,0 +1,90 @@
+# One-shot Entropy Minimization
+
+[![paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/abs/2505.20282)
+[![Model](https://img.shields.io/badge/Models/Dataset-fcd022?style=for-the-badge&logo=huggingface&logoColor=000)](https://huggingface.co/zgao3186/qwen25math7b-one-shot-em/)
+[![Notion](https://img.shields.io/badge/Site-000000.svg?style=for-the-badge&logo=notion&logoColor=white)](https://www.notion.so/One-shot-Entropy-Minimization-202606db813b80639773f850f39246a5)
+
+### Installation
+
+```bash
+conda create -n one-shot-em python=3.10 -y
+pip install -r requirements.txt
+```
+
+---
+
+### Reproducing One-shot EM Training (SOTA)
+
+```bash
+accelerate launch train.py \
+ --model_name Qwen2.5-Math-7B \
+ --model_path /path/to/Qwen2.5-Math-7B \
+ --train_data dataset/1shot_rlvr/pi1_r1280.parquet \
+ --effective_batch 64 \
+ --micro_batch_size 2 \
+ --temperature 0.5 \
+ --learning_rate 2e-5 \
+ --max_steps 50 \
+ --log_steps 1 \
+ --save_steps 1 \
+ --run_name one_shot \
+ --wandb_project one-shot-em
+```
+
+---
+
+### Reproducing Multi-shot EM Training
+
+```bash
+accelerate launch train.py \
+ --model_name Qwen2.5-Math-7B \
+ --model_path /path/to/Qwen2.5-Math-7B \
+ --train_data dataset/numina/numina_00.parquet \
+ --effective_batch 64 \
+ --micro_batch_size 2 \
+ --temperature 0.5 \
+ --learning_rate 2e-5 \
+ --max_steps 50 \
+ --log_steps 1 \
+ --save_steps 1 \
+ --run_name multi_shot \
+ --wandb_project one-shot-em
+```
+
+---
+
+### Evaluation
+
+```bash
+cd Qwen2.5-Eval/evaluation
+bash sh/eval_all_math.sh
+```
+
+---
+
+### Acknowledgements
+
+Our dataset references and builds upon the following open-source contributions:
+
+- [NuminaMath-CoT](https://huggingface.co/datasets/AI-MO/NuminaMath-CoT)
+- [DeepScaler](https://github.com/agentica-project/deepscaler)
+- [One-shot RLVR](https://github.com/ypwang61/One-Shot-RLVR/) – for data selection strategies
+- [Qwen2.5-Eval](https://github.com/QwenLM/Qwen2.5-Math/) – for evaluation benchmarks
+
+We sincerely thank the authors and maintainers of these projects for their excellent contributions to the research community!
+
+
+---
+
+### Citation
+```
+@misc{gao2025oneshotentropyminimization,
+ title={One-shot Entropy Minimization},
+ author={Zitian Gao and Lynx Chen and Haoming Luo and Joey Zhou and Bryan Dai},
+ year={2025},
+ eprint={2505.20282},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL},
+ url={https://arxiv.org/abs/2505.20282},
+}
+```
diff --git a/Group-Entropy-Equalization/requirements.txt b/Group-Entropy-Equalization/requirements.txt
new file mode 100644
index 0000000..7fb330a
--- /dev/null
+++ b/Group-Entropy-Equalization/requirements.txt
@@ -0,0 +1,280 @@
+absl-py==2.1.0
+accelerate==0.33.0
+aiofiles==23.2.1
+annotated-types==0.6.0
+anyio==4.4.0
+argon2-cffi==23.1.0
+argon2-cffi-bindings==21.2.0
+arrow==1.3.0
+asttokens==2.4.1
+astunparse==1.6.3
+async-lru==2.0.4
+attrs==23.2.0
+audioread==3.0.1
+babel==2.16.0
+beautifulsoup4==4.12.3
+bitsandbytes==0.43.3
+bleach==6.1.0
+blis==0.7.11
+cachetools==5.3.2
+catalogue==2.0.10
+certifi==2024.2.2
+cffi==1.16.0
+charset-normalizer==3.3.2
+click==8.1.7
+cloudpathlib==0.16.0
+cmake==3.28.1
+comm==0.2.1
+confection==0.1.4
+contourpy==1.2.0
+cycler==0.12.1
+cymem==2.0.8
+Cython==3.0.8
+datasets==2.21.0
+debugpy==1.8.1
+decorator==5.1.1
+deepspeed==0.14.4
+defusedxml==0.7.1
+dill==0.3.8
+diskcache==5.6.3
+distro==1.9.0
+dm-tree==0.1.8
+docstring_parser==0.16
+einops==0.7.0
+exceptiongroup==1.2.0
+execnet==2.0.2
+executing==2.0.1
+expecttest==0.1.3
+fastapi==0.112.2
+fastjsonschema==2.19.1
+ffmpy==0.4.0
+filelock==3.13.1
+fire==0.6.0
+flash-attn==2.6.3
+fonttools==4.48.1
+fqdn==1.5.1
+gast==0.5.4
+gguf==0.9.1
+google-auth==2.27.0
+google-auth-oauthlib==0.4.6
+gradio==4.42.0
+gradio_client==1.3.0
+grpcio==1.60.1
+h11==0.14.0
+hjson==3.1.0
+httpcore==1.0.5
+httptools==0.6.1
+httpx==0.27.2
+huggingface-hub==0.24.6
+hypothesis==5.35.1
+idna==3.6
+importlib_resources==6.4.4
+iniconfig==2.0.0
+intel-openmp==2021.4.0
+interegular==0.3.3
+ipykernel==6.29.2
+ipython==8.21.0
+ipython-genutils==0.2.0
+isoduration==20.11.0
+jedi==0.19.1
+jieba==0.42.1
+Jinja2==3.1.3
+jiter==0.5.0
+joblib==1.3.2
+json5==0.9.14
+jsonpointer==3.0.0
+jsonschema==4.21.1
+jsonschema-specifications==2023.12.1
+jupyter-events==0.10.0
+jupyter-lsp==2.2.5
+jupyter_client==8.6.0
+jupyter_core==5.7.1
+jupyter_server==2.14.2
+jupyter_server_terminals==0.5.3
+jupyterlab==4.1.6
+jupyterlab_pygments==0.3.0
+jupyterlab_server==2.27.3
+jupytext==1.16.1
+kiwisolver==1.4.5
+langcodes==3.3.0
+lark==1.2.2
+lazy_loader==0.3
+librosa==0.10.1
+lm-format-enforcer==0.10.6
+Markdown==3.5.2
+markdown-it-py==3.0.0
+matplotlib==3.8.2
+matplotlib-inline==0.1.6
+mdit-py-plugins==0.4.0
+mdurl==0.1.2
+mistral_common==1.3.4
+mistune==3.0.2
+mkl==2021.1.1
+mkl-devel==2021.1.1
+mkl-include==2021.1.1
+mock==5.1.0
+mpmath==1.3.0
+msgpack==1.0.7
+msgspec==0.18.6
+multiprocess==0.70.16
+murmurhash==1.0.10
+nbclient==0.9.0
+nbconvert==7.16.0
+nbformat==5.9.2
+nest-asyncio==1.6.0
+networkx==2.6.3
+ninja==1.11.1.1
+nltk==3.9.1
+notebook==6.4.10
+notebook_shim==0.2.4
+numpy==1.24.4
+nvfuser==0.1.4a0+d0bb811
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==9.1.0.70
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-dali-cuda120==1.34.0
+nvidia-ml-py==12.560.30
+nvidia-nccl-cu12==2.20.5
+nvidia-nvjitlink-cu12==12.6.68
+nvidia-nvtx-cu12==12.1.105
+nvidia-pyindex==1.0.9
+nvitop==1.5.1
+oauthlib==3.2.2
+openai==1.43.0
+optree==0.10.0
+orjson==3.10.7
+outlines==0.0.46
+overrides==7.7.0
+packaging==23.2
+pandas==2.2.2
+pandocfilters==1.5.1
+parso==0.8.3
+partial-json-parser==0.2.1.1.post4
+peft==0.12.0
+pexpect==4.9.0
+platformdirs==4.2.0
+pluggy==1.4.0
+polygraphy==0.49.4
+pooch==1.8.0
+preshed==3.0.9
+prettytable==3.9.0
+prometheus-client==0.19.0
+prometheus-fastapi-instrumentator==7.0.0
+prompt-toolkit==3.0.43
+protobuf==4.24.4
+ptyprocess==0.7.0
+pure-eval==0.2.2
+py-cpuinfo==9.0.0
+pyairports==2.1.1
+pyarrow==17.0.0
+pyasn1==0.5.1
+pyasn1-modules==0.3.0
+pybind11==2.11.1
+pybind11-global==2.11.1
+pycountry==24.6.1
+pycparser==2.21
+pydantic==2.8.2
+pydantic_core==2.20.1
+pydub==0.25.1
+Pygments==2.17.2
+PyJWT==2.8.0
+pyparsing==3.1.1
+pytest==8.0.0
+pytest-flakefinder==1.1.0
+pytest-rerunfailures==13.0
+pytest-shard==0.1.2
+pytest-xdist==3.5.0
+python-dateutil==2.8.2
+python-dotenv==1.0.1
+python-hostlist==1.23.0
+python-json-logger==2.0.7
+python-multipart==0.0.9
+pytorch-quantization==2.1.2
+PyYAML==6.0.1
+pyzmq==25.1.2
+ray==2.35.0
+referencing==0.33.0
+regex==2023.12.25
+requests==2.32.3
+requests-oauthlib==1.3.1
+rfc3339-validator==0.1.4
+rfc3986-validator==0.1.1
+rouge-chinese==1.0.3
+rpds-py==0.17.1
+rsa==4.9
+ruff==0.6.3
+safetensors==0.4.4
+semantic-version==2.10.0
+Send2Trash==1.8.2
+sentencepiece==0.2.0
+shellingham==1.5.4
+shtab==1.7.1
+six==1.16.0
+smart-open==6.4.0
+sniffio==1.3.1
+sortedcontainers==2.4.0
+soundfile==0.12.1
+soupsieve==2.5
+soxr==0.3.7
+spacy==3.7.2
+spacy-legacy==3.0.12
+spacy-loggers==1.0.5
+sphinx_glpi_theme==0.6
+srsly==2.4.8
+sse-starlette==2.1.3
+stack-data==0.6.3
+starlette==0.38.4
+sympy==1.12
+tabulate==0.9.0
+tbb==2021.11.0
+tensorboard==2.9.0
+tensorboard-data-server==0.6.1
+tensorboard-plugin-wit==1.8.1
+termcolor==2.4.0
+terminado==0.18.0
+thinc==8.2.3
+threadpoolctl==3.2.0
+tiktoken==0.7.0
+tinycss2==1.2.1
+tokenizers==0.19.1
+toml==0.10.2
+tomli==2.0.1
+tomlkit==0.12.0
+torch==2.4.0
+torchvision==0.19.0
+tornado==6.4
+tqdm==4.66.5
+traitlets==5.9.0
+transformers==4.44.2
+triton==3.0.0
+trl==0.9.6
+typer==0.12.5
+types-dataclasses==0.6.6
+types-python-dateutil==2.9.0.20240821
+typing_extensions==4.12.2
+tyro==0.8.10
+tzdata==2024.1
+uri-template==1.3.0
+urllib3==2.2.2
+uvicorn==0.30.6
+uvloop==0.20.0
+vllm==0.6.0
+vllm-flash-attn==2.6.1
+wasabi==1.1.2
+watchfiles==0.24.0
+wcwidth==0.2.13
+weasel==0.3.4
+webcolors==24.8.0
+webencodings==0.5.1
+websocket-client==1.8.0
+websockets==12.0
+Werkzeug==3.0.1
+xdoctest==1.0.2
+xformers==0.0.27.post2
+xxhash==3.5.0
diff --git a/Group-Entropy-Equalization/train.py b/Group-Entropy-Equalization/train.py
new file mode 100644
index 0000000..11f658a
--- /dev/null
+++ b/Group-Entropy-Equalization/train.py
@@ -0,0 +1,197 @@
+import argparse
+import os
+import random
+import time
+from pathlib import Path
+
+import psutil
+import torch
+import torch.nn.functional as F
+from torch.optim import AdamW
+import pandas as pd
+import numpy as np
+from torch.utils.data import Dataset, DataLoader
+
+import wandb
+
+from accelerate import Accelerator, DeepSpeedPlugin
+from accelerate.utils import set_seed
+from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
+
+
+os.environ.setdefault("NCCL_TIMEOUT", "2700")
+os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700")
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name')
+ parser.add_argument('--model_path', type=str, default=None, help='Local model path')
+ parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path')
+ parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory')
+ parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size')
+ parser.add_argument('--micro_batch_size', type=str, default=2, help='Micro batch size or "auto"')
+ parser.add_argument('--temperature', type=float, default=0.5, help='Temperature coefficient')
+ parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
+ parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval')
+ parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval')
+ parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps')
+ parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter')
+ parser.add_argument('--run_name', type=str, default=None, help='Experiment run name')
+ parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name')
+ parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name')
+ parser.add_argument('--seed', type=int, default=15, help='Random seed')
+ parser.add_argument('--no_deepspeed', action='store_true', help='Disable DeepSpeed and use plain Accelerator (Colab-friendly)')
+ parser.add_argument('--mixed_precision', type=str, default='bf16', choices=['bf16', 'fp16', 'no'], help='Mixed precision mode')
+ return parser.parse_args()
+
+class FTDataset(Dataset):
+ def __init__(self, rows): self.rows = rows
+ def __len__(self): return len(self.rows)
+ def __getitem__(self, idx): return self.rows[idx]
+
+def custom_collate(batch):
+ return {"input": [item["input"] for item in batch]}
+
+def get_optimal_micro_batch_size(model_name: str, world_size: int = 1) -> int:
+ model_configs = {
+ "1.5B": {"base_batch": 4, "keywords": ["1.5B", "1B"]},
+ "2B": {"base_batch": 4, "keywords": ["2B"]},
+ "3B": {"base_batch": 2, "keywords": ["3B"]},
+ "7B": {"base_batch": 2, "keywords": ["7B"]},
+ "8B+": {"base_batch": 1, "keywords": ["8B", "9B", "10B", "11B", "12B", "13B", "14B"]},
+ }
+ model_name_upper = model_name.upper()
+ detected = next((cfg for cfg in model_configs.values() if any(k in model_name_upper for k in cfg["keywords"])), None)
+ base_batch = detected["base_batch"] if detected else 2
+ if world_size > 1:
+ return min(base_batch + 1, int(base_batch * 1.5))
+ return base_batch
+
+def apply_chat_template(tokenizer, problem: str) -> str:
+ return tokenizer.apply_chat_template(
+ [{"role": "user", "content": problem}],
+ tokenize=False, add_generation_prompt=True
+ )
+
+def main():
+ args = parse_args()
+ set_seed(args.seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
+ micro_bs = int(args.micro_batch_size)
+ eff_bs = args.effective_batch
+ accum_steps = max(1, eff_bs // (micro_bs * world_size))
+ temp = args.temperature
+ lr = args.learning_rate
+
+ save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}")
+ # Resolve mixed precision automatically if requested bf16 is unsupported
+ mp = args.mixed_precision
+ if mp == "bf16":
+ if not torch.cuda.is_available() or not torch.cuda.is_bf16_supported():
+ mp = "fp16" if torch.cuda.is_available() else "no"
+
+ if args.no_deepspeed:
+ accelerator = Accelerator(mixed_precision=mp, gradient_accumulation_steps=accum_steps)
+ else:
+ ds_config = {
+ "train_micro_batch_size_per_gpu": micro_bs,
+ "train_batch_size": eff_bs,
+ "gradient_accumulation_steps": accum_steps,
+ "bf16": {"enabled": mp == "bf16"},
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {"device": "cpu"},
+ "offload_param": {"device": "none"}
+ },
+ "gradient_clipping": 1.0,
+ }
+ accelerator = Accelerator(mixed_precision=mp,
+ gradient_accumulation_steps=accum_steps,
+ deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config))
+ print = accelerator.print
+
+ model_path = args.model_path or f"/volume/pt-train/models/{args.model_name}"
+ config = AutoConfig.from_pretrained(model_path)
+ config.use_cache = False
+ model = AutoModelForCausalLM.from_pretrained(model_path, config=config)
+ model.gradient_checkpointing_enable()
+ tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
+ tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
+
+ if accelerator.is_main_process:
+ wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args))
+
+ df = pd.read_parquet(args.train_data)
+ train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()]
+ train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate)
+
+ optimizer = AdamW(model.parameters(), lr=lr)
+ model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader)
+ prev_logits = None
+ model.train()
+
+ for step, batch in enumerate(train_loader, start=1):
+ if step > args.max_steps:
+ print(f"Exceed max step {args.max_steps}, training stopped.")
+ break
+
+ with accelerator.accumulate(model):
+ enc = tokenizer(batch["input"],
+ return_tensors="pt",
+ padding="longest",
+ truncation=True,
+ max_length=2048).to(accelerator.device)
+
+ with torch.no_grad():
+ gen_ids = accelerator.unwrap_model(model).generate(**enc,
+ max_new_tokens=512,
+ do_sample=True,
+ top_p=0.95,
+ temperature=args.sample_temp,
+ synced_gpus=True,
+ repetition_penalty=1.15,
+ pad_token_id=tokenizer.pad_token_id,
+ use_cache=False)
+
+ seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096]
+ pad_mask = seq.ne(tokenizer.pad_token_id)
+ prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1)
+ token_idx = torch.arange(seq.size(1), device=seq.device)
+ gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask
+
+ logits = model(seq, attention_mask=pad_mask).logits
+ probs = F.softmax(logits / temp, dim=-1)
+ H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1)
+ loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1)
+
+ prev_logits = logits.detach()
+ accelerator.backward(loss)
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ optimizer.zero_grad()
+
+ if accelerator.is_main_process:
+ if step % args.log_steps == 0:
+ print(f"Step {step} | loss={loss.item():.6f}")
+ wandb.log({"step": step, "loss": loss.item()})
+
+ if step % args.save_steps == 0:
+ ckpt = Path(save_root) / f"step_{step}"
+ ckpt.mkdir(parents=True, exist_ok=True)
+ accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True)
+ tokenizer.save_pretrained(ckpt)
+ print(f"Checkpoint saved to {ckpt}")
+
+ if accelerator.is_main_process:
+ final = Path(save_root) / "final"
+ final.mkdir(parents=True, exist_ok=True)
+ accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True)
+ tokenizer.save_pretrained(final)
+ print(f"Final checkpoint saved to {final}")
+ wandb.finish()
+
+if __name__ == "__main__":
+ main()