diff options
| author | blackhao <13851610112@163.com> | 2025-08-23 13:35:13 -0500 |
|---|---|---|
| committer | blackhao <13851610112@163.com> | 2025-08-23 13:35:13 -0500 |
| commit | 4f81a87ef95b190450ed5202bfa725dbb0a539f4 (patch) | |
| tree | 875f5966cdaaa526d85ff49a13cd6bf27ab4a723 | |
| parent | ad3e216afd066375219ef8b3928ef4096237fbf6 (diff) | |
init
| -rw-r--r-- | .gitignore | 26 | ||||
| -rw-r--r-- | 2505.20282v4.pdf | bin | 0 -> 381009 bytes | |||
| -rw-r--r-- | Group-Entropy-Equalization/README.md | 90 | ||||
| -rw-r--r-- | Group-Entropy-Equalization/requirements.txt | 280 | ||||
| -rw-r--r-- | Group-Entropy-Equalization/train.py | 197 |
5 files changed, 593 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..28ec88c --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +# Byte-compiled / cache +__pycache__/ +*.py[cod] +*.so +*.pyd + +# Data +*.parquet +*.json +*.csv + +# Checkpoints/weights +*.bin +*.safetensors +*.pt +checkpoints/ + +# Envs +.env +.venv/ + +# Misc +.DS_Store +.ipynb_checkpoints/ +wandb/ + diff --git a/2505.20282v4.pdf b/2505.20282v4.pdf Binary files differnew file mode 100644 index 0000000..c012b46 --- /dev/null +++ b/2505.20282v4.pdf diff --git a/Group-Entropy-Equalization/README.md b/Group-Entropy-Equalization/README.md new file mode 100644 index 0000000..804af95 --- /dev/null +++ b/Group-Entropy-Equalization/README.md @@ -0,0 +1,90 @@ +# One-shot Entropy Minimization + +[](https://arxiv.org/abs/2505.20282) +[](https://huggingface.co/zgao3186/qwen25math7b-one-shot-em/) +[](https://www.notion.so/One-shot-Entropy-Minimization-202606db813b80639773f850f39246a5) + +### Installation + +```bash +conda create -n one-shot-em python=3.10 -y +pip install -r requirements.txt +``` + +--- + +### Reproducing One-shot EM Training (SOTA) + +```bash +accelerate launch train.py \ + --model_name Qwen2.5-Math-7B \ + --model_path /path/to/Qwen2.5-Math-7B \ + --train_data dataset/1shot_rlvr/pi1_r1280.parquet \ + --effective_batch 64 \ + --micro_batch_size 2 \ + --temperature 0.5 \ + --learning_rate 2e-5 \ + --max_steps 50 \ + --log_steps 1 \ + --save_steps 1 \ + --run_name one_shot \ + --wandb_project one-shot-em +``` + +--- + +### Reproducing Multi-shot EM Training + +```bash +accelerate launch train.py \ + --model_name Qwen2.5-Math-7B \ + --model_path /path/to/Qwen2.5-Math-7B \ + --train_data dataset/numina/numina_00.parquet \ + --effective_batch 64 \ + --micro_batch_size 2 \ + --temperature 0.5 \ + --learning_rate 2e-5 \ + --max_steps 50 \ + --log_steps 1 \ + --save_steps 1 \ + --run_name multi_shot \ + --wandb_project one-shot-em +``` + +--- + +### Evaluation + +```bash +cd Qwen2.5-Eval/evaluation +bash sh/eval_all_math.sh +``` + +--- + +### Acknowledgements + +Our dataset references and builds upon the following open-source contributions: + +- [NuminaMath-CoT](https://huggingface.co/datasets/AI-MO/NuminaMath-CoT) +- [DeepScaler](https://github.com/agentica-project/deepscaler) +- [One-shot RLVR](https://github.com/ypwang61/One-Shot-RLVR/) – for data selection strategies +- [Qwen2.5-Eval](https://github.com/QwenLM/Qwen2.5-Math/) – for evaluation benchmarks + +We sincerely thank the authors and maintainers of these projects for their excellent contributions to the research community! + + +--- + +### Citation +``` +@misc{gao2025oneshotentropyminimization, + title={One-shot Entropy Minimization}, + author={Zitian Gao and Lynx Chen and Haoming Luo and Joey Zhou and Bryan Dai}, + year={2025}, + eprint={2505.20282}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2505.20282}, +} +``` diff --git a/Group-Entropy-Equalization/requirements.txt b/Group-Entropy-Equalization/requirements.txt new file mode 100644 index 0000000..7fb330a --- /dev/null +++ b/Group-Entropy-Equalization/requirements.txt @@ -0,0 +1,280 @@ +absl-py==2.1.0 +accelerate==0.33.0 +aiofiles==23.2.1 +annotated-types==0.6.0 +anyio==4.4.0 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==2.4.1 +astunparse==1.6.3 +async-lru==2.0.4 +attrs==23.2.0 +audioread==3.0.1 +babel==2.16.0 +beautifulsoup4==4.12.3 +bitsandbytes==0.43.3 +bleach==6.1.0 +blis==0.7.11 +cachetools==5.3.2 +catalogue==2.0.10 +certifi==2024.2.2 +cffi==1.16.0 +charset-normalizer==3.3.2 +click==8.1.7 +cloudpathlib==0.16.0 +cmake==3.28.1 +comm==0.2.1 +confection==0.1.4 +contourpy==1.2.0 +cycler==0.12.1 +cymem==2.0.8 +Cython==3.0.8 +datasets==2.21.0 +debugpy==1.8.1 +decorator==5.1.1 +deepspeed==0.14.4 +defusedxml==0.7.1 +dill==0.3.8 +diskcache==5.6.3 +distro==1.9.0 +dm-tree==0.1.8 +docstring_parser==0.16 +einops==0.7.0 +exceptiongroup==1.2.0 +execnet==2.0.2 +executing==2.0.1 +expecttest==0.1.3 +fastapi==0.112.2 +fastjsonschema==2.19.1 +ffmpy==0.4.0 +filelock==3.13.1 +fire==0.6.0 +flash-attn==2.6.3 +fonttools==4.48.1 +fqdn==1.5.1 +gast==0.5.4 +gguf==0.9.1 +google-auth==2.27.0 +google-auth-oauthlib==0.4.6 +gradio==4.42.0 +gradio_client==1.3.0 +grpcio==1.60.1 +h11==0.14.0 +hjson==3.1.0 +httpcore==1.0.5 +httptools==0.6.1 +httpx==0.27.2 +huggingface-hub==0.24.6 +hypothesis==5.35.1 +idna==3.6 +importlib_resources==6.4.4 +iniconfig==2.0.0 +intel-openmp==2021.4.0 +interegular==0.3.3 +ipykernel==6.29.2 +ipython==8.21.0 +ipython-genutils==0.2.0 +isoduration==20.11.0 +jedi==0.19.1 +jieba==0.42.1 +Jinja2==3.1.3 +jiter==0.5.0 +joblib==1.3.2 +json5==0.9.14 +jsonpointer==3.0.0 +jsonschema==4.21.1 +jsonschema-specifications==2023.12.1 +jupyter-events==0.10.0 +jupyter-lsp==2.2.5 +jupyter_client==8.6.0 +jupyter_core==5.7.1 +jupyter_server==2.14.2 +jupyter_server_terminals==0.5.3 +jupyterlab==4.1.6 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.3 +jupytext==1.16.1 +kiwisolver==1.4.5 +langcodes==3.3.0 +lark==1.2.2 +lazy_loader==0.3 +librosa==0.10.1 +lm-format-enforcer==0.10.6 +Markdown==3.5.2 +markdown-it-py==3.0.0 +matplotlib==3.8.2 +matplotlib-inline==0.1.6 +mdit-py-plugins==0.4.0 +mdurl==0.1.2 +mistral_common==1.3.4 +mistune==3.0.2 +mkl==2021.1.1 +mkl-devel==2021.1.1 +mkl-include==2021.1.1 +mock==5.1.0 +mpmath==1.3.0 +msgpack==1.0.7 +msgspec==0.18.6 +multiprocess==0.70.16 +murmurhash==1.0.10 +nbclient==0.9.0 +nbconvert==7.16.0 +nbformat==5.9.2 +nest-asyncio==1.6.0 +networkx==2.6.3 +ninja==1.11.1.1 +nltk==3.9.1 +notebook==6.4.10 +notebook_shim==0.2.4 +numpy==1.24.4 +nvfuser==0.1.4a0+d0bb811 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-dali-cuda120==1.34.0 +nvidia-ml-py==12.560.30 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvtx-cu12==12.1.105 +nvidia-pyindex==1.0.9 +nvitop==1.5.1 +oauthlib==3.2.2 +openai==1.43.0 +optree==0.10.0 +orjson==3.10.7 +outlines==0.0.46 +overrides==7.7.0 +packaging==23.2 +pandas==2.2.2 +pandocfilters==1.5.1 +parso==0.8.3 +partial-json-parser==0.2.1.1.post4 +peft==0.12.0 +pexpect==4.9.0 +platformdirs==4.2.0 +pluggy==1.4.0 +polygraphy==0.49.4 +pooch==1.8.0 +preshed==3.0.9 +prettytable==3.9.0 +prometheus-client==0.19.0 +prometheus-fastapi-instrumentator==7.0.0 +prompt-toolkit==3.0.43 +protobuf==4.24.4 +ptyprocess==0.7.0 +pure-eval==0.2.2 +py-cpuinfo==9.0.0 +pyairports==2.1.1 +pyarrow==17.0.0 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pybind11==2.11.1 +pybind11-global==2.11.1 +pycountry==24.6.1 +pycparser==2.21 +pydantic==2.8.2 +pydantic_core==2.20.1 +pydub==0.25.1 +Pygments==2.17.2 +PyJWT==2.8.0 +pyparsing==3.1.1 +pytest==8.0.0 +pytest-flakefinder==1.1.0 +pytest-rerunfailures==13.0 +pytest-shard==0.1.2 +pytest-xdist==3.5.0 +python-dateutil==2.8.2 +python-dotenv==1.0.1 +python-hostlist==1.23.0 +python-json-logger==2.0.7 +python-multipart==0.0.9 +pytorch-quantization==2.1.2 +PyYAML==6.0.1 +pyzmq==25.1.2 +ray==2.35.0 +referencing==0.33.0 +regex==2023.12.25 +requests==2.32.3 +requests-oauthlib==1.3.1 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rouge-chinese==1.0.3 +rpds-py==0.17.1 +rsa==4.9 +ruff==0.6.3 +safetensors==0.4.4 +semantic-version==2.10.0 +Send2Trash==1.8.2 +sentencepiece==0.2.0 +shellingham==1.5.4 +shtab==1.7.1 +six==1.16.0 +smart-open==6.4.0 +sniffio==1.3.1 +sortedcontainers==2.4.0 +soundfile==0.12.1 +soupsieve==2.5 +soxr==0.3.7 +spacy==3.7.2 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +sphinx_glpi_theme==0.6 +srsly==2.4.8 +sse-starlette==2.1.3 +stack-data==0.6.3 +starlette==0.38.4 +sympy==1.12 +tabulate==0.9.0 +tbb==2021.11.0 +tensorboard==2.9.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +termcolor==2.4.0 +terminado==0.18.0 +thinc==8.2.3 +threadpoolctl==3.2.0 +tiktoken==0.7.0 +tinycss2==1.2.1 +tokenizers==0.19.1 +toml==0.10.2 +tomli==2.0.1 +tomlkit==0.12.0 +torch==2.4.0 +torchvision==0.19.0 +tornado==6.4 +tqdm==4.66.5 +traitlets==5.9.0 +transformers==4.44.2 +triton==3.0.0 +trl==0.9.6 +typer==0.12.5 +types-dataclasses==0.6.6 +types-python-dateutil==2.9.0.20240821 +typing_extensions==4.12.2 +tyro==0.8.10 +tzdata==2024.1 +uri-template==1.3.0 +urllib3==2.2.2 +uvicorn==0.30.6 +uvloop==0.20.0 +vllm==0.6.0 +vllm-flash-attn==2.6.1 +wasabi==1.1.2 +watchfiles==0.24.0 +wcwidth==0.2.13 +weasel==0.3.4 +webcolors==24.8.0 +webencodings==0.5.1 +websocket-client==1.8.0 +websockets==12.0 +Werkzeug==3.0.1 +xdoctest==1.0.2 +xformers==0.0.27.post2 +xxhash==3.5.0 diff --git a/Group-Entropy-Equalization/train.py b/Group-Entropy-Equalization/train.py new file mode 100644 index 0000000..11f658a --- /dev/null +++ b/Group-Entropy-Equalization/train.py @@ -0,0 +1,197 @@ +import argparse +import os +import random +import time +from pathlib import Path + +import psutil +import torch +import torch.nn.functional as F +from torch.optim import AdamW +import pandas as pd +import numpy as np +from torch.utils.data import Dataset, DataLoader + +import wandb + +from accelerate import Accelerator, DeepSpeedPlugin +from accelerate.utils import set_seed +from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM + + +os.environ.setdefault("NCCL_TIMEOUT", "2700") +os.environ.setdefault("TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC", "2700") + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model_name', type=str, default='Qwen2.5-Math-7B', help='Model name') + parser.add_argument('--model_path', type=str, default=None, help='Local model path') + parser.add_argument('--train_data', type=str, default='dataset/1shot_rlvr/pi1_r1280.parquet', help='Training data file path') + parser.add_argument('--save_root', type=str, default=None, help='Checkpoint save root directory') + parser.add_argument('--effective_batch', type=int, default=64, help='Global batch size') + parser.add_argument('--micro_batch_size', type=str, default=2, help='Micro batch size or "auto"') + parser.add_argument('--temperature', type=float, default=0.5, help='Temperature coefficient') + parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate') + parser.add_argument('--log_steps', type=int, default=1, help='Logging step interval') + parser.add_argument('--save_steps', type=int, default=1, help='Checkpoint saving step interval') + parser.add_argument('--max_steps', type=int, default=1000, help='Maximum training steps') + parser.add_argument('--sample_temp', type=float, default=0.5, help='Generation temperature parameter') + parser.add_argument('--run_name', type=str, default=None, help='Experiment run name') + parser.add_argument('--wandb_project', type=str, default='entropy-maximization-ft', help='W&B project name') + parser.add_argument('--wandb_name', type=str, default=None, help='W&B run name') + parser.add_argument('--seed', type=int, default=15, help='Random seed') + parser.add_argument('--no_deepspeed', action='store_true', help='Disable DeepSpeed and use plain Accelerator (Colab-friendly)') + parser.add_argument('--mixed_precision', type=str, default='bf16', choices=['bf16', 'fp16', 'no'], help='Mixed precision mode') + return parser.parse_args() + +class FTDataset(Dataset): + def __init__(self, rows): self.rows = rows + def __len__(self): return len(self.rows) + def __getitem__(self, idx): return self.rows[idx] + +def custom_collate(batch): + return {"input": [item["input"] for item in batch]} + +def get_optimal_micro_batch_size(model_name: str, world_size: int = 1) -> int: + model_configs = { + "1.5B": {"base_batch": 4, "keywords": ["1.5B", "1B"]}, + "2B": {"base_batch": 4, "keywords": ["2B"]}, + "3B": {"base_batch": 2, "keywords": ["3B"]}, + "7B": {"base_batch": 2, "keywords": ["7B"]}, + "8B+": {"base_batch": 1, "keywords": ["8B", "9B", "10B", "11B", "12B", "13B", "14B"]}, + } + model_name_upper = model_name.upper() + detected = next((cfg for cfg in model_configs.values() if any(k in model_name_upper for k in cfg["keywords"])), None) + base_batch = detected["base_batch"] if detected else 2 + if world_size > 1: + return min(base_batch + 1, int(base_batch * 1.5)) + return base_batch + +def apply_chat_template(tokenizer, problem: str) -> str: + return tokenizer.apply_chat_template( + [{"role": "user", "content": problem}], + tokenize=False, add_generation_prompt=True + ) + +def main(): + args = parse_args() + set_seed(args.seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + world_size = int(os.getenv("WORLD_SIZE", "1")) + micro_bs = int(args.micro_batch_size) + eff_bs = args.effective_batch + accum_steps = max(1, eff_bs // (micro_bs * world_size)) + temp = args.temperature + lr = args.learning_rate + + save_root = args.save_root or (f"checkpoints/{args.model_name}/{args.run_name}" if args.run_name else f"checkpoints/{args.model_name}") + # Resolve mixed precision automatically if requested bf16 is unsupported + mp = args.mixed_precision + if mp == "bf16": + if not torch.cuda.is_available() or not torch.cuda.is_bf16_supported(): + mp = "fp16" if torch.cuda.is_available() else "no" + + if args.no_deepspeed: + accelerator = Accelerator(mixed_precision=mp, gradient_accumulation_steps=accum_steps) + else: + ds_config = { + "train_micro_batch_size_per_gpu": micro_bs, + "train_batch_size": eff_bs, + "gradient_accumulation_steps": accum_steps, + "bf16": {"enabled": mp == "bf16"}, + "zero_optimization": { + "stage": 2, + "offload_optimizer": {"device": "cpu"}, + "offload_param": {"device": "none"} + }, + "gradient_clipping": 1.0, + } + accelerator = Accelerator(mixed_precision=mp, + gradient_accumulation_steps=accum_steps, + deepspeed_plugin=DeepSpeedPlugin(hf_ds_config=ds_config)) + print = accelerator.print + + model_path = args.model_path or f"/volume/pt-train/models/{args.model_name}" + config = AutoConfig.from_pretrained(model_path) + config.use_cache = False + model = AutoModelForCausalLM.from_pretrained(model_path, config=config) + model.gradient_checkpointing_enable() + tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") + tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token + + if accelerator.is_main_process: + wandb.init(project=args.wandb_project, name=args.run_name or args.wandb_name or args.model_name, config=vars(args)) + + df = pd.read_parquet(args.train_data) + train_data = [{"input": apply_chat_template(tokenizer, p)} for p in df["problem"].dropna().tolist()] + train_loader = DataLoader(FTDataset(train_data), batch_size=micro_bs, shuffle=True, collate_fn=custom_collate) + + optimizer = AdamW(model.parameters(), lr=lr) + model, optimizer, train_loader = accelerator.prepare(model, optimizer, train_loader) + prev_logits = None + model.train() + + for step, batch in enumerate(train_loader, start=1): + if step > args.max_steps: + print(f"Exceed max step {args.max_steps}, training stopped.") + break + + with accelerator.accumulate(model): + enc = tokenizer(batch["input"], + return_tensors="pt", + padding="longest", + truncation=True, + max_length=2048).to(accelerator.device) + + with torch.no_grad(): + gen_ids = accelerator.unwrap_model(model).generate(**enc, + max_new_tokens=512, + do_sample=True, + top_p=0.95, + temperature=args.sample_temp, + synced_gpus=True, + repetition_penalty=1.15, + pad_token_id=tokenizer.pad_token_id, + use_cache=False) + + seq = torch.cat([enc.input_ids, gen_ids[:, enc.input_ids.shape[1]:]], dim=1)[:, :4096] + pad_mask = seq.ne(tokenizer.pad_token_id) + prompt_len = pad_mask[:, :enc.input_ids.shape[1]].sum(-1) + token_idx = torch.arange(seq.size(1), device=seq.device) + gen_mask = (token_idx.unsqueeze(0) >= prompt_len.unsqueeze(1)) & pad_mask + + logits = model(seq, attention_mask=pad_mask).logits + probs = F.softmax(logits / temp, dim=-1) + H_tok = -(probs * torch.log(probs + 1e-12)).sum(-1) + loss = (H_tok * gen_mask).sum() / gen_mask.sum().clamp_min(1) + + prev_logits = logits.detach() + accelerator.backward(loss) + accelerator.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + optimizer.zero_grad() + + if accelerator.is_main_process: + if step % args.log_steps == 0: + print(f"Step {step} | loss={loss.item():.6f}") + wandb.log({"step": step, "loss": loss.item()}) + + if step % args.save_steps == 0: + ckpt = Path(save_root) / f"step_{step}" + ckpt.mkdir(parents=True, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(ckpt, safe_serialization=True) + tokenizer.save_pretrained(ckpt) + print(f"Checkpoint saved to {ckpt}") + + if accelerator.is_main_process: + final = Path(save_root) / "final" + final.mkdir(parents=True, exist_ok=True) + accelerator.unwrap_model(model).save_pretrained(final, safe_serialization=True) + tokenizer.save_pretrained(final) + print(f"Final checkpoint saved to {final}") + wandb.finish() + +if __name__ == "__main__": + main() |
