summaryrefslogtreecommitdiff
path: root/train_runner.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@timan108.cs.illinois.edu>2025-09-10 14:58:32 -0500
committerYuren Hao <yurenh2@timan108.cs.illinois.edu>2025-09-10 14:58:32 -0500
commit9bdd1091a8045d74823d1f5de94b041805e42c3f (patch)
treeb63843e7bce35b95bc3bd87aaee19059fdf2982c /train_runner.py
parent9d5f2379ac25b4b58e2600544f61172dbb15b67a (diff)
fuck nvidia driversHEADmain
Diffstat (limited to 'train_runner.py')
-rw-r--r--train_runner.py212
1 files changed, 212 insertions, 0 deletions
diff --git a/train_runner.py b/train_runner.py
new file mode 100644
index 0000000..f7c8c25
--- /dev/null
+++ b/train_runner.py
@@ -0,0 +1,212 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Few-step LoRA training runner for EM / Group-EM / JSD.
+- Generates a short greedy continuation for each prompt (T=0), then teacher-forces
+ on [prompt + continuation] to compute stepwise logits and apply losses.
+- Gating: Top-K membership of {F ∪ M} at each step (k configurable).
+- Guards: mass parity (lambda) and stability KL to a frozen base model (beta).
+Outputs:
+ adapter weights + a JSON log of loss components.
+This script is single-GPU; to parallelize, launch multiple processes with different CUDA_VISIBLE_DEVICES.
+"""
+import os, json, math, time, random, pathlib, argparse
+from typing import List, Dict, Tuple
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import AutoTokenizer, AutoModelForCausalLM, get_cosine_schedule_with_warmup
+from peft import LoraConfig, get_peft_model
+from src.losses import (
+ map_words_to_token_ids, probs_from_logits, topk_gate,
+ loss_em, loss_group_em, loss_jsd
+)
+
+def set_seed(s: int):
+ random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
+
+def read_jsonl(p: str) -> List[Dict]:
+ return [json.loads(x) for x in open(p, "r", encoding="utf-8") if x.strip()]
+
+def save_json(p: str, obj):
+ pathlib.Path(p).parent.mkdir(parents=True, exist_ok=True)
+ open(p, "w", encoding="utf-8").write(json.dumps(obj, indent=2))
+
+@torch.no_grad()
+def greedy_generate_ids(model, tok, prompt_ids: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
+ """
+ prompt_ids: [1, L]
+ return gen_ids: [1, G] (without BOS/EOS trimming)
+ """
+ out = model.generate(
+ input_ids=prompt_ids,
+ attention_mask=torch.ones_like(prompt_ids),
+ do_sample=False, temperature=0.0, top_p=1.0,
+ max_new_tokens=max_new_tokens,
+ eos_token_id=tok.eos_token_id
+ )
+ # decode original prompt length to slice
+ gen_ids = out[:, prompt_ids.size(1):] # [1, G]
+ return gen_ids
+
+def build_lora(model) -> nn.Module:
+ lconf = LoraConfig(
+ r=8, lora_alpha=16, lora_dropout=0.0,
+ bias="none", task_type="CAUSAL_LM",
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] # Qwen-style
+ )
+ return get_peft_model(model, lconf)
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct")
+ ap.add_argument("--loss_type", type=str, choices=["em","group_em","jsd"], required=True)
+ ap.add_argument("--train_file", type=str, help="for em/group_em: JSONL with {'prompt':...}")
+ ap.add_argument("--train_pairs", type=str, help="for jsd: JSONL with {'prompt','prompt_swap'}")
+ ap.add_argument("--groups_dir", type=str, default="assets/groups")
+ ap.add_argument("--output_dir", type=str, required=True)
+
+ # optimization
+ ap.add_argument("--max_steps", type=int, default=10)
+ ap.add_argument("--learning_rate", type=float, default=2e-5)
+ ap.add_argument("--warmup_steps", type=int, default=0)
+ ap.add_argument("--weight_decay", type=float, default=0.0)
+ ap.add_argument("--grad_accum", type=int, default=32)
+ ap.add_argument("--gen_len", type=int, default=64)
+ ap.add_argument("--seed", type=int, default=2025)
+
+ # guards / gating
+ ap.add_argument("--lambda_mass", type=float, default=0.0)
+ ap.add_argument("--beta_kl", type=float, default=0.0)
+ ap.add_argument("--topk_gate", type=int, default=20)
+ ap.add_argument("--topk_jsd", type=int, default=0) # 0 => full vocab JSD
+
+ # dtype / device
+ ap.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16","float16","float32"])
+ ap.add_argument("--device", type=str, default="cuda")
+
+ args = ap.parse_args()
+ set_seed(args.seed)
+
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
+
+ # Load tokenizer/model + LoRA (trainable) and base (frozen for KL)
+ tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
+ base_model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype).to(device)
+ base_model.eval()
+ model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype).to(device)
+ model = build_lora(model)
+ model.train()
+
+ # Build optimizer/scheduler
+ opt = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
+ sch = get_cosine_schedule_with_warmup(opt, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps)
+
+ # Build gender token id sets (for group and jsd; EM ignores)
+ fem_words = [w.strip().lower() for w in open(os.path.join(args.groups_dir,"en_female.txt"),"r",encoding="utf-8")]
+ male_words = [w.strip().lower() for w in open(os.path.join(args.groups_dir,"en_male.txt"),"r",encoding="utf-8")]
+ fem_ids = map_words_to_token_ids(tok, fem_words)
+ male_ids = map_words_to_token_ids(tok, male_words)
+
+ # Load training data
+ if args.loss_type in ("em","group_em"):
+ assert args.train_file, "--train_file is required for em/group_em"
+ rows = read_jsonl(args.train_file)
+ assert len(rows) > 0, "Empty train_file"
+ prompts = [r["prompt"] for r in rows]
+ else:
+ assert args.loss_type == "jsd" and args.train_pairs, "--train_pairs is required for jsd"
+ rows = read_jsonl(args.train_pairs)
+ assert len(rows) > 0, "Empty train_pairs"
+ pairs = [(r["prompt"], r["prompt_swap"]) for r in rows]
+
+ # Logging
+ outdir = pathlib.Path(args.output_dir)
+ (outdir/"logs").mkdir(parents=True, exist_ok=True)
+ (outdir/"adapter").mkdir(parents=True, exist_ok=True)
+ json_log = []
+
+ step = 0
+ while step < args.max_steps:
+ # simple cyclic sampler
+ if args.loss_type in ("em","group_em"):
+ idx = step % len(prompts)
+ prompt = prompts[idx]
+ # 1) generate a short continuation (greedy)
+ enc = tok(prompt, return_tensors="pt")
+ gen_ids = greedy_generate_ids(model, tok, enc.input_ids.to(device), max_new_tokens=args.gen_len) # [1,G]
+ concat_ids = torch.cat([enc.input_ids.to(device), gen_ids], dim=1) # [1,L+G]
+ attn = torch.ones_like(concat_ids).to(device)
+
+ # 2) forward current model (teacher forcing) for loss
+ out = model(input_ids=concat_ids, attention_mask=attn)
+ logits = out.logits[:, :-1, :] # predict next for each pos except last targetless
+ T = logits.size(1)
+ # generation mask: 0 for prompt part, 1 for generated part (shifted by 1)
+ gen_mask = torch.zeros((1,T), dtype=torch.float32, device=device)
+ gen_mask[:, enc.input_ids.size(1)-1:] = 1.0 # from the last prompt position onward
+
+ if args.loss_type == "em":
+ loss, extras = loss_em(logits, gen_mask)
+ else:
+ gate = topk_gate(logits, fem_ids, male_ids, k=args.topk_gate) # [1,T]
+ with torch.no_grad():
+ out_ref = base_model(input_ids=concat_ids, attention_mask=attn)
+ ref_probs = probs_from_logits(out_ref.logits[:, :-1, :])
+ loss, extras = loss_group_em(
+ logits, gen_mask, fem_ids, male_ids, gate_mask=gate,
+ lambda_mass=args.lambda_mass, beta_kl=args.beta_kl, ref_probs=ref_probs
+ )
+
+ else:
+ # JSD branch
+ idx = step % len(pairs)
+ x, xs = pairs[idx]
+
+ # factual
+ enc_f = tok(x, return_tensors="pt")
+ gen_f = greedy_generate_ids(model, tok, enc_f.input_ids.to(device), max_new_tokens=args.gen_len)
+ all_f = torch.cat([enc_f.input_ids.to(device), gen_f], dim=1)
+ attn_f = torch.ones_like(all_f).to(device)
+ out_f = model(input_ids=all_f, attention_mask=attn_f)
+ logits_f = out_f.logits[:, :-1, :]
+ T_f = logits_f.size(1)
+ gen_mask_f = torch.zeros((1,T_f), dtype=torch.float32, device=device)
+ gen_mask_f[:, enc_f.input_ids.size(1)-1:] = 1.0
+ gate_f = topk_gate(logits_f, fem_ids, male_ids, k=args.topk_gate)
+
+ # counterfactual
+ enc_c = tok(xs, return_tensors="pt")
+ gen_c = greedy_generate_ids(model, tok, enc_c.input_ids.to(device), max_new_tokens=args.gen_len)
+ all_c = torch.cat([enc_c.input_ids.to(device), gen_c], dim=1)
+ attn_c = torch.ones_like(all_c).to(device)
+ out_c = model(input_ids=all_c, attention_mask=attn_c)
+ logits_c = out_c.logits[:, :-1, :]
+
+ with torch.no_grad():
+ out_ref_f = base_model(input_ids=all_f, attention_mask=attn_f)
+ ref_probs_f = probs_from_logits(out_ref_f.logits[:, :-1, :])
+
+ loss, extras = loss_jsd(
+ logits_f, logits_c, gen_mask_f, fem_ids, male_ids,
+ gate_mask_f=gate_f, lambda_mass=args.lambda_mass, beta_kl=args.beta_kl,
+ ref_probs_f=ref_probs_f, topk_jsd=args.topk_jsd
+ )
+
+ (loss / args.grad_accum).backward()
+ if (step + 1) % args.grad_accum == 0 or (step + 1) == args.max_steps:
+ opt.step(); sch.step(); opt.zero_grad(set_to_none=True)
+
+ step += 1
+ json_log.append({"step": step, "loss": float(loss.item()), **extras})
+ if step % 1 == 0:
+ print(f"[{step}/{args.max_steps}] loss={loss.item():.6f} | {extras}")
+
+ # save adapter
+ model.save_pretrained(str(outdir/"adapter"))
+ save_json(str(outdir/"logs"/"train_log.json"), {"args": vars(args), "log": json_log})
+ print("Saved adapter to", outdir/"adapter")
+
+if __name__ == "__main__":
+ main()