summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.tmp_gpu_check.py13
-rw-r--r--.tmp_gpu_check2.py16
-rw-r--r--.tmp_gpu_check3.py2
-rw-r--r--data/train/em_group/train_en_size1.jsonl1
-rw-r--r--data/train/em_group/train_en_size20.jsonl20
-rw-r--r--data/train/em_group/train_en_size5.jsonl5
-rw-r--r--data/train/jsd/train_pairs_en_size100.jsonl100
-rw-r--r--requirements.txt283
-rw-r--r--scripts/make_train_sets.py200
-rw-r--r--src/__init__.py0
-rw-r--r--src/losses.py248
-rw-r--r--train_runner.py212
12 files changed, 837 insertions, 263 deletions
diff --git a/.tmp_gpu_check.py b/.tmp_gpu_check.py
new file mode 100644
index 0000000..706491c
--- /dev/null
+++ b/.tmp_gpu_check.py
@@ -0,0 +1,13 @@
+import os
+print("ENV CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
+try:
+ import torch
+ print("torch version:", torch.__version__)
+ print("torch.version.cuda:", getattr(torch.version, "cuda", None))
+ print("torch.cuda.is_available():", torch.cuda.is_available())
+ print("torch.cuda.device_count():", torch.cuda.device_count())
+ if torch.cuda.is_available():
+ for i in range(torch.cuda.device_count()):
+ print(f"[{i}]", torch.cuda.get_device_name(i))
+except Exception as e:
+ print("torch error:", repr(e))
diff --git a/.tmp_gpu_check2.py b/.tmp_gpu_check2.py
new file mode 100644
index 0000000..7555155
--- /dev/null
+++ b/.tmp_gpu_check2.py
@@ -0,0 +1,16 @@
+import os
+print("ENV CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
+try:
+ import torch
+ print("torch:", torch.__version__, "cuda:", getattr(torch.version, "cuda", None))
+ print("built_with_cuda:", torch.backends.cuda.is_built())
+ print("device_count:", torch.cuda.device_count())
+ print("is_available:", torch.cuda.is_available())
+ if torch.cuda.device_count() > 0:
+ for i in range(torch.cuda.device_count()):
+ try:
+ print(f"[{i}]", torch.cuda.get_device_name(i))
+ except Exception as e:
+ print(f"[{i}] name error:", e)
+except Exception as e:
+ print("torch error:", repr(e))
diff --git a/.tmp_gpu_check3.py b/.tmp_gpu_check3.py
new file mode 100644
index 0000000..e88f10e
--- /dev/null
+++ b/.tmp_gpu_check3.py
@@ -0,0 +1,2 @@
+import os, torch
+print(env
diff --git a/data/train/em_group/train_en_size1.jsonl b/data/train/em_group/train_en_size1.jsonl
new file mode 100644
index 0000000..c8993ed
--- /dev/null
+++ b/data/train/em_group/train_en_size1.jsonl
@@ -0,0 +1 @@
+{"id": "em_000000", "lang": "en", "occupation": "janitor", "prompt": "The janitor promised that she will lead the next meeting."}
diff --git a/data/train/em_group/train_en_size20.jsonl b/data/train/em_group/train_en_size20.jsonl
new file mode 100644
index 0000000..d326dac
--- /dev/null
+++ b/data/train/em_group/train_en_size20.jsonl
@@ -0,0 +1,20 @@
+{"id": "em_000000", "lang": "en", "occupation": "firefighter", "prompt": "The firefighter noted that she prefers the alternative plan."}
+{"id": "em_000001", "lang": "en", "occupation": "secretary", "prompt": "The secretary said that she would review the design."}
+{"id": "em_000002", "lang": "en", "occupation": "librarian", "prompt": "The librarian said that she would review the design."}
+{"id": "em_000003", "lang": "en", "occupation": "administrator", "prompt": "The administrator explained that she needed more time."}
+{"id": "em_000004", "lang": "en", "occupation": "accountant", "prompt": "The accountant promised that she will lead the next meeting."}
+{"id": "em_000005", "lang": "en", "occupation": "lawyer", "prompt": "The lawyer promised that she will lead the next meeting."}
+{"id": "em_000006", "lang": "en", "occupation": "nurse", "prompt": "The nurse noted that she prefers the alternative plan."}
+{"id": "em_000007", "lang": "en", "occupation": "advisor", "prompt": "The advisor said that she would review the design."}
+{"id": "em_000008", "lang": "en", "occupation": "cashier", "prompt": "The cashier said that she would review the design."}
+{"id": "em_000009", "lang": "en", "occupation": "investigator", "prompt": "The investigator noted that she prefers the alternative plan."}
+{"id": "em_000010", "lang": "en", "occupation": "therapist", "prompt": "The therapist said that she would review the design."}
+{"id": "em_000011", "lang": "en", "occupation": "salesperson", "prompt": "The salesperson noted that she prefers the alternative plan."}
+{"id": "em_000012", "lang": "en", "occupation": "surgeon", "prompt": "The surgeon noted that she prefers the alternative plan."}
+{"id": "em_000013", "lang": "en", "occupation": "officer", "prompt": "The officer promised that she will lead the next meeting."}
+{"id": "em_000014", "lang": "en", "occupation": "machinist", "prompt": "The machinist promised that she will lead the next meeting."}
+{"id": "em_000015", "lang": "en", "occupation": "hygienist", "prompt": "The hygienist explained that she needed more time."}
+{"id": "em_000016", "lang": "en", "occupation": "chemist", "prompt": "The chemist promised that she will lead the next meeting."}
+{"id": "em_000017", "lang": "en", "occupation": "mechanic", "prompt": "The mechanic noted that she prefers the alternative plan."}
+{"id": "em_000018", "lang": "en", "occupation": "supervisor", "prompt": "The supervisor promised that she will lead the next meeting."}
+{"id": "em_000019", "lang": "en", "occupation": "engineer", "prompt": "The engineer explained that she needed more time."}
diff --git a/data/train/em_group/train_en_size5.jsonl b/data/train/em_group/train_en_size5.jsonl
new file mode 100644
index 0000000..c08a4e5
--- /dev/null
+++ b/data/train/em_group/train_en_size5.jsonl
@@ -0,0 +1,5 @@
+{"id": "em_000000", "lang": "en", "occupation": "pharmacist", "prompt": "The pharmacist explained that she needed more time."}
+{"id": "em_000001", "lang": "en", "occupation": "cashier", "prompt": "The cashier reported that she completed the task."}
+{"id": "em_000002", "lang": "en", "occupation": "veterinarian", "prompt": "The veterinarian noted that she prefers the alternative plan."}
+{"id": "em_000003", "lang": "en", "occupation": "hygienist", "prompt": "The hygienist explained that she needed more time."}
+{"id": "em_000004", "lang": "en", "occupation": "scientist", "prompt": "The scientist promised that she will lead the next meeting."}
diff --git a/data/train/jsd/train_pairs_en_size100.jsonl b/data/train/jsd/train_pairs_en_size100.jsonl
new file mode 100644
index 0000000..b05e46b
--- /dev/null
+++ b/data/train/jsd/train_pairs_en_size100.jsonl
@@ -0,0 +1,100 @@
+{"id": "jsd_000000", "lang": "en", "occupation": "supervisor", "prompt": "The supervisor promised that she will lead the next meeting.", "prompt_swap": "The supervisor promised that he will lead the next meeting."}
+{"id": "jsd_000001", "lang": "en", "occupation": "nurse", "prompt": "The nurse explained that she needed more time.", "prompt_swap": "The nurse explained that he needed more time."}
+{"id": "jsd_000002", "lang": "en", "occupation": "hygienist", "prompt": "The hygienist said that she would review the design.", "prompt_swap": "The hygienist said that he would review the design."}
+{"id": "jsd_000003", "lang": "en", "occupation": "programmer", "prompt": "The programmer reported that she completed the task.", "prompt_swap": "The programmer reported that he completed the task."}
+{"id": "jsd_000004", "lang": "en", "occupation": "machinist", "prompt": "The machinist noted that she prefers the alternative plan.", "prompt_swap": "The machinist noted that he prefers the alternative plan."}
+{"id": "jsd_000005", "lang": "en", "occupation": "specialist", "prompt": "The specialist noted that she prefers the alternative plan.", "prompt_swap": "The specialist noted that he prefers the alternative plan."}
+{"id": "jsd_000006", "lang": "en", "occupation": "practitioner", "prompt": "The practitioner explained that she needed more time.", "prompt_swap": "The practitioner explained that he needed more time."}
+{"id": "jsd_000007", "lang": "en", "occupation": "paramedic", "prompt": "The paramedic reported that she completed the task.", "prompt_swap": "The paramedic reported that he completed the task."}
+{"id": "jsd_000008", "lang": "en", "occupation": "investigator", "prompt": "The investigator explained that she needed more time.", "prompt_swap": "The investigator explained that he needed more time."}
+{"id": "jsd_000009", "lang": "en", "occupation": "administrator", "prompt": "The administrator explained that she needed more time.", "prompt_swap": "The administrator explained that he needed more time."}
+{"id": "jsd_000010", "lang": "en", "occupation": "pharmacist", "prompt": "The pharmacist said that she would review the design.", "prompt_swap": "The pharmacist said that he would review the design."}
+{"id": "jsd_000011", "lang": "en", "occupation": "surgeon", "prompt": "The surgeon noted that she prefers the alternative plan.", "prompt_swap": "The surgeon noted that he prefers the alternative plan."}
+{"id": "jsd_000012", "lang": "en", "occupation": "officer", "prompt": "The officer said that she would review the design.", "prompt_swap": "The officer said that he would review the design."}
+{"id": "jsd_000013", "lang": "en", "occupation": "plumber", "prompt": "The plumber promised that she will lead the next meeting.", "prompt_swap": "The plumber promised that he will lead the next meeting."}
+{"id": "jsd_000014", "lang": "en", "occupation": "manager", "prompt": "The manager reported that she completed the task.", "prompt_swap": "The manager reported that he completed the task."}
+{"id": "jsd_000015", "lang": "en", "occupation": "nutritionist", "prompt": "The nutritionist promised that she will lead the next meeting.", "prompt_swap": "The nutritionist promised that he will lead the next meeting."}
+{"id": "jsd_000016", "lang": "en", "occupation": "janitor", "prompt": "The janitor reported that she completed the task.", "prompt_swap": "The janitor reported that he completed the task."}
+{"id": "jsd_000017", "lang": "en", "occupation": "planner", "prompt": "The planner said that she would review the design.", "prompt_swap": "The planner said that he would review the design."}
+{"id": "jsd_000018", "lang": "en", "occupation": "veterinarian", "prompt": "The veterinarian said that she would review the design.", "prompt_swap": "The veterinarian said that he would review the design."}
+{"id": "jsd_000019", "lang": "en", "occupation": "technician", "prompt": "The technician promised that she will lead the next meeting.", "prompt_swap": "The technician promised that he will lead the next meeting."}
+{"id": "jsd_000020", "lang": "en", "occupation": "accountant", "prompt": "The accountant explained that she needed more time.", "prompt_swap": "The accountant explained that he needed more time."}
+{"id": "jsd_000021", "lang": "en", "occupation": "baker", "prompt": "The baker said that she would review the design.", "prompt_swap": "The baker said that he would review the design."}
+{"id": "jsd_000022", "lang": "en", "occupation": "hairdresser", "prompt": "The hairdresser explained that she needed more time.", "prompt_swap": "The hairdresser explained that he needed more time."}
+{"id": "jsd_000023", "lang": "en", "occupation": "counselor", "prompt": "The counselor promised that she will lead the next meeting.", "prompt_swap": "The counselor promised that he will lead the next meeting."}
+{"id": "jsd_000024", "lang": "en", "occupation": "therapist", "prompt": "The therapist explained that she needed more time.", "prompt_swap": "The therapist explained that he needed more time."}
+{"id": "jsd_000025", "lang": "en", "occupation": "mechanic", "prompt": "The mechanic promised that she will lead the next meeting.", "prompt_swap": "The mechanic promised that he will lead the next meeting."}
+{"id": "jsd_000026", "lang": "en", "occupation": "salesperson", "prompt": "The salesperson promised that she will lead the next meeting.", "prompt_swap": "The salesperson promised that he will lead the next meeting."}
+{"id": "jsd_000027", "lang": "en", "occupation": "psychologist", "prompt": "The psychologist said that she would review the design.", "prompt_swap": "The psychologist said that he would review the design."}
+{"id": "jsd_000028", "lang": "en", "occupation": "carpenter", "prompt": "The carpenter explained that she needed more time.", "prompt_swap": "The carpenter explained that he needed more time."}
+{"id": "jsd_000029", "lang": "en", "occupation": "lawyer", "prompt": "The lawyer said that she would review the design.", "prompt_swap": "The lawyer said that he would review the design."}
+{"id": "jsd_000030", "lang": "en", "occupation": "dietitian", "prompt": "The dietitian explained that she needed more time.", "prompt_swap": "The dietitian explained that he needed more time."}
+{"id": "jsd_000031", "lang": "en", "occupation": "dispatcher", "prompt": "The dispatcher said that she would review the design.", "prompt_swap": "The dispatcher said that he would review the design."}
+{"id": "jsd_000032", "lang": "en", "occupation": "bartender", "prompt": "The bartender explained that she needed more time.", "prompt_swap": "The bartender explained that he needed more time."}
+{"id": "jsd_000033", "lang": "en", "occupation": "doctor", "prompt": "The doctor explained that she needed more time.", "prompt_swap": "The doctor explained that he needed more time."}
+{"id": "jsd_000034", "lang": "en", "occupation": "teacher", "prompt": "The teacher explained that she needed more time.", "prompt_swap": "The teacher explained that he needed more time."}
+{"id": "jsd_000035", "lang": "en", "occupation": "broker", "prompt": "The broker reported that she completed the task.", "prompt_swap": "The broker reported that he completed the task."}
+{"id": "jsd_000036", "lang": "en", "occupation": "librarian", "prompt": "The librarian noted that she prefers the alternative plan.", "prompt_swap": "The librarian noted that he prefers the alternative plan."}
+{"id": "jsd_000037", "lang": "en", "occupation": "scientist", "prompt": "The scientist noted that she prefers the alternative plan.", "prompt_swap": "The scientist noted that he prefers the alternative plan."}
+{"id": "jsd_000038", "lang": "en", "occupation": "educator", "prompt": "The educator explained that she needed more time.", "prompt_swap": "The educator explained that he needed more time."}
+{"id": "jsd_000039", "lang": "en", "occupation": "inspector", "prompt": "The inspector said that she would review the design.", "prompt_swap": "The inspector said that he would review the design."}
+{"id": "jsd_000040", "lang": "en", "occupation": "electrician", "prompt": "The electrician noted that she prefers the alternative plan.", "prompt_swap": "The electrician noted that he prefers the alternative plan."}
+{"id": "jsd_000041", "lang": "en", "occupation": "firefighter", "prompt": "The firefighter said that she would review the design.", "prompt_swap": "The firefighter said that he would review the design."}
+{"id": "jsd_000042", "lang": "en", "occupation": "worker", "prompt": "The worker explained that she needed more time.", "prompt_swap": "The worker explained that he needed more time."}
+{"id": "jsd_000043", "lang": "en", "occupation": "architect", "prompt": "The architect noted that she prefers the alternative plan.", "prompt_swap": "The architect noted that he prefers the alternative plan."}
+{"id": "jsd_000044", "lang": "en", "occupation": "physician", "prompt": "The physician explained that she needed more time.", "prompt_swap": "The physician explained that he needed more time."}
+{"id": "jsd_000045", "lang": "en", "occupation": "pathologist", "prompt": "The pathologist said that she would review the design.", "prompt_swap": "The pathologist said that he would review the design."}
+{"id": "jsd_000046", "lang": "en", "occupation": "chemist", "prompt": "The chemist said that she would review the design.", "prompt_swap": "The chemist said that he would review the design."}
+{"id": "jsd_000047", "lang": "en", "occupation": "paralegal", "prompt": "The paralegal promised that she will lead the next meeting.", "prompt_swap": "The paralegal promised that he will lead the next meeting."}
+{"id": "jsd_000048", "lang": "en", "occupation": "advisor", "prompt": "The advisor explained that she needed more time.", "prompt_swap": "The advisor explained that he needed more time."}
+{"id": "jsd_000049", "lang": "en", "occupation": "engineer", "prompt": "The engineer promised that she will lead the next meeting.", "prompt_swap": "The engineer promised that he will lead the next meeting."}
+{"id": "jsd_000050", "lang": "en", "occupation": "auditor", "prompt": "The auditor explained that she needed more time.", "prompt_swap": "The auditor explained that he needed more time."}
+{"id": "jsd_000051", "lang": "en", "occupation": "receptionist", "prompt": "The receptionist promised that she will lead the next meeting.", "prompt_swap": "The receptionist promised that he will lead the next meeting."}
+{"id": "jsd_000052", "lang": "en", "occupation": "painter", "prompt": "The painter noted that she prefers the alternative plan.", "prompt_swap": "The painter noted that he prefers the alternative plan."}
+{"id": "jsd_000053", "lang": "en", "occupation": "cashier", "prompt": "The cashier reported that she completed the task.", "prompt_swap": "The cashier reported that he completed the task."}
+{"id": "jsd_000054", "lang": "en", "occupation": "appraiser", "prompt": "The appraiser reported that she completed the task.", "prompt_swap": "The appraiser reported that he completed the task."}
+{"id": "jsd_000055", "lang": "en", "occupation": "chef", "prompt": "The chef explained that she needed more time.", "prompt_swap": "The chef explained that he needed more time."}
+{"id": "jsd_000056", "lang": "en", "occupation": "secretary", "prompt": "The secretary explained that she needed more time.", "prompt_swap": "The secretary explained that he needed more time."}
+{"id": "jsd_000057", "lang": "en", "occupation": "clerk", "prompt": "The clerk promised that she will lead the next meeting.", "prompt_swap": "The clerk promised that he will lead the next meeting."}
+{"id": "jsd_000058", "lang": "en", "occupation": "instructor", "prompt": "The instructor noted that she prefers the alternative plan.", "prompt_swap": "The instructor noted that he prefers the alternative plan."}
+{"id": "jsd_000059", "lang": "en", "occupation": "examiner", "prompt": "The examiner explained that she needed more time.", "prompt_swap": "The examiner explained that he needed more time."}
+{"id": "jsd_000060", "lang": "en", "occupation": "supervisor", "prompt": "The supervisor noted that she prefers the alternative plan.", "prompt_swap": "The supervisor noted that he prefers the alternative plan."}
+{"id": "jsd_000061", "lang": "en", "occupation": "nurse", "prompt": "The nurse explained that she needed more time.", "prompt_swap": "The nurse explained that he needed more time."}
+{"id": "jsd_000062", "lang": "en", "occupation": "hygienist", "prompt": "The hygienist promised that she will lead the next meeting.", "prompt_swap": "The hygienist promised that he will lead the next meeting."}
+{"id": "jsd_000063", "lang": "en", "occupation": "programmer", "prompt": "The programmer noted that she prefers the alternative plan.", "prompt_swap": "The programmer noted that he prefers the alternative plan."}
+{"id": "jsd_000064", "lang": "en", "occupation": "machinist", "prompt": "The machinist explained that she needed more time.", "prompt_swap": "The machinist explained that he needed more time."}
+{"id": "jsd_000065", "lang": "en", "occupation": "specialist", "prompt": "The specialist promised that she will lead the next meeting.", "prompt_swap": "The specialist promised that he will lead the next meeting."}
+{"id": "jsd_000066", "lang": "en", "occupation": "practitioner", "prompt": "The practitioner promised that she will lead the next meeting.", "prompt_swap": "The practitioner promised that he will lead the next meeting."}
+{"id": "jsd_000067", "lang": "en", "occupation": "paramedic", "prompt": "The paramedic explained that she needed more time.", "prompt_swap": "The paramedic explained that he needed more time."}
+{"id": "jsd_000068", "lang": "en", "occupation": "investigator", "prompt": "The investigator noted that she prefers the alternative plan.", "prompt_swap": "The investigator noted that he prefers the alternative plan."}
+{"id": "jsd_000069", "lang": "en", "occupation": "administrator", "prompt": "The administrator reported that she completed the task.", "prompt_swap": "The administrator reported that he completed the task."}
+{"id": "jsd_000070", "lang": "en", "occupation": "pharmacist", "prompt": "The pharmacist promised that she will lead the next meeting.", "prompt_swap": "The pharmacist promised that he will lead the next meeting."}
+{"id": "jsd_000071", "lang": "en", "occupation": "surgeon", "prompt": "The surgeon reported that she completed the task.", "prompt_swap": "The surgeon reported that he completed the task."}
+{"id": "jsd_000072", "lang": "en", "occupation": "officer", "prompt": "The officer said that she would review the design.", "prompt_swap": "The officer said that he would review the design."}
+{"id": "jsd_000073", "lang": "en", "occupation": "plumber", "prompt": "The plumber explained that she needed more time.", "prompt_swap": "The plumber explained that he needed more time."}
+{"id": "jsd_000074", "lang": "en", "occupation": "manager", "prompt": "The manager explained that she needed more time.", "prompt_swap": "The manager explained that he needed more time."}
+{"id": "jsd_000075", "lang": "en", "occupation": "nutritionist", "prompt": "The nutritionist said that she would review the design.", "prompt_swap": "The nutritionist said that he would review the design."}
+{"id": "jsd_000076", "lang": "en", "occupation": "janitor", "prompt": "The janitor explained that she needed more time.", "prompt_swap": "The janitor explained that he needed more time."}
+{"id": "jsd_000077", "lang": "en", "occupation": "planner", "prompt": "The planner said that she would review the design.", "prompt_swap": "The planner said that he would review the design."}
+{"id": "jsd_000078", "lang": "en", "occupation": "veterinarian", "prompt": "The veterinarian promised that she will lead the next meeting.", "prompt_swap": "The veterinarian promised that he will lead the next meeting."}
+{"id": "jsd_000079", "lang": "en", "occupation": "technician", "prompt": "The technician promised that she will lead the next meeting.", "prompt_swap": "The technician promised that he will lead the next meeting."}
+{"id": "jsd_000080", "lang": "en", "occupation": "accountant", "prompt": "The accountant explained that she needed more time.", "prompt_swap": "The accountant explained that he needed more time."}
+{"id": "jsd_000081", "lang": "en", "occupation": "baker", "prompt": "The baker noted that she prefers the alternative plan.", "prompt_swap": "The baker noted that he prefers the alternative plan."}
+{"id": "jsd_000082", "lang": "en", "occupation": "hairdresser", "prompt": "The hairdresser noted that she prefers the alternative plan.", "prompt_swap": "The hairdresser noted that he prefers the alternative plan."}
+{"id": "jsd_000083", "lang": "en", "occupation": "counselor", "prompt": "The counselor explained that she needed more time.", "prompt_swap": "The counselor explained that he needed more time."}
+{"id": "jsd_000084", "lang": "en", "occupation": "therapist", "prompt": "The therapist reported that she completed the task.", "prompt_swap": "The therapist reported that he completed the task."}
+{"id": "jsd_000085", "lang": "en", "occupation": "mechanic", "prompt": "The mechanic noted that she prefers the alternative plan.", "prompt_swap": "The mechanic noted that he prefers the alternative plan."}
+{"id": "jsd_000086", "lang": "en", "occupation": "salesperson", "prompt": "The salesperson reported that she completed the task.", "prompt_swap": "The salesperson reported that he completed the task."}
+{"id": "jsd_000087", "lang": "en", "occupation": "psychologist", "prompt": "The psychologist reported that she completed the task.", "prompt_swap": "The psychologist reported that he completed the task."}
+{"id": "jsd_000088", "lang": "en", "occupation": "carpenter", "prompt": "The carpenter noted that she prefers the alternative plan.", "prompt_swap": "The carpenter noted that he prefers the alternative plan."}
+{"id": "jsd_000089", "lang": "en", "occupation": "lawyer", "prompt": "The lawyer said that she would review the design.", "prompt_swap": "The lawyer said that he would review the design."}
+{"id": "jsd_000090", "lang": "en", "occupation": "dietitian", "prompt": "The dietitian reported that she completed the task.", "prompt_swap": "The dietitian reported that he completed the task."}
+{"id": "jsd_000091", "lang": "en", "occupation": "dispatcher", "prompt": "The dispatcher promised that she will lead the next meeting.", "prompt_swap": "The dispatcher promised that he will lead the next meeting."}
+{"id": "jsd_000092", "lang": "en", "occupation": "bartender", "prompt": "The bartender reported that she completed the task.", "prompt_swap": "The bartender reported that he completed the task."}
+{"id": "jsd_000093", "lang": "en", "occupation": "doctor", "prompt": "The doctor explained that she needed more time.", "prompt_swap": "The doctor explained that he needed more time."}
+{"id": "jsd_000094", "lang": "en", "occupation": "teacher", "prompt": "The teacher said that she would review the design.", "prompt_swap": "The teacher said that he would review the design."}
+{"id": "jsd_000095", "lang": "en", "occupation": "broker", "prompt": "The broker promised that she will lead the next meeting.", "prompt_swap": "The broker promised that he will lead the next meeting."}
+{"id": "jsd_000096", "lang": "en", "occupation": "librarian", "prompt": "The librarian explained that she needed more time.", "prompt_swap": "The librarian explained that he needed more time."}
+{"id": "jsd_000097", "lang": "en", "occupation": "scientist", "prompt": "The scientist said that she would review the design.", "prompt_swap": "The scientist said that he would review the design."}
+{"id": "jsd_000098", "lang": "en", "occupation": "educator", "prompt": "The educator noted that she prefers the alternative plan.", "prompt_swap": "The educator noted that he prefers the alternative plan."}
+{"id": "jsd_000099", "lang": "en", "occupation": "inspector", "prompt": "The inspector promised that she will lead the next meeting.", "prompt_swap": "The inspector promised that he will lead the next meeting."}
diff --git a/requirements.txt b/requirements.txt
index fca3779..06637df 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,274 +1,31 @@
-absl-py==2.1.0
accelerate==0.33.0
-aiofiles==23.2.1
-annotated-types==0.6.0
-anyio==4.4.0
-argon2-cffi==23.1.0
-argon2-cffi-bindings==21.2.0
-arrow==1.3.0
-asttokens==2.4.1
-astunparse==1.6.3
-async-lru==2.0.4
-attrs==23.2.0
-audioread==3.0.1
-babel==2.16.0
-beautifulsoup4==4.12.3
-bitsandbytes==0.43.3
-bleach==6.1.0
-blis==0.7.11
-cachetools==5.3.2
-catalogue==2.0.10
-certifi==2024.2.2
-cffi==1.16.0
-charset-normalizer==3.3.2
-click==8.1.7
-cloudpathlib==0.16.0
-cmake==3.28.1
-comm==0.2.1
-confection==0.1.4
-contourpy==1.2.0
-cycler==0.12.1
-cymem==2.0.8
-Cython==3.0.8
datasets==2.21.0
-debugpy==1.8.1
-decorator==5.1.1
-defusedxml==0.7.1
-dill==0.3.8
-diskcache==5.6.3
-distro==1.9.0
-dm-tree==0.1.8
-docstring_parser==0.16
-einops==0.7.0
-exceptiongroup==1.2.0
-execnet==2.0.2
-executing==2.0.1
-expecttest==0.1.3
-fastapi==0.112.2
-fastjsonschema==2.19.1
-ffmpy==0.4.0
-filelock==3.13.1
-fire==0.6.0
-fonttools==4.48.1
-fqdn==1.5.1
-gast==0.5.4
-gguf==0.9.1
-google-auth==2.27.0
-google-auth-oauthlib==0.4.6
-gradio==4.42.0
-gradio_client==1.3.0
-grpcio==1.60.1
-h11==0.14.0
-hjson==3.1.0
-httpcore==1.0.5
-httptools==0.6.1
-httpx==0.27.2
-huggingface-hub==0.24.6
-hypothesis==5.35.1
-idna==3.6
-importlib_resources==6.4.4
-iniconfig==2.0.0
-intel-openmp==2021.4.0
-interegular==0.3.3
-ipykernel==6.29.2
-ipython==8.21.0
-ipython-genutils==0.2.0
-isoduration==20.11.0
-jedi==0.19.1
-jieba==0.42.1
-Jinja2==3.1.3
-jiter==0.5.0
-joblib==1.3.2
-json5==0.9.14
-jsonpointer==3.0.0
-jsonschema==4.21.1
-jsonschema-specifications==2023.12.1
-jupyter-events==0.10.0
-jupyter-lsp==2.2.5
-jupyter_client==8.6.0
-jupyter_core==5.7.1
-jupyter_server==2.14.2
-jupyter_server_terminals==0.5.3
-jupyterlab==4.1.6
-jupyterlab_pygments==0.3.0
-jupyterlab_server==2.27.3
-jupytext==1.16.1
-kiwisolver==1.4.5
-langcodes==3.3.0
-lark==1.2.2
-lazy_loader==0.3
-librosa==0.10.1
-lm-format-enforcer==0.10.6
-Markdown==3.5.2
-markdown-it-py==3.0.0
-matplotlib==3.8.2
-matplotlib-inline==0.1.6
-mdit-py-plugins==0.4.0
-mdurl==0.1.2
-mistral_common==1.3.4
-mistune==3.0.2
-mkl==2021.1.1
-mkl-devel==2021.1.1
-mkl-include==2021.1.1
-mock==5.1.0
-mpmath==1.3.0
-msgpack==1.0.7
-msgspec==0.18.6
-multiprocess==0.70.16
-murmurhash==1.0.10
-nbclient==0.9.0
-nbconvert==7.16.0
-nbformat==5.9.2
-nest-asyncio==1.6.0
-networkx==2.6.3
-ninja==1.11.1.1
-nltk==3.9.1
-notebook==6.4.10
-notebook_shim==0.2.4
numpy==1.24.4
-nvidia-cublas-cu12==12.1.3.1
-nvidia-cuda-cupti-cu12==12.1.105
-nvidia-cuda-nvrtc-cu12==12.1.105
-nvidia-cuda-runtime-cu12==12.1.105
-nvidia-cudnn-cu12==9.1.0.70
-nvidia-cufft-cu12==11.0.2.54
-nvidia-curand-cu12==10.3.2.106
-nvidia-cusolver-cu12==11.4.5.107
-nvidia-cusparse-cu12==12.1.0.106
-nvidia-dali-cuda120==1.34.0
-nvidia-ml-py==12.560.30
-nvidia-nccl-cu12==2.20.5
-nvidia-nvjitlink-cu12==12.6.68
-nvidia-nvtx-cu12==12.1.105
-nvidia-pyindex==1.0.9
-nvitop==1.5.1
-oauthlib==3.2.2
-openai==1.43.0
-optree==0.10.0
-orjson==3.10.7
-overrides==7.7.0
-packaging==23.2
pandas==2.2.2
-pandocfilters==1.5.1
-parso==0.8.3
-partial-json-parser==0.2.1.1.post4
peft==0.12.0
-pexpect==4.9.0
-platformdirs==4.2.0
-pluggy==1.4.0
-pooch==1.8.0
-preshed==3.0.9
-prettytable==3.9.0
-prometheus-client==0.19.0
-prometheus-fastapi-instrumentator==7.0.0
-prompt-toolkit==3.0.43
-protobuf==4.24.4
-ptyprocess==0.7.0
-pure-eval==0.2.2
-py-cpuinfo==9.0.0
-pyarrow==17.0.0
-pyasn1==0.5.1
-pyasn1-modules==0.3.0
-pybind11==2.11.1
-pybind11-global==2.11.1
-pycountry==24.6.1
-pycparser==2.21
-pydantic==2.8.2
-pydantic_core==2.20.1
-pydub==0.25.1
-Pygments==2.17.2
-PyJWT==2.8.0
-pyparsing==3.1.1
-pytest==8.0.0
-pytest-flakefinder==1.1.0
-pytest-rerunfailures==13.0
-pytest-shard==0.1.2
-pytest-xdist==3.5.0
-python-dateutil==2.8.2
-python-dotenv==1.0.1
-python-hostlist==1.23.0
-python-json-logger==2.0.7
-python-multipart==0.0.9
-pytorch-quantization==2.1.2
-PyYAML==6.0.1
-pyzmq==25.1.2
-ray==2.35.0
-referencing==0.33.0
+psutil
regex==2023.12.25
-requests==2.32.3
-requests-oauthlib==1.3.1
-rfc3339-validator==0.1.4
-rfc3986-validator==0.1.1
-rouge-chinese==1.0.3
-rpds-py==0.17.1
-rsa==4.9
-ruff==0.6.3
-safetensors==0.4.4
-semantic-version==2.10.0
-Send2Trash==1.8.2
-sentencepiece==0.2.0
-shellingham==1.5.4
-shtab==1.7.1
-six==1.16.0
-smart-open==6.4.0
-sniffio==1.3.1
-sortedcontainers==2.4.0
-soundfile==0.12.1
-soupsieve==2.5
-soxr==0.3.7
-spacy==3.7.2
-spacy-legacy==3.0.12
-spacy-loggers==1.0.5
-sphinx_glpi_theme==0.6
-srsly==2.4.8
-sse-starlette==2.1.3
-stack-data==0.6.3
-starlette==0.38.4
+scipy
sympy==1.12
-tabulate==0.9.0
-tbb==2021.11.0
-tensorboard==2.9.0
-tensorboard-data-server==0.6.1
-tensorboard-plugin-wit==1.8.1
-termcolor==2.4.0
-terminado==0.18.0
-thinc==8.2.3
-threadpoolctl==3.2.0
tiktoken==0.7.0
-tinycss2==1.2.1
-tokenizers==0.19.1
-toml==0.10.2
-tomli==2.0.1
-tomlkit==0.12.0
torch==2.4.0
-torchvision==0.19.0
-tornado==6.4
tqdm==4.66.5
-traitlets==5.9.0
transformers==4.44.2
-triton==3.0.0
-trl==0.9.6
-typer==0.12.5
-types-dataclasses==0.6.6
-types-python-dateutil==2.9.0.20240821
-typing_extensions==4.12.2
-tyro==0.8.10
-tzdata==2024.1
-uri-template==1.3.0
-urllib3==2.2.2
-uvicorn==0.30.6
-uvloop==0.20.0
-vllm==0.6.0
-vllm-flash-attn==2.6.1
-wasabi==1.1.2
-watchfiles==0.24.0
-wcwidth==0.2.13
-weasel==0.3.4
-webcolors==24.8.0
-webencodings==0.5.1
-websocket-client==1.8.0
-websockets==12.0
-Werkzeug==3.0.1
-xdoctest==1.0.2
-xformers==0.0.27.post2
-xxhash==3.5.0
+vllm>=0.6.2
+wandb
+openai==1.43.0
+huggingface-hub
+loguru
+multiprocess==0.70.16
+pebble
+python-dateutil==2.8.2
+matplotlib==3.8.2
+sglang
+timeout-decorator
+func-timeout
+antlr4-python3-runtime==4.11.1
+latex2sympy2
+word2number
+pyarrow==17.0.0
+sentencepiece==0.2.0
+
diff --git a/scripts/make_train_sets.py b/scripts/make_train_sets.py
new file mode 100644
index 0000000..218ac81
--- /dev/null
+++ b/scripts/make_train_sets.py
@@ -0,0 +1,200 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Generate training sets for:
+- EM / Group-EM: 1, 5, 20 prompts (English only)
+- JSD: 100 factual/counterfactual pairs (x, x_swap) with gender swap
+
+Data sources:
+ - assets/triggers/occupations_en.txt
+ - assets/groups/en_definitional_pairs.json (female<->male definitional pairs; supports multiple schemas)
+ - assets/groups/en_female.txt / en_male.txt (for pronouns; we mainly use 'she'/'he')
+
+Avoids overlap with eval CTF examples in data/bias/ctf/ctf_en.jsonl.
+
+Outputs:
+ data/train/em_group/train_en_size{1,5,20}.jsonl
+ data/train/jsd/train_pairs_en_size100.jsonl
+"""
+import json, random, pathlib, re, sys
+from typing import List, Tuple, Dict, Any
+
+ROOT = pathlib.Path(__file__).resolve().parents[1]
+ASSETS = ROOT / "assets"
+DATA = ROOT / "data"
+OUT_EM = DATA / "train" / "em_group"
+OUT_JSD = DATA / "train" / "jsd"
+EVAL_CTF = DATA / "bias" / "ctf" / "ctf_en.jsonl"
+
+random.seed(2025)
+
+def read_lines(p: pathlib.Path) -> List[str]:
+ return [x.strip() for x in p.read_text(encoding="utf-8").splitlines() if x.strip()]
+
+def read_jsonl(p: pathlib.Path) -> List[dict]:
+ if not p.exists(): return []
+ return [json.loads(x) for x in p.read_text(encoding="utf-8").splitlines() if x.strip()]
+
+def _normalize_pair(a: Any, b: Any) -> Tuple[str,str]:
+ return (str(a).strip().lower(), str(b).strip().lower())
+
+def load_pairs_json_any(p: pathlib.Path) -> List[Tuple[str,str]]:
+ """
+ Robustly load definitional gender pairs from various schemas seen in the wild:
+ - [ ["woman","man"], ["girl","boy"], ... ]
+ - { "definitional": [ ["woman","man"], ... ] }
+ - [ {"f":"woman","m":"man"}, ... ] or keys {"female":...,"male":...}
+ - { "pairs":[... same as above ...] }
+ Returns a list of (female, male) lower-cased tuples.
+ """
+ if not p.exists():
+ return []
+ data = json.loads(p.read_text(encoding="utf-8"))
+ pairs: List[Tuple[str,str]] = []
+
+ def add_from_list(lst: List[Any]):
+ for item in lst:
+ if isinstance(item, list) and len(item) == 2:
+ a, b = item
+ pairs.append(_normalize_pair(a, b))
+ elif isinstance(item, dict):
+ # common key patterns
+ if "f" in item and "m" in item:
+ pairs.append(_normalize_pair(item["f"], item["m"]))
+ elif "female" in item and "male" in item:
+ pairs.append(_normalize_pair(item["female"], item["male"]))
+ elif "a" in item and "b" in item:
+ pairs.append(_normalize_pair(item["a"], item["b"]))
+
+ if isinstance(data, list):
+ add_from_list(data)
+ elif isinstance(data, dict):
+ for k in ("definitional", "definitional_pairs", "pairs"):
+ if isinstance(data.get(k), list):
+ add_from_list(data[k])
+
+ # de-dup & sanity
+ out = []
+ seen = set()
+ for a,b in pairs:
+ if not a or not b or a == b:
+ continue
+ key = (a,b)
+ if key not in seen:
+ seen.add(key); out.append(key)
+ return out
+
+def minimal_swap(text: str, pairs: List[Tuple[str,str]]) -> str:
+ """
+ Swap gendered tokens with word-boundary regex, preferring longer strings first.
+ """
+ if not pairs:
+ return text # no-op; caller will fallback to template-based swap
+ mapping: Dict[str,str] = {}
+ for a,b in pairs:
+ mapping[a] = b
+ mapping[b] = a
+ # sort keys by length desc to avoid partial overshadow
+ keys = sorted(mapping.keys(), key=len, reverse=True)
+ out = text
+ for k in keys:
+ # boundary-safe, case-insensitive; replace with mapping[k] in lowercase
+ pat = re.compile(rf"(?<!\w){re.escape(k)}(?!\w)", flags=re.IGNORECASE)
+ out = pat.sub(mapping[k], out)
+ return out
+
+def sample_templates(occupations: List[str], n: int) -> List[Tuple[str,str,str]]:
+ """
+ Return list of tuples: (occupation, she_template, he_template)
+ """
+ SHE_FORMS = [
+ "The {occ} said that she would review the design.",
+ "The {occ} explained that she needed more time.",
+ "The {occ} promised that she will lead the next meeting.",
+ "The {occ} reported that she completed the task.",
+ "The {occ} noted that she prefers the alternative plan.",
+ ]
+ HE_FORMS = [
+ t.replace(" she ", " he ").replace("She ", "He ").replace(" she", " he")
+ for t in SHE_FORMS
+ ]
+ random.shuffle(occupations)
+ out = []
+ i = 0
+ while len(out) < n and i < 10*n:
+ occ = occupations[i % len(occupations)]
+ idx = random.randrange(len(SHE_FORMS))
+ s_she = SHE_FORMS[idx].format(occ=occ)
+ s_he = HE_FORMS[idx].format(occ=occ)
+ out.append((occ, s_she, s_he))
+ i += 1
+ return out[:n]
+
+def main():
+ OUT_EM.mkdir(parents=True, exist_ok=True)
+ OUT_JSD.mkdir(parents=True, exist_ok=True)
+
+ # sources
+ occs = read_lines(ASSETS / "triggers" / "occupations_en.txt")
+ pairs_json = ASSETS / "groups" / "en_definitional_pairs.json"
+ pairs = load_pairs_json_any(pairs_json)
+ eval_ctf = read_jsonl(EVAL_CTF)
+ eval_x = set([r.get("x","").strip() for r in eval_ctf] + [r.get("x_swap","").strip() for r in eval_ctf])
+
+ # ---- EM / Group-EM prompts (she-variant prompts; labels不需要)
+ for size in [1,5,20]:
+ triples = sample_templates(occs, size*3) # oversample for filtering
+ rows = []
+ for occ, s_she, s_he in triples:
+ if s_she in eval_x or s_he in eval_x:
+ continue
+ rows.append({"id": f"em_{len(rows):06d}", "lang":"en", "occupation": occ, "prompt": s_she})
+ if len(rows) >= size: break
+ outp = OUT_EM / f"train_en_size{size}.jsonl"
+ outp.write_text("\n".join(json.dumps(r) for r in rows) + ("\n" if rows else ""), encoding="utf-8")
+ print("Wrote", outp, "N=", len(rows))
+
+ # ---- JSD pairs (x, x_swap)
+ size_jsd = 100
+ triples = sample_templates(occs, size_jsd*4) # oversample more to be safe
+ pairs_out = []
+ for occ, s_she, s_he in triples:
+ x = s_she
+ if x in eval_x:
+ continue
+
+ # try definitional-pair swap first
+ x_swap = minimal_swap(x, pairs)
+ # fallback: if no change, use our explicit he-template
+ if x.strip().lower() == x_swap.strip().lower():
+ x_swap = s_he
+
+ if x_swap in eval_x:
+ continue
+ if x.strip().lower() == x_swap.strip().lower(): # still identical? extremely unlikely
+ continue
+
+ pairs_out.append({
+ "id": f"jsd_{len(pairs_out):06d}",
+ "lang":"en",
+ "occupation": occ,
+ "prompt": x,
+ "prompt_swap": x_swap
+ })
+ if len(pairs_out) >= size_jsd: break
+
+ outp2 = OUT_JSD / "train_pairs_en_size100.jsonl"
+ outp2.write_text("\n".join(json.dumps(r) for r in pairs_out) + ("\n" if pairs_out else ""), encoding="utf-8")
+ print("Wrote", outp2, "N=", len(pairs_out))
+
+ # quick diagnostics
+ if len(pairs_out) == 0:
+ print("[WARN] JSD pairs = 0. Diagnostics:")
+ print(" - occupations:", len(occs))
+ print(" - definitional pairs loaded:", len(pairs))
+ print(" - eval_x size:", len(eval_x))
+ print(" - Check assets/groups/en_definitional_pairs.json schema.")
+ sys.exit(2)
+
+if __name__ == "__main__":
+ main()
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/__init__.py
diff --git a/src/losses.py b/src/losses.py
new file mode 100644
index 0000000..f8f491e
--- /dev/null
+++ b/src/losses.py
@@ -0,0 +1,248 @@
+# -*- coding: utf-8 -*-
+"""
+Losses for:
+- EM (entropy minimization)
+- Group-EM (entropy-difference between female/male token groups)
+- JSD counterfactual invariance (x vs swap(x)) with optional Top-K
+Guards:
+- Mass parity: (piF - piM)^2
+- Stability: KL( p_theta || p_base )
+Gating:
+- Top-K trigger on {F ∪ M} at each step (boundary-safe happens at text level during data build/eval)
+Note:
+- All losses are averaged over steps where gate==1 AND (optionally) generation mask==1.
+"""
+from typing import Dict, List, Optional, Tuple
+import torch
+import torch.nn.functional as F
+
+def map_words_to_token_ids(tok, words: List[str]) -> List[int]:
+ ids = set()
+ for w in words:
+ for form in (w, " " + w):
+ enc = tok(form, add_special_tokens=False, return_tensors=None)
+ toks = enc["input_ids"]
+ if len(toks) == 1:
+ ids.add(int(toks[0]))
+ elif len(toks) > 1:
+ ids.add(int(toks[0])) # first-piece fallback
+ return sorted(ids)
+
+def probs_from_logits(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
+ if temperature <= 0:
+ # avoid div-by-zero; at T=0 use softmax on raw logits (equivalent to no scaling)
+ return F.softmax(logits, dim=-1)
+ return F.softmax(logits / temperature, dim=-1)
+
+def topk_gate(logits: torch.Tensor, fem_ids: List[int], male_ids: List[int], k: int = 20) -> torch.Tensor:
+ """
+ logits: [B,T,V]
+ Return gate mask [B,T] == 1 if top-k at step contains any F∪M id.
+ """
+ B,T,V = logits.shape
+ topk = torch.topk(logits, k=min(k, V), dim=-1).indices # [B,T,k]
+ ids = torch.tensor(list(set(fem_ids) | set(male_ids)), device=logits.device, dtype=torch.long)
+ if ids.numel() == 0:
+ return torch.zeros(B,T, dtype=torch.float32, device=logits.device)
+ # Compare with broadcasting
+ match = (topk.unsqueeze(-1) == ids.view(1,1,1,-1)).any(dim=-1) # [B,T,k] -> [B,T]
+ return match.float()
+
+def group_masses(probs: torch.Tensor, fem_ids: List[int], male_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ probs: [B,T,V]
+ Returns piF, piM of shape [B,T]
+ """
+ if len(fem_ids) == 0 and len(male_ids) == 0:
+ return torch.zeros_like(probs[...,0]), torch.zeros_like(probs[...,0])
+ idxF = torch.tensor(fem_ids, device=probs.device, dtype=torch.long) if len(fem_ids)>0 else None
+ idxM = torch.tensor(male_ids, device=probs.device, dtype=torch.long) if len(male_ids)>0 else None
+ piF = probs[..., idxF].sum(dim=-1) if idxF is not None else torch.zeros_like(probs[...,0])
+ piM = probs[..., idxM].sum(dim=-1) if idxM is not None else torch.zeros_like(probs[...,0])
+ return piF, piM
+
+def normalized_entropy(sub_probs: torch.Tensor) -> torch.Tensor:
+ """
+ sub_probs: [*, K]
+ Return normalized entropy in [0,1]: H(p)/log(K)
+ """
+ eps = 1e-12
+ K = sub_probs.size(-1)
+ H = -(sub_probs.clamp_min(eps) * sub_probs.clamp_min(eps).log()).sum(dim=-1)
+ denom = torch.log(torch.tensor(float(K), device=sub_probs.device))
+ return H / (denom + eps)
+
+def group_entropies(probs: torch.Tensor, fem_ids: List[int], male_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ probs: [B,T,V] -> qF [B,T,|F|], qM [B,T,|M|] -> HF, HM in [0,1]
+ """
+ eps = 1e-12
+ idxF = torch.tensor(fem_ids, device=probs.device, dtype=torch.long) if len(fem_ids)>0 else None
+ idxM = torch.tensor(male_ids, device=probs.device, dtype=torch.long) if len(male_ids)>0 else None
+
+ if idxF is None:
+ HF = torch.zeros(probs.shape[:2], device=probs.device)
+ else:
+ pF = probs[..., idxF] # [B,T,|F|]
+ piF = pF.sum(dim=-1, keepdim=True) + eps
+ qF = pF / piF
+ HF = normalized_entropy(qF)
+
+ if idxM is None:
+ HM = torch.zeros(probs.shape[:2], device=probs.device)
+ else:
+ pM = probs[..., idxM]
+ piM = pM.sum(dim=-1, keepdim=True) + eps
+ qM = pM / piM
+ HM = normalized_entropy(qM)
+
+ return HF, HM
+
+def reduce_steps(x: torch.Tensor, step_mask: torch.Tensor) -> torch.Tensor:
+ """
+ x: [B,T], step_mask: [B,T] in {0,1}
+ Return mean over steps where mask==1 (avoid div by 0).
+ """
+ w = step_mask
+ s = (x * w).sum()
+ d = w.sum().clamp_min(1.0)
+ return s / d
+
+# ---------------- EM ----------------
+def loss_em(logits: torch.Tensor, gen_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
+ """
+ Entropy minimization over generation steps (no gating).
+ logits: [B,T,V]; gen_mask: [B,T] 1 for generation steps (non-prompt)
+ """
+ probs = probs_from_logits(logits) # [B,T,V]
+ eps = 1e-12
+ Ht = -(probs.clamp_min(eps) * probs.clamp_min(eps).log()).sum(dim=-1) # [B,T]
+ L = reduce_steps(Ht, gen_mask)
+ return L, {"H_mean": float(reduce_steps(Ht, gen_mask).item())}
+
+# ------------- Group-EM -------------
+def loss_group_em(
+ logits: torch.Tensor,
+ gen_mask: torch.Tensor,
+ fem_ids: List[int],
+ male_ids: List[int],
+ gate_mask: Optional[torch.Tensor] = None,
+ lambda_mass: float = 0.0,
+ beta_kl: float = 0.0,
+ ref_probs: Optional[torch.Tensor] = None,
+) -> Tuple[torch.Tensor, Dict]:
+ """
+ Group-EM loss with optional guards.
+ - core: (H_F - H_M)^2
+ - mass guard: (piF - piM)^2
+ - stability: KL( p || pref )
+ """
+ probs = probs_from_logits(logits) # [B,T,V]
+ HF, HM = group_entropies(probs, fem_ids, male_ids) # [B,T], [B,T]
+ core = (HF - HM) ** 2 # [B,T]
+
+ piF, piM = group_masses(probs, fem_ids, male_ids)
+ Lmass = (piF - piM) ** 2 # [B,T]
+
+ if gate_mask is None:
+ step_mask = gen_mask
+ else:
+ step_mask = (gen_mask * gate_mask).float()
+
+ L_core = reduce_steps(core, step_mask)
+ L_mass = reduce_steps(Lmass, step_mask)
+
+ L_kl = torch.tensor(0.0, device=logits.device)
+ if beta_kl > 0.0 and ref_probs is not None:
+ eps = 1e-12
+ p = probs.clamp_min(eps)
+ q = ref_probs.clamp_min(eps)
+ KL = (p * (p.log() - q.log())).sum(dim=-1) # [B,T]
+ L_kl = reduce_steps(KL, step_mask)
+
+ loss = L_core + lambda_mass * L_mass + beta_kl * L_kl
+ extras = {
+ "L_core": float(L_core.item()),
+ "L_mass": float(L_mass.item()),
+ "L_kl": float(L_kl.item()) if isinstance(L_kl, torch.Tensor) else float(L_kl),
+ "piF_mean": float(reduce_steps(piF, step_mask).item()),
+ "piM_mean": float(reduce_steps(piM, step_mask).item()),
+ "HF_mean": float(reduce_steps(HF, step_mask).item()),
+ "HM_mean": float(reduce_steps(HM, step_mask).item()),
+ }
+ return loss, extras
+
+# --------------- JSD ---------------
+def _jsd_full(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
+ m = 0.5 * (p + q)
+ return 0.5 * (p * (p.clamp_min(eps).log() - m.clamp_min(eps).log())).sum(dim=-1) + \
+ 0.5 * (q * (q.clamp_min(eps).log() - m.clamp_min(eps).log())).sum(dim=-1)
+
+def _jsd_topk(p: torch.Tensor, q: torch.Tensor, K: int) -> torch.Tensor:
+ V = p.size(-1)
+ K = min(K, V)
+ idx_p = torch.topk(p, k=K, dim=-1).indices
+ idx_q = torch.topk(q, k=K, dim=-1).indices
+ idx = torch.cat([idx_p, idx_q], dim=-1).unique(dim=-1) # union
+ pK = p.gather(-1, idx); qK = q.gather(-1, idx)
+ mK = 0.5 * (pK + qK)
+ eps = 1e-12
+ return 0.5 * (pK * (pK.clamp_min(eps).log() - mK.clamp_min(eps).log())).sum(dim=-1) + \
+ 0.5 * (qK * (qK.clamp_min(eps).log() - mK.clamp_min(eps).log())).sum(dim=-1)
+
+def loss_jsd(
+ logits_f: torch.Tensor, # [B,T,V]
+ logits_c: torch.Tensor, # [B,T,V]
+ gen_mask: torch.Tensor, # [B,T]
+ fem_ids: List[int],
+ male_ids: List[int],
+ gate_mask_f: Optional[torch.Tensor] = None,
+ gate_mask_c: Optional[torch.Tensor] = None,
+ lambda_mass: float = 0.0,
+ beta_kl: float = 0.0,
+ ref_probs_f: Optional[torch.Tensor] = None,
+ topk_jsd: int = 0
+) -> Tuple[torch.Tensor, Dict]:
+ """
+ JSD(p||q) averaged over steps with gating (factual and counterfactual separately gated).
+ Also includes mass parity on both branches and optional stability to base on factual branch.
+ """
+ p = probs_from_logits(logits_f) # [B,T,V]
+ q = probs_from_logits(logits_c) # [B,T,V]
+
+ if topk_jsd and topk_jsd > 0:
+ J = _jsd_topk(p, q, K=topk_jsd) # [B,T]
+ else:
+ J = _jsd_full(p, q) # [B,T]
+
+ # step mask: require gate on factual (and optionally also on counterfactual)
+ if gate_mask_f is None:
+ step_mask = gen_mask
+ else:
+ step_mask = (gen_mask * gate_mask_f).float()
+
+ L_jsd = reduce_steps(J, step_mask)
+
+ # mass parity on each branch
+ piF_f, piM_f = group_masses(p, fem_ids, male_ids)
+ piF_c, piM_c = group_masses(q, fem_ids, male_ids)
+ L_mass = reduce_steps((piF_f - piM_f)**2, step_mask) + reduce_steps((piF_c - piM_c)**2, step_mask)
+
+ # stability to base (factual branch)
+ L_kl = torch.tensor(0.0, device=logits_f.device)
+ if beta_kl > 0.0 and ref_probs_f is not None:
+ eps = 1e-12
+ p0 = ref_probs_f.clamp_min(eps)
+ L_kl = reduce_steps((p.clamp_min(eps) * (p.clamp_min(eps).log() - p0.log())).sum(dim=-1), step_mask)
+
+ loss = L_jsd + lambda_mass * L_mass + beta_kl * L_kl
+ extras = {
+ "L_jsd": float(L_jsd.item()),
+ "L_mass": float(L_mass.item()),
+ "L_kl": float(L_kl.item()) if isinstance(L_kl, torch.Tensor) else float(L_kl),
+ "piF_f": float(reduce_steps(piF_f, step_mask).item()),
+ "piM_f": float(reduce_steps(piM_f, step_mask).item()),
+ "piF_c": float(reduce_steps(piF_c, step_mask).item()),
+ "piM_c": float(reduce_steps(piM_c, step_mask).item()),
+ }
+ return loss, extras
diff --git a/train_runner.py b/train_runner.py
new file mode 100644
index 0000000..f7c8c25
--- /dev/null
+++ b/train_runner.py
@@ -0,0 +1,212 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Few-step LoRA training runner for EM / Group-EM / JSD.
+- Generates a short greedy continuation for each prompt (T=0), then teacher-forces
+ on [prompt + continuation] to compute stepwise logits and apply losses.
+- Gating: Top-K membership of {F ∪ M} at each step (k configurable).
+- Guards: mass parity (lambda) and stability KL to a frozen base model (beta).
+Outputs:
+ adapter weights + a JSON log of loss components.
+This script is single-GPU; to parallelize, launch multiple processes with different CUDA_VISIBLE_DEVICES.
+"""
+import os, json, math, time, random, pathlib, argparse
+from typing import List, Dict, Tuple
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import AutoTokenizer, AutoModelForCausalLM, get_cosine_schedule_with_warmup
+from peft import LoraConfig, get_peft_model
+from src.losses import (
+ map_words_to_token_ids, probs_from_logits, topk_gate,
+ loss_em, loss_group_em, loss_jsd
+)
+
+def set_seed(s: int):
+ random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
+
+def read_jsonl(p: str) -> List[Dict]:
+ return [json.loads(x) for x in open(p, "r", encoding="utf-8") if x.strip()]
+
+def save_json(p: str, obj):
+ pathlib.Path(p).parent.mkdir(parents=True, exist_ok=True)
+ open(p, "w", encoding="utf-8").write(json.dumps(obj, indent=2))
+
+@torch.no_grad()
+def greedy_generate_ids(model, tok, prompt_ids: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
+ """
+ prompt_ids: [1, L]
+ return gen_ids: [1, G] (without BOS/EOS trimming)
+ """
+ out = model.generate(
+ input_ids=prompt_ids,
+ attention_mask=torch.ones_like(prompt_ids),
+ do_sample=False, temperature=0.0, top_p=1.0,
+ max_new_tokens=max_new_tokens,
+ eos_token_id=tok.eos_token_id
+ )
+ # decode original prompt length to slice
+ gen_ids = out[:, prompt_ids.size(1):] # [1, G]
+ return gen_ids
+
+def build_lora(model) -> nn.Module:
+ lconf = LoraConfig(
+ r=8, lora_alpha=16, lora_dropout=0.0,
+ bias="none", task_type="CAUSAL_LM",
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] # Qwen-style
+ )
+ return get_peft_model(model, lconf)
+
+def main():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct")
+ ap.add_argument("--loss_type", type=str, choices=["em","group_em","jsd"], required=True)
+ ap.add_argument("--train_file", type=str, help="for em/group_em: JSONL with {'prompt':...}")
+ ap.add_argument("--train_pairs", type=str, help="for jsd: JSONL with {'prompt','prompt_swap'}")
+ ap.add_argument("--groups_dir", type=str, default="assets/groups")
+ ap.add_argument("--output_dir", type=str, required=True)
+
+ # optimization
+ ap.add_argument("--max_steps", type=int, default=10)
+ ap.add_argument("--learning_rate", type=float, default=2e-5)
+ ap.add_argument("--warmup_steps", type=int, default=0)
+ ap.add_argument("--weight_decay", type=float, default=0.0)
+ ap.add_argument("--grad_accum", type=int, default=32)
+ ap.add_argument("--gen_len", type=int, default=64)
+ ap.add_argument("--seed", type=int, default=2025)
+
+ # guards / gating
+ ap.add_argument("--lambda_mass", type=float, default=0.0)
+ ap.add_argument("--beta_kl", type=float, default=0.0)
+ ap.add_argument("--topk_gate", type=int, default=20)
+ ap.add_argument("--topk_jsd", type=int, default=0) # 0 => full vocab JSD
+
+ # dtype / device
+ ap.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16","float16","float32"])
+ ap.add_argument("--device", type=str, default="cuda")
+
+ args = ap.parse_args()
+ set_seed(args.seed)
+
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
+
+ # Load tokenizer/model + LoRA (trainable) and base (frozen for KL)
+ tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
+ base_model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype).to(device)
+ base_model.eval()
+ model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype).to(device)
+ model = build_lora(model)
+ model.train()
+
+ # Build optimizer/scheduler
+ opt = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
+ sch = get_cosine_schedule_with_warmup(opt, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps)
+
+ # Build gender token id sets (for group and jsd; EM ignores)
+ fem_words = [w.strip().lower() for w in open(os.path.join(args.groups_dir,"en_female.txt"),"r",encoding="utf-8")]
+ male_words = [w.strip().lower() for w in open(os.path.join(args.groups_dir,"en_male.txt"),"r",encoding="utf-8")]
+ fem_ids = map_words_to_token_ids(tok, fem_words)
+ male_ids = map_words_to_token_ids(tok, male_words)
+
+ # Load training data
+ if args.loss_type in ("em","group_em"):
+ assert args.train_file, "--train_file is required for em/group_em"
+ rows = read_jsonl(args.train_file)
+ assert len(rows) > 0, "Empty train_file"
+ prompts = [r["prompt"] for r in rows]
+ else:
+ assert args.loss_type == "jsd" and args.train_pairs, "--train_pairs is required for jsd"
+ rows = read_jsonl(args.train_pairs)
+ assert len(rows) > 0, "Empty train_pairs"
+ pairs = [(r["prompt"], r["prompt_swap"]) for r in rows]
+
+ # Logging
+ outdir = pathlib.Path(args.output_dir)
+ (outdir/"logs").mkdir(parents=True, exist_ok=True)
+ (outdir/"adapter").mkdir(parents=True, exist_ok=True)
+ json_log = []
+
+ step = 0
+ while step < args.max_steps:
+ # simple cyclic sampler
+ if args.loss_type in ("em","group_em"):
+ idx = step % len(prompts)
+ prompt = prompts[idx]
+ # 1) generate a short continuation (greedy)
+ enc = tok(prompt, return_tensors="pt")
+ gen_ids = greedy_generate_ids(model, tok, enc.input_ids.to(device), max_new_tokens=args.gen_len) # [1,G]
+ concat_ids = torch.cat([enc.input_ids.to(device), gen_ids], dim=1) # [1,L+G]
+ attn = torch.ones_like(concat_ids).to(device)
+
+ # 2) forward current model (teacher forcing) for loss
+ out = model(input_ids=concat_ids, attention_mask=attn)
+ logits = out.logits[:, :-1, :] # predict next for each pos except last targetless
+ T = logits.size(1)
+ # generation mask: 0 for prompt part, 1 for generated part (shifted by 1)
+ gen_mask = torch.zeros((1,T), dtype=torch.float32, device=device)
+ gen_mask[:, enc.input_ids.size(1)-1:] = 1.0 # from the last prompt position onward
+
+ if args.loss_type == "em":
+ loss, extras = loss_em(logits, gen_mask)
+ else:
+ gate = topk_gate(logits, fem_ids, male_ids, k=args.topk_gate) # [1,T]
+ with torch.no_grad():
+ out_ref = base_model(input_ids=concat_ids, attention_mask=attn)
+ ref_probs = probs_from_logits(out_ref.logits[:, :-1, :])
+ loss, extras = loss_group_em(
+ logits, gen_mask, fem_ids, male_ids, gate_mask=gate,
+ lambda_mass=args.lambda_mass, beta_kl=args.beta_kl, ref_probs=ref_probs
+ )
+
+ else:
+ # JSD branch
+ idx = step % len(pairs)
+ x, xs = pairs[idx]
+
+ # factual
+ enc_f = tok(x, return_tensors="pt")
+ gen_f = greedy_generate_ids(model, tok, enc_f.input_ids.to(device), max_new_tokens=args.gen_len)
+ all_f = torch.cat([enc_f.input_ids.to(device), gen_f], dim=1)
+ attn_f = torch.ones_like(all_f).to(device)
+ out_f = model(input_ids=all_f, attention_mask=attn_f)
+ logits_f = out_f.logits[:, :-1, :]
+ T_f = logits_f.size(1)
+ gen_mask_f = torch.zeros((1,T_f), dtype=torch.float32, device=device)
+ gen_mask_f[:, enc_f.input_ids.size(1)-1:] = 1.0
+ gate_f = topk_gate(logits_f, fem_ids, male_ids, k=args.topk_gate)
+
+ # counterfactual
+ enc_c = tok(xs, return_tensors="pt")
+ gen_c = greedy_generate_ids(model, tok, enc_c.input_ids.to(device), max_new_tokens=args.gen_len)
+ all_c = torch.cat([enc_c.input_ids.to(device), gen_c], dim=1)
+ attn_c = torch.ones_like(all_c).to(device)
+ out_c = model(input_ids=all_c, attention_mask=attn_c)
+ logits_c = out_c.logits[:, :-1, :]
+
+ with torch.no_grad():
+ out_ref_f = base_model(input_ids=all_f, attention_mask=attn_f)
+ ref_probs_f = probs_from_logits(out_ref_f.logits[:, :-1, :])
+
+ loss, extras = loss_jsd(
+ logits_f, logits_c, gen_mask_f, fem_ids, male_ids,
+ gate_mask_f=gate_f, lambda_mass=args.lambda_mass, beta_kl=args.beta_kl,
+ ref_probs_f=ref_probs_f, topk_jsd=args.topk_jsd
+ )
+
+ (loss / args.grad_accum).backward()
+ if (step + 1) % args.grad_accum == 0 or (step + 1) == args.max_steps:
+ opt.step(); sch.step(); opt.zero_grad(set_to_none=True)
+
+ step += 1
+ json_log.append({"step": step, "loss": float(loss.item()), **extras})
+ if step % 1 == 0:
+ print(f"[{step}/{args.max_steps}] loss={loss.item():.6f} | {extras}")
+
+ # save adapter
+ model.save_pretrained(str(outdir/"adapter"))
+ save_json(str(outdir/"logs"/"train_log.json"), {"args": vars(args), "log": json_log})
+ print("Saved adapter to", outdir/"adapter")
+
+if __name__ == "__main__":
+ main()