diff options
| -rw-r--r-- | .tmp_gpu_check.py | 13 | ||||
| -rw-r--r-- | .tmp_gpu_check2.py | 16 | ||||
| -rw-r--r-- | .tmp_gpu_check3.py | 2 | ||||
| -rw-r--r-- | data/train/em_group/train_en_size1.jsonl | 1 | ||||
| -rw-r--r-- | data/train/em_group/train_en_size20.jsonl | 20 | ||||
| -rw-r--r-- | data/train/em_group/train_en_size5.jsonl | 5 | ||||
| -rw-r--r-- | data/train/jsd/train_pairs_en_size100.jsonl | 100 | ||||
| -rw-r--r-- | requirements.txt | 283 | ||||
| -rw-r--r-- | scripts/make_train_sets.py | 200 | ||||
| -rw-r--r-- | src/__init__.py | 0 | ||||
| -rw-r--r-- | src/losses.py | 248 | ||||
| -rw-r--r-- | train_runner.py | 212 |
12 files changed, 837 insertions, 263 deletions
diff --git a/.tmp_gpu_check.py b/.tmp_gpu_check.py new file mode 100644 index 0000000..706491c --- /dev/null +++ b/.tmp_gpu_check.py @@ -0,0 +1,13 @@ +import os +print("ENV CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES")) +try: + import torch + print("torch version:", torch.__version__) + print("torch.version.cuda:", getattr(torch.version, "cuda", None)) + print("torch.cuda.is_available():", torch.cuda.is_available()) + print("torch.cuda.device_count():", torch.cuda.device_count()) + if torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + print(f"[{i}]", torch.cuda.get_device_name(i)) +except Exception as e: + print("torch error:", repr(e)) diff --git a/.tmp_gpu_check2.py b/.tmp_gpu_check2.py new file mode 100644 index 0000000..7555155 --- /dev/null +++ b/.tmp_gpu_check2.py @@ -0,0 +1,16 @@ +import os +print("ENV CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES")) +try: + import torch + print("torch:", torch.__version__, "cuda:", getattr(torch.version, "cuda", None)) + print("built_with_cuda:", torch.backends.cuda.is_built()) + print("device_count:", torch.cuda.device_count()) + print("is_available:", torch.cuda.is_available()) + if torch.cuda.device_count() > 0: + for i in range(torch.cuda.device_count()): + try: + print(f"[{i}]", torch.cuda.get_device_name(i)) + except Exception as e: + print(f"[{i}] name error:", e) +except Exception as e: + print("torch error:", repr(e)) diff --git a/.tmp_gpu_check3.py b/.tmp_gpu_check3.py new file mode 100644 index 0000000..e88f10e --- /dev/null +++ b/.tmp_gpu_check3.py @@ -0,0 +1,2 @@ +import os, torch +print(env diff --git a/data/train/em_group/train_en_size1.jsonl b/data/train/em_group/train_en_size1.jsonl new file mode 100644 index 0000000..c8993ed --- /dev/null +++ b/data/train/em_group/train_en_size1.jsonl @@ -0,0 +1 @@ +{"id": "em_000000", "lang": "en", "occupation": "janitor", "prompt": "The janitor promised that she will lead the next meeting."} diff --git a/data/train/em_group/train_en_size20.jsonl b/data/train/em_group/train_en_size20.jsonl new file mode 100644 index 0000000..d326dac --- /dev/null +++ b/data/train/em_group/train_en_size20.jsonl @@ -0,0 +1,20 @@ +{"id": "em_000000", "lang": "en", "occupation": "firefighter", "prompt": "The firefighter noted that she prefers the alternative plan."} +{"id": "em_000001", "lang": "en", "occupation": "secretary", "prompt": "The secretary said that she would review the design."} +{"id": "em_000002", "lang": "en", "occupation": "librarian", "prompt": "The librarian said that she would review the design."} +{"id": "em_000003", "lang": "en", "occupation": "administrator", "prompt": "The administrator explained that she needed more time."} +{"id": "em_000004", "lang": "en", "occupation": "accountant", "prompt": "The accountant promised that she will lead the next meeting."} +{"id": "em_000005", "lang": "en", "occupation": "lawyer", "prompt": "The lawyer promised that she will lead the next meeting."} +{"id": "em_000006", "lang": "en", "occupation": "nurse", "prompt": "The nurse noted that she prefers the alternative plan."} +{"id": "em_000007", "lang": "en", "occupation": "advisor", "prompt": "The advisor said that she would review the design."} +{"id": "em_000008", "lang": "en", "occupation": "cashier", "prompt": "The cashier said that she would review the design."} +{"id": "em_000009", "lang": "en", "occupation": "investigator", "prompt": "The investigator noted that she prefers the alternative plan."} +{"id": "em_000010", "lang": "en", "occupation": "therapist", "prompt": "The therapist said that she would review the design."} +{"id": "em_000011", "lang": "en", "occupation": "salesperson", "prompt": "The salesperson noted that she prefers the alternative plan."} +{"id": "em_000012", "lang": "en", "occupation": "surgeon", "prompt": "The surgeon noted that she prefers the alternative plan."} +{"id": "em_000013", "lang": "en", "occupation": "officer", "prompt": "The officer promised that she will lead the next meeting."} +{"id": "em_000014", "lang": "en", "occupation": "machinist", "prompt": "The machinist promised that she will lead the next meeting."} +{"id": "em_000015", "lang": "en", "occupation": "hygienist", "prompt": "The hygienist explained that she needed more time."} +{"id": "em_000016", "lang": "en", "occupation": "chemist", "prompt": "The chemist promised that she will lead the next meeting."} +{"id": "em_000017", "lang": "en", "occupation": "mechanic", "prompt": "The mechanic noted that she prefers the alternative plan."} +{"id": "em_000018", "lang": "en", "occupation": "supervisor", "prompt": "The supervisor promised that she will lead the next meeting."} +{"id": "em_000019", "lang": "en", "occupation": "engineer", "prompt": "The engineer explained that she needed more time."} diff --git a/data/train/em_group/train_en_size5.jsonl b/data/train/em_group/train_en_size5.jsonl new file mode 100644 index 0000000..c08a4e5 --- /dev/null +++ b/data/train/em_group/train_en_size5.jsonl @@ -0,0 +1,5 @@ +{"id": "em_000000", "lang": "en", "occupation": "pharmacist", "prompt": "The pharmacist explained that she needed more time."} +{"id": "em_000001", "lang": "en", "occupation": "cashier", "prompt": "The cashier reported that she completed the task."} +{"id": "em_000002", "lang": "en", "occupation": "veterinarian", "prompt": "The veterinarian noted that she prefers the alternative plan."} +{"id": "em_000003", "lang": "en", "occupation": "hygienist", "prompt": "The hygienist explained that she needed more time."} +{"id": "em_000004", "lang": "en", "occupation": "scientist", "prompt": "The scientist promised that she will lead the next meeting."} diff --git a/data/train/jsd/train_pairs_en_size100.jsonl b/data/train/jsd/train_pairs_en_size100.jsonl new file mode 100644 index 0000000..b05e46b --- /dev/null +++ b/data/train/jsd/train_pairs_en_size100.jsonl @@ -0,0 +1,100 @@ +{"id": "jsd_000000", "lang": "en", "occupation": "supervisor", "prompt": "The supervisor promised that she will lead the next meeting.", "prompt_swap": "The supervisor promised that he will lead the next meeting."} +{"id": "jsd_000001", "lang": "en", "occupation": "nurse", "prompt": "The nurse explained that she needed more time.", "prompt_swap": "The nurse explained that he needed more time."} +{"id": "jsd_000002", "lang": "en", "occupation": "hygienist", "prompt": "The hygienist said that she would review the design.", "prompt_swap": "The hygienist said that he would review the design."} +{"id": "jsd_000003", "lang": "en", "occupation": "programmer", "prompt": "The programmer reported that she completed the task.", "prompt_swap": "The programmer reported that he completed the task."} +{"id": "jsd_000004", "lang": "en", "occupation": "machinist", "prompt": "The machinist noted that she prefers the alternative plan.", "prompt_swap": "The machinist noted that he prefers the alternative plan."} +{"id": "jsd_000005", "lang": "en", "occupation": "specialist", "prompt": "The specialist noted that she prefers the alternative plan.", "prompt_swap": "The specialist noted that he prefers the alternative plan."} +{"id": "jsd_000006", "lang": "en", "occupation": "practitioner", "prompt": "The practitioner explained that she needed more time.", "prompt_swap": "The practitioner explained that he needed more time."} +{"id": "jsd_000007", "lang": "en", "occupation": "paramedic", "prompt": "The paramedic reported that she completed the task.", "prompt_swap": "The paramedic reported that he completed the task."} +{"id": "jsd_000008", "lang": "en", "occupation": "investigator", "prompt": "The investigator explained that she needed more time.", "prompt_swap": "The investigator explained that he needed more time."} +{"id": "jsd_000009", "lang": "en", "occupation": "administrator", "prompt": "The administrator explained that she needed more time.", "prompt_swap": "The administrator explained that he needed more time."} +{"id": "jsd_000010", "lang": "en", "occupation": "pharmacist", "prompt": "The pharmacist said that she would review the design.", "prompt_swap": "The pharmacist said that he would review the design."} +{"id": "jsd_000011", "lang": "en", "occupation": "surgeon", "prompt": "The surgeon noted that she prefers the alternative plan.", "prompt_swap": "The surgeon noted that he prefers the alternative plan."} +{"id": "jsd_000012", "lang": "en", "occupation": "officer", "prompt": "The officer said that she would review the design.", "prompt_swap": "The officer said that he would review the design."} +{"id": "jsd_000013", "lang": "en", "occupation": "plumber", "prompt": "The plumber promised that she will lead the next meeting.", "prompt_swap": "The plumber promised that he will lead the next meeting."} +{"id": "jsd_000014", "lang": "en", "occupation": "manager", "prompt": "The manager reported that she completed the task.", "prompt_swap": "The manager reported that he completed the task."} +{"id": "jsd_000015", "lang": "en", "occupation": "nutritionist", "prompt": "The nutritionist promised that she will lead the next meeting.", "prompt_swap": "The nutritionist promised that he will lead the next meeting."} +{"id": "jsd_000016", "lang": "en", "occupation": "janitor", "prompt": "The janitor reported that she completed the task.", "prompt_swap": "The janitor reported that he completed the task."} +{"id": "jsd_000017", "lang": "en", "occupation": "planner", "prompt": "The planner said that she would review the design.", "prompt_swap": "The planner said that he would review the design."} +{"id": "jsd_000018", "lang": "en", "occupation": "veterinarian", "prompt": "The veterinarian said that she would review the design.", "prompt_swap": "The veterinarian said that he would review the design."} +{"id": "jsd_000019", "lang": "en", "occupation": "technician", "prompt": "The technician promised that she will lead the next meeting.", "prompt_swap": "The technician promised that he will lead the next meeting."} +{"id": "jsd_000020", "lang": "en", "occupation": "accountant", "prompt": "The accountant explained that she needed more time.", "prompt_swap": "The accountant explained that he needed more time."} +{"id": "jsd_000021", "lang": "en", "occupation": "baker", "prompt": "The baker said that she would review the design.", "prompt_swap": "The baker said that he would review the design."} +{"id": "jsd_000022", "lang": "en", "occupation": "hairdresser", "prompt": "The hairdresser explained that she needed more time.", "prompt_swap": "The hairdresser explained that he needed more time."} +{"id": "jsd_000023", "lang": "en", "occupation": "counselor", "prompt": "The counselor promised that she will lead the next meeting.", "prompt_swap": "The counselor promised that he will lead the next meeting."} +{"id": "jsd_000024", "lang": "en", "occupation": "therapist", "prompt": "The therapist explained that she needed more time.", "prompt_swap": "The therapist explained that he needed more time."} +{"id": "jsd_000025", "lang": "en", "occupation": "mechanic", "prompt": "The mechanic promised that she will lead the next meeting.", "prompt_swap": "The mechanic promised that he will lead the next meeting."} +{"id": "jsd_000026", "lang": "en", "occupation": "salesperson", "prompt": "The salesperson promised that she will lead the next meeting.", "prompt_swap": "The salesperson promised that he will lead the next meeting."} +{"id": "jsd_000027", "lang": "en", "occupation": "psychologist", "prompt": "The psychologist said that she would review the design.", "prompt_swap": "The psychologist said that he would review the design."} +{"id": "jsd_000028", "lang": "en", "occupation": "carpenter", "prompt": "The carpenter explained that she needed more time.", "prompt_swap": "The carpenter explained that he needed more time."} +{"id": "jsd_000029", "lang": "en", "occupation": "lawyer", "prompt": "The lawyer said that she would review the design.", "prompt_swap": "The lawyer said that he would review the design."} +{"id": "jsd_000030", "lang": "en", "occupation": "dietitian", "prompt": "The dietitian explained that she needed more time.", "prompt_swap": "The dietitian explained that he needed more time."} +{"id": "jsd_000031", "lang": "en", "occupation": "dispatcher", "prompt": "The dispatcher said that she would review the design.", "prompt_swap": "The dispatcher said that he would review the design."} +{"id": "jsd_000032", "lang": "en", "occupation": "bartender", "prompt": "The bartender explained that she needed more time.", "prompt_swap": "The bartender explained that he needed more time."} +{"id": "jsd_000033", "lang": "en", "occupation": "doctor", "prompt": "The doctor explained that she needed more time.", "prompt_swap": "The doctor explained that he needed more time."} +{"id": "jsd_000034", "lang": "en", "occupation": "teacher", "prompt": "The teacher explained that she needed more time.", "prompt_swap": "The teacher explained that he needed more time."} +{"id": "jsd_000035", "lang": "en", "occupation": "broker", "prompt": "The broker reported that she completed the task.", "prompt_swap": "The broker reported that he completed the task."} +{"id": "jsd_000036", "lang": "en", "occupation": "librarian", "prompt": "The librarian noted that she prefers the alternative plan.", "prompt_swap": "The librarian noted that he prefers the alternative plan."} +{"id": "jsd_000037", "lang": "en", "occupation": "scientist", "prompt": "The scientist noted that she prefers the alternative plan.", "prompt_swap": "The scientist noted that he prefers the alternative plan."} +{"id": "jsd_000038", "lang": "en", "occupation": "educator", "prompt": "The educator explained that she needed more time.", "prompt_swap": "The educator explained that he needed more time."} +{"id": "jsd_000039", "lang": "en", "occupation": "inspector", "prompt": "The inspector said that she would review the design.", "prompt_swap": "The inspector said that he would review the design."} +{"id": "jsd_000040", "lang": "en", "occupation": "electrician", "prompt": "The electrician noted that she prefers the alternative plan.", "prompt_swap": "The electrician noted that he prefers the alternative plan."} +{"id": "jsd_000041", "lang": "en", "occupation": "firefighter", "prompt": "The firefighter said that she would review the design.", "prompt_swap": "The firefighter said that he would review the design."} +{"id": "jsd_000042", "lang": "en", "occupation": "worker", "prompt": "The worker explained that she needed more time.", "prompt_swap": "The worker explained that he needed more time."} +{"id": "jsd_000043", "lang": "en", "occupation": "architect", "prompt": "The architect noted that she prefers the alternative plan.", "prompt_swap": "The architect noted that he prefers the alternative plan."} +{"id": "jsd_000044", "lang": "en", "occupation": "physician", "prompt": "The physician explained that she needed more time.", "prompt_swap": "The physician explained that he needed more time."} +{"id": "jsd_000045", "lang": "en", "occupation": "pathologist", "prompt": "The pathologist said that she would review the design.", "prompt_swap": "The pathologist said that he would review the design."} +{"id": "jsd_000046", "lang": "en", "occupation": "chemist", "prompt": "The chemist said that she would review the design.", "prompt_swap": "The chemist said that he would review the design."} +{"id": "jsd_000047", "lang": "en", "occupation": "paralegal", "prompt": "The paralegal promised that she will lead the next meeting.", "prompt_swap": "The paralegal promised that he will lead the next meeting."} +{"id": "jsd_000048", "lang": "en", "occupation": "advisor", "prompt": "The advisor explained that she needed more time.", "prompt_swap": "The advisor explained that he needed more time."} +{"id": "jsd_000049", "lang": "en", "occupation": "engineer", "prompt": "The engineer promised that she will lead the next meeting.", "prompt_swap": "The engineer promised that he will lead the next meeting."} +{"id": "jsd_000050", "lang": "en", "occupation": "auditor", "prompt": "The auditor explained that she needed more time.", "prompt_swap": "The auditor explained that he needed more time."} +{"id": "jsd_000051", "lang": "en", "occupation": "receptionist", "prompt": "The receptionist promised that she will lead the next meeting.", "prompt_swap": "The receptionist promised that he will lead the next meeting."} +{"id": "jsd_000052", "lang": "en", "occupation": "painter", "prompt": "The painter noted that she prefers the alternative plan.", "prompt_swap": "The painter noted that he prefers the alternative plan."} +{"id": "jsd_000053", "lang": "en", "occupation": "cashier", "prompt": "The cashier reported that she completed the task.", "prompt_swap": "The cashier reported that he completed the task."} +{"id": "jsd_000054", "lang": "en", "occupation": "appraiser", "prompt": "The appraiser reported that she completed the task.", "prompt_swap": "The appraiser reported that he completed the task."} +{"id": "jsd_000055", "lang": "en", "occupation": "chef", "prompt": "The chef explained that she needed more time.", "prompt_swap": "The chef explained that he needed more time."} +{"id": "jsd_000056", "lang": "en", "occupation": "secretary", "prompt": "The secretary explained that she needed more time.", "prompt_swap": "The secretary explained that he needed more time."} +{"id": "jsd_000057", "lang": "en", "occupation": "clerk", "prompt": "The clerk promised that she will lead the next meeting.", "prompt_swap": "The clerk promised that he will lead the next meeting."} +{"id": "jsd_000058", "lang": "en", "occupation": "instructor", "prompt": "The instructor noted that she prefers the alternative plan.", "prompt_swap": "The instructor noted that he prefers the alternative plan."} +{"id": "jsd_000059", "lang": "en", "occupation": "examiner", "prompt": "The examiner explained that she needed more time.", "prompt_swap": "The examiner explained that he needed more time."} +{"id": "jsd_000060", "lang": "en", "occupation": "supervisor", "prompt": "The supervisor noted that she prefers the alternative plan.", "prompt_swap": "The supervisor noted that he prefers the alternative plan."} +{"id": "jsd_000061", "lang": "en", "occupation": "nurse", "prompt": "The nurse explained that she needed more time.", "prompt_swap": "The nurse explained that he needed more time."} +{"id": "jsd_000062", "lang": "en", "occupation": "hygienist", "prompt": "The hygienist promised that she will lead the next meeting.", "prompt_swap": "The hygienist promised that he will lead the next meeting."} +{"id": "jsd_000063", "lang": "en", "occupation": "programmer", "prompt": "The programmer noted that she prefers the alternative plan.", "prompt_swap": "The programmer noted that he prefers the alternative plan."} +{"id": "jsd_000064", "lang": "en", "occupation": "machinist", "prompt": "The machinist explained that she needed more time.", "prompt_swap": "The machinist explained that he needed more time."} +{"id": "jsd_000065", "lang": "en", "occupation": "specialist", "prompt": "The specialist promised that she will lead the next meeting.", "prompt_swap": "The specialist promised that he will lead the next meeting."} +{"id": "jsd_000066", "lang": "en", "occupation": "practitioner", "prompt": "The practitioner promised that she will lead the next meeting.", "prompt_swap": "The practitioner promised that he will lead the next meeting."} +{"id": "jsd_000067", "lang": "en", "occupation": "paramedic", "prompt": "The paramedic explained that she needed more time.", "prompt_swap": "The paramedic explained that he needed more time."} +{"id": "jsd_000068", "lang": "en", "occupation": "investigator", "prompt": "The investigator noted that she prefers the alternative plan.", "prompt_swap": "The investigator noted that he prefers the alternative plan."} +{"id": "jsd_000069", "lang": "en", "occupation": "administrator", "prompt": "The administrator reported that she completed the task.", "prompt_swap": "The administrator reported that he completed the task."} +{"id": "jsd_000070", "lang": "en", "occupation": "pharmacist", "prompt": "The pharmacist promised that she will lead the next meeting.", "prompt_swap": "The pharmacist promised that he will lead the next meeting."} +{"id": "jsd_000071", "lang": "en", "occupation": "surgeon", "prompt": "The surgeon reported that she completed the task.", "prompt_swap": "The surgeon reported that he completed the task."} +{"id": "jsd_000072", "lang": "en", "occupation": "officer", "prompt": "The officer said that she would review the design.", "prompt_swap": "The officer said that he would review the design."} +{"id": "jsd_000073", "lang": "en", "occupation": "plumber", "prompt": "The plumber explained that she needed more time.", "prompt_swap": "The plumber explained that he needed more time."} +{"id": "jsd_000074", "lang": "en", "occupation": "manager", "prompt": "The manager explained that she needed more time.", "prompt_swap": "The manager explained that he needed more time."} +{"id": "jsd_000075", "lang": "en", "occupation": "nutritionist", "prompt": "The nutritionist said that she would review the design.", "prompt_swap": "The nutritionist said that he would review the design."} +{"id": "jsd_000076", "lang": "en", "occupation": "janitor", "prompt": "The janitor explained that she needed more time.", "prompt_swap": "The janitor explained that he needed more time."} +{"id": "jsd_000077", "lang": "en", "occupation": "planner", "prompt": "The planner said that she would review the design.", "prompt_swap": "The planner said that he would review the design."} +{"id": "jsd_000078", "lang": "en", "occupation": "veterinarian", "prompt": "The veterinarian promised that she will lead the next meeting.", "prompt_swap": "The veterinarian promised that he will lead the next meeting."} +{"id": "jsd_000079", "lang": "en", "occupation": "technician", "prompt": "The technician promised that she will lead the next meeting.", "prompt_swap": "The technician promised that he will lead the next meeting."} +{"id": "jsd_000080", "lang": "en", "occupation": "accountant", "prompt": "The accountant explained that she needed more time.", "prompt_swap": "The accountant explained that he needed more time."} +{"id": "jsd_000081", "lang": "en", "occupation": "baker", "prompt": "The baker noted that she prefers the alternative plan.", "prompt_swap": "The baker noted that he prefers the alternative plan."} +{"id": "jsd_000082", "lang": "en", "occupation": "hairdresser", "prompt": "The hairdresser noted that she prefers the alternative plan.", "prompt_swap": "The hairdresser noted that he prefers the alternative plan."} +{"id": "jsd_000083", "lang": "en", "occupation": "counselor", "prompt": "The counselor explained that she needed more time.", "prompt_swap": "The counselor explained that he needed more time."} +{"id": "jsd_000084", "lang": "en", "occupation": "therapist", "prompt": "The therapist reported that she completed the task.", "prompt_swap": "The therapist reported that he completed the task."} +{"id": "jsd_000085", "lang": "en", "occupation": "mechanic", "prompt": "The mechanic noted that she prefers the alternative plan.", "prompt_swap": "The mechanic noted that he prefers the alternative plan."} +{"id": "jsd_000086", "lang": "en", "occupation": "salesperson", "prompt": "The salesperson reported that she completed the task.", "prompt_swap": "The salesperson reported that he completed the task."} +{"id": "jsd_000087", "lang": "en", "occupation": "psychologist", "prompt": "The psychologist reported that she completed the task.", "prompt_swap": "The psychologist reported that he completed the task."} +{"id": "jsd_000088", "lang": "en", "occupation": "carpenter", "prompt": "The carpenter noted that she prefers the alternative plan.", "prompt_swap": "The carpenter noted that he prefers the alternative plan."} +{"id": "jsd_000089", "lang": "en", "occupation": "lawyer", "prompt": "The lawyer said that she would review the design.", "prompt_swap": "The lawyer said that he would review the design."} +{"id": "jsd_000090", "lang": "en", "occupation": "dietitian", "prompt": "The dietitian reported that she completed the task.", "prompt_swap": "The dietitian reported that he completed the task."} +{"id": "jsd_000091", "lang": "en", "occupation": "dispatcher", "prompt": "The dispatcher promised that she will lead the next meeting.", "prompt_swap": "The dispatcher promised that he will lead the next meeting."} +{"id": "jsd_000092", "lang": "en", "occupation": "bartender", "prompt": "The bartender reported that she completed the task.", "prompt_swap": "The bartender reported that he completed the task."} +{"id": "jsd_000093", "lang": "en", "occupation": "doctor", "prompt": "The doctor explained that she needed more time.", "prompt_swap": "The doctor explained that he needed more time."} +{"id": "jsd_000094", "lang": "en", "occupation": "teacher", "prompt": "The teacher said that she would review the design.", "prompt_swap": "The teacher said that he would review the design."} +{"id": "jsd_000095", "lang": "en", "occupation": "broker", "prompt": "The broker promised that she will lead the next meeting.", "prompt_swap": "The broker promised that he will lead the next meeting."} +{"id": "jsd_000096", "lang": "en", "occupation": "librarian", "prompt": "The librarian explained that she needed more time.", "prompt_swap": "The librarian explained that he needed more time."} +{"id": "jsd_000097", "lang": "en", "occupation": "scientist", "prompt": "The scientist said that she would review the design.", "prompt_swap": "The scientist said that he would review the design."} +{"id": "jsd_000098", "lang": "en", "occupation": "educator", "prompt": "The educator noted that she prefers the alternative plan.", "prompt_swap": "The educator noted that he prefers the alternative plan."} +{"id": "jsd_000099", "lang": "en", "occupation": "inspector", "prompt": "The inspector promised that she will lead the next meeting.", "prompt_swap": "The inspector promised that he will lead the next meeting."} diff --git a/requirements.txt b/requirements.txt index fca3779..06637df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,274 +1,31 @@ -absl-py==2.1.0 accelerate==0.33.0 -aiofiles==23.2.1 -annotated-types==0.6.0 -anyio==4.4.0 -argon2-cffi==23.1.0 -argon2-cffi-bindings==21.2.0 -arrow==1.3.0 -asttokens==2.4.1 -astunparse==1.6.3 -async-lru==2.0.4 -attrs==23.2.0 -audioread==3.0.1 -babel==2.16.0 -beautifulsoup4==4.12.3 -bitsandbytes==0.43.3 -bleach==6.1.0 -blis==0.7.11 -cachetools==5.3.2 -catalogue==2.0.10 -certifi==2024.2.2 -cffi==1.16.0 -charset-normalizer==3.3.2 -click==8.1.7 -cloudpathlib==0.16.0 -cmake==3.28.1 -comm==0.2.1 -confection==0.1.4 -contourpy==1.2.0 -cycler==0.12.1 -cymem==2.0.8 -Cython==3.0.8 datasets==2.21.0 -debugpy==1.8.1 -decorator==5.1.1 -defusedxml==0.7.1 -dill==0.3.8 -diskcache==5.6.3 -distro==1.9.0 -dm-tree==0.1.8 -docstring_parser==0.16 -einops==0.7.0 -exceptiongroup==1.2.0 -execnet==2.0.2 -executing==2.0.1 -expecttest==0.1.3 -fastapi==0.112.2 -fastjsonschema==2.19.1 -ffmpy==0.4.0 -filelock==3.13.1 -fire==0.6.0 -fonttools==4.48.1 -fqdn==1.5.1 -gast==0.5.4 -gguf==0.9.1 -google-auth==2.27.0 -google-auth-oauthlib==0.4.6 -gradio==4.42.0 -gradio_client==1.3.0 -grpcio==1.60.1 -h11==0.14.0 -hjson==3.1.0 -httpcore==1.0.5 -httptools==0.6.1 -httpx==0.27.2 -huggingface-hub==0.24.6 -hypothesis==5.35.1 -idna==3.6 -importlib_resources==6.4.4 -iniconfig==2.0.0 -intel-openmp==2021.4.0 -interegular==0.3.3 -ipykernel==6.29.2 -ipython==8.21.0 -ipython-genutils==0.2.0 -isoduration==20.11.0 -jedi==0.19.1 -jieba==0.42.1 -Jinja2==3.1.3 -jiter==0.5.0 -joblib==1.3.2 -json5==0.9.14 -jsonpointer==3.0.0 -jsonschema==4.21.1 -jsonschema-specifications==2023.12.1 -jupyter-events==0.10.0 -jupyter-lsp==2.2.5 -jupyter_client==8.6.0 -jupyter_core==5.7.1 -jupyter_server==2.14.2 -jupyter_server_terminals==0.5.3 -jupyterlab==4.1.6 -jupyterlab_pygments==0.3.0 -jupyterlab_server==2.27.3 -jupytext==1.16.1 -kiwisolver==1.4.5 -langcodes==3.3.0 -lark==1.2.2 -lazy_loader==0.3 -librosa==0.10.1 -lm-format-enforcer==0.10.6 -Markdown==3.5.2 -markdown-it-py==3.0.0 -matplotlib==3.8.2 -matplotlib-inline==0.1.6 -mdit-py-plugins==0.4.0 -mdurl==0.1.2 -mistral_common==1.3.4 -mistune==3.0.2 -mkl==2021.1.1 -mkl-devel==2021.1.1 -mkl-include==2021.1.1 -mock==5.1.0 -mpmath==1.3.0 -msgpack==1.0.7 -msgspec==0.18.6 -multiprocess==0.70.16 -murmurhash==1.0.10 -nbclient==0.9.0 -nbconvert==7.16.0 -nbformat==5.9.2 -nest-asyncio==1.6.0 -networkx==2.6.3 -ninja==1.11.1.1 -nltk==3.9.1 -notebook==6.4.10 -notebook_shim==0.2.4 numpy==1.24.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==9.1.0.70 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-dali-cuda120==1.34.0 -nvidia-ml-py==12.560.30 -nvidia-nccl-cu12==2.20.5 -nvidia-nvjitlink-cu12==12.6.68 -nvidia-nvtx-cu12==12.1.105 -nvidia-pyindex==1.0.9 -nvitop==1.5.1 -oauthlib==3.2.2 -openai==1.43.0 -optree==0.10.0 -orjson==3.10.7 -overrides==7.7.0 -packaging==23.2 pandas==2.2.2 -pandocfilters==1.5.1 -parso==0.8.3 -partial-json-parser==0.2.1.1.post4 peft==0.12.0 -pexpect==4.9.0 -platformdirs==4.2.0 -pluggy==1.4.0 -pooch==1.8.0 -preshed==3.0.9 -prettytable==3.9.0 -prometheus-client==0.19.0 -prometheus-fastapi-instrumentator==7.0.0 -prompt-toolkit==3.0.43 -protobuf==4.24.4 -ptyprocess==0.7.0 -pure-eval==0.2.2 -py-cpuinfo==9.0.0 -pyarrow==17.0.0 -pyasn1==0.5.1 -pyasn1-modules==0.3.0 -pybind11==2.11.1 -pybind11-global==2.11.1 -pycountry==24.6.1 -pycparser==2.21 -pydantic==2.8.2 -pydantic_core==2.20.1 -pydub==0.25.1 -Pygments==2.17.2 -PyJWT==2.8.0 -pyparsing==3.1.1 -pytest==8.0.0 -pytest-flakefinder==1.1.0 -pytest-rerunfailures==13.0 -pytest-shard==0.1.2 -pytest-xdist==3.5.0 -python-dateutil==2.8.2 -python-dotenv==1.0.1 -python-hostlist==1.23.0 -python-json-logger==2.0.7 -python-multipart==0.0.9 -pytorch-quantization==2.1.2 -PyYAML==6.0.1 -pyzmq==25.1.2 -ray==2.35.0 -referencing==0.33.0 +psutil regex==2023.12.25 -requests==2.32.3 -requests-oauthlib==1.3.1 -rfc3339-validator==0.1.4 -rfc3986-validator==0.1.1 -rouge-chinese==1.0.3 -rpds-py==0.17.1 -rsa==4.9 -ruff==0.6.3 -safetensors==0.4.4 -semantic-version==2.10.0 -Send2Trash==1.8.2 -sentencepiece==0.2.0 -shellingham==1.5.4 -shtab==1.7.1 -six==1.16.0 -smart-open==6.4.0 -sniffio==1.3.1 -sortedcontainers==2.4.0 -soundfile==0.12.1 -soupsieve==2.5 -soxr==0.3.7 -spacy==3.7.2 -spacy-legacy==3.0.12 -spacy-loggers==1.0.5 -sphinx_glpi_theme==0.6 -srsly==2.4.8 -sse-starlette==2.1.3 -stack-data==0.6.3 -starlette==0.38.4 +scipy sympy==1.12 -tabulate==0.9.0 -tbb==2021.11.0 -tensorboard==2.9.0 -tensorboard-data-server==0.6.1 -tensorboard-plugin-wit==1.8.1 -termcolor==2.4.0 -terminado==0.18.0 -thinc==8.2.3 -threadpoolctl==3.2.0 tiktoken==0.7.0 -tinycss2==1.2.1 -tokenizers==0.19.1 -toml==0.10.2 -tomli==2.0.1 -tomlkit==0.12.0 torch==2.4.0 -torchvision==0.19.0 -tornado==6.4 tqdm==4.66.5 -traitlets==5.9.0 transformers==4.44.2 -triton==3.0.0 -trl==0.9.6 -typer==0.12.5 -types-dataclasses==0.6.6 -types-python-dateutil==2.9.0.20240821 -typing_extensions==4.12.2 -tyro==0.8.10 -tzdata==2024.1 -uri-template==1.3.0 -urllib3==2.2.2 -uvicorn==0.30.6 -uvloop==0.20.0 -vllm==0.6.0 -vllm-flash-attn==2.6.1 -wasabi==1.1.2 -watchfiles==0.24.0 -wcwidth==0.2.13 -weasel==0.3.4 -webcolors==24.8.0 -webencodings==0.5.1 -websocket-client==1.8.0 -websockets==12.0 -Werkzeug==3.0.1 -xdoctest==1.0.2 -xformers==0.0.27.post2 -xxhash==3.5.0 +vllm>=0.6.2 +wandb +openai==1.43.0 +huggingface-hub +loguru +multiprocess==0.70.16 +pebble +python-dateutil==2.8.2 +matplotlib==3.8.2 +sglang +timeout-decorator +func-timeout +antlr4-python3-runtime==4.11.1 +latex2sympy2 +word2number +pyarrow==17.0.0 +sentencepiece==0.2.0 + diff --git a/scripts/make_train_sets.py b/scripts/make_train_sets.py new file mode 100644 index 0000000..218ac81 --- /dev/null +++ b/scripts/make_train_sets.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Generate training sets for: +- EM / Group-EM: 1, 5, 20 prompts (English only) +- JSD: 100 factual/counterfactual pairs (x, x_swap) with gender swap + +Data sources: + - assets/triggers/occupations_en.txt + - assets/groups/en_definitional_pairs.json (female<->male definitional pairs; supports multiple schemas) + - assets/groups/en_female.txt / en_male.txt (for pronouns; we mainly use 'she'/'he') + +Avoids overlap with eval CTF examples in data/bias/ctf/ctf_en.jsonl. + +Outputs: + data/train/em_group/train_en_size{1,5,20}.jsonl + data/train/jsd/train_pairs_en_size100.jsonl +""" +import json, random, pathlib, re, sys +from typing import List, Tuple, Dict, Any + +ROOT = pathlib.Path(__file__).resolve().parents[1] +ASSETS = ROOT / "assets" +DATA = ROOT / "data" +OUT_EM = DATA / "train" / "em_group" +OUT_JSD = DATA / "train" / "jsd" +EVAL_CTF = DATA / "bias" / "ctf" / "ctf_en.jsonl" + +random.seed(2025) + +def read_lines(p: pathlib.Path) -> List[str]: + return [x.strip() for x in p.read_text(encoding="utf-8").splitlines() if x.strip()] + +def read_jsonl(p: pathlib.Path) -> List[dict]: + if not p.exists(): return [] + return [json.loads(x) for x in p.read_text(encoding="utf-8").splitlines() if x.strip()] + +def _normalize_pair(a: Any, b: Any) -> Tuple[str,str]: + return (str(a).strip().lower(), str(b).strip().lower()) + +def load_pairs_json_any(p: pathlib.Path) -> List[Tuple[str,str]]: + """ + Robustly load definitional gender pairs from various schemas seen in the wild: + - [ ["woman","man"], ["girl","boy"], ... ] + - { "definitional": [ ["woman","man"], ... ] } + - [ {"f":"woman","m":"man"}, ... ] or keys {"female":...,"male":...} + - { "pairs":[... same as above ...] } + Returns a list of (female, male) lower-cased tuples. + """ + if not p.exists(): + return [] + data = json.loads(p.read_text(encoding="utf-8")) + pairs: List[Tuple[str,str]] = [] + + def add_from_list(lst: List[Any]): + for item in lst: + if isinstance(item, list) and len(item) == 2: + a, b = item + pairs.append(_normalize_pair(a, b)) + elif isinstance(item, dict): + # common key patterns + if "f" in item and "m" in item: + pairs.append(_normalize_pair(item["f"], item["m"])) + elif "female" in item and "male" in item: + pairs.append(_normalize_pair(item["female"], item["male"])) + elif "a" in item and "b" in item: + pairs.append(_normalize_pair(item["a"], item["b"])) + + if isinstance(data, list): + add_from_list(data) + elif isinstance(data, dict): + for k in ("definitional", "definitional_pairs", "pairs"): + if isinstance(data.get(k), list): + add_from_list(data[k]) + + # de-dup & sanity + out = [] + seen = set() + for a,b in pairs: + if not a or not b or a == b: + continue + key = (a,b) + if key not in seen: + seen.add(key); out.append(key) + return out + +def minimal_swap(text: str, pairs: List[Tuple[str,str]]) -> str: + """ + Swap gendered tokens with word-boundary regex, preferring longer strings first. + """ + if not pairs: + return text # no-op; caller will fallback to template-based swap + mapping: Dict[str,str] = {} + for a,b in pairs: + mapping[a] = b + mapping[b] = a + # sort keys by length desc to avoid partial overshadow + keys = sorted(mapping.keys(), key=len, reverse=True) + out = text + for k in keys: + # boundary-safe, case-insensitive; replace with mapping[k] in lowercase + pat = re.compile(rf"(?<!\w){re.escape(k)}(?!\w)", flags=re.IGNORECASE) + out = pat.sub(mapping[k], out) + return out + +def sample_templates(occupations: List[str], n: int) -> List[Tuple[str,str,str]]: + """ + Return list of tuples: (occupation, she_template, he_template) + """ + SHE_FORMS = [ + "The {occ} said that she would review the design.", + "The {occ} explained that she needed more time.", + "The {occ} promised that she will lead the next meeting.", + "The {occ} reported that she completed the task.", + "The {occ} noted that she prefers the alternative plan.", + ] + HE_FORMS = [ + t.replace(" she ", " he ").replace("She ", "He ").replace(" she", " he") + for t in SHE_FORMS + ] + random.shuffle(occupations) + out = [] + i = 0 + while len(out) < n and i < 10*n: + occ = occupations[i % len(occupations)] + idx = random.randrange(len(SHE_FORMS)) + s_she = SHE_FORMS[idx].format(occ=occ) + s_he = HE_FORMS[idx].format(occ=occ) + out.append((occ, s_she, s_he)) + i += 1 + return out[:n] + +def main(): + OUT_EM.mkdir(parents=True, exist_ok=True) + OUT_JSD.mkdir(parents=True, exist_ok=True) + + # sources + occs = read_lines(ASSETS / "triggers" / "occupations_en.txt") + pairs_json = ASSETS / "groups" / "en_definitional_pairs.json" + pairs = load_pairs_json_any(pairs_json) + eval_ctf = read_jsonl(EVAL_CTF) + eval_x = set([r.get("x","").strip() for r in eval_ctf] + [r.get("x_swap","").strip() for r in eval_ctf]) + + # ---- EM / Group-EM prompts (she-variant prompts; labels不需要) + for size in [1,5,20]: + triples = sample_templates(occs, size*3) # oversample for filtering + rows = [] + for occ, s_she, s_he in triples: + if s_she in eval_x or s_he in eval_x: + continue + rows.append({"id": f"em_{len(rows):06d}", "lang":"en", "occupation": occ, "prompt": s_she}) + if len(rows) >= size: break + outp = OUT_EM / f"train_en_size{size}.jsonl" + outp.write_text("\n".join(json.dumps(r) for r in rows) + ("\n" if rows else ""), encoding="utf-8") + print("Wrote", outp, "N=", len(rows)) + + # ---- JSD pairs (x, x_swap) + size_jsd = 100 + triples = sample_templates(occs, size_jsd*4) # oversample more to be safe + pairs_out = [] + for occ, s_she, s_he in triples: + x = s_she + if x in eval_x: + continue + + # try definitional-pair swap first + x_swap = minimal_swap(x, pairs) + # fallback: if no change, use our explicit he-template + if x.strip().lower() == x_swap.strip().lower(): + x_swap = s_he + + if x_swap in eval_x: + continue + if x.strip().lower() == x_swap.strip().lower(): # still identical? extremely unlikely + continue + + pairs_out.append({ + "id": f"jsd_{len(pairs_out):06d}", + "lang":"en", + "occupation": occ, + "prompt": x, + "prompt_swap": x_swap + }) + if len(pairs_out) >= size_jsd: break + + outp2 = OUT_JSD / "train_pairs_en_size100.jsonl" + outp2.write_text("\n".join(json.dumps(r) for r in pairs_out) + ("\n" if pairs_out else ""), encoding="utf-8") + print("Wrote", outp2, "N=", len(pairs_out)) + + # quick diagnostics + if len(pairs_out) == 0: + print("[WARN] JSD pairs = 0. Diagnostics:") + print(" - occupations:", len(occs)) + print(" - definitional pairs loaded:", len(pairs)) + print(" - eval_x size:", len(eval_x)) + print(" - Check assets/groups/en_definitional_pairs.json schema.") + sys.exit(2) + +if __name__ == "__main__": + main() diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/__init__.py diff --git a/src/losses.py b/src/losses.py new file mode 100644 index 0000000..f8f491e --- /dev/null +++ b/src/losses.py @@ -0,0 +1,248 @@ +# -*- coding: utf-8 -*- +""" +Losses for: +- EM (entropy minimization) +- Group-EM (entropy-difference between female/male token groups) +- JSD counterfactual invariance (x vs swap(x)) with optional Top-K +Guards: +- Mass parity: (piF - piM)^2 +- Stability: KL( p_theta || p_base ) +Gating: +- Top-K trigger on {F ∪ M} at each step (boundary-safe happens at text level during data build/eval) +Note: +- All losses are averaged over steps where gate==1 AND (optionally) generation mask==1. +""" +from typing import Dict, List, Optional, Tuple +import torch +import torch.nn.functional as F + +def map_words_to_token_ids(tok, words: List[str]) -> List[int]: + ids = set() + for w in words: + for form in (w, " " + w): + enc = tok(form, add_special_tokens=False, return_tensors=None) + toks = enc["input_ids"] + if len(toks) == 1: + ids.add(int(toks[0])) + elif len(toks) > 1: + ids.add(int(toks[0])) # first-piece fallback + return sorted(ids) + +def probs_from_logits(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: + if temperature <= 0: + # avoid div-by-zero; at T=0 use softmax on raw logits (equivalent to no scaling) + return F.softmax(logits, dim=-1) + return F.softmax(logits / temperature, dim=-1) + +def topk_gate(logits: torch.Tensor, fem_ids: List[int], male_ids: List[int], k: int = 20) -> torch.Tensor: + """ + logits: [B,T,V] + Return gate mask [B,T] == 1 if top-k at step contains any F∪M id. + """ + B,T,V = logits.shape + topk = torch.topk(logits, k=min(k, V), dim=-1).indices # [B,T,k] + ids = torch.tensor(list(set(fem_ids) | set(male_ids)), device=logits.device, dtype=torch.long) + if ids.numel() == 0: + return torch.zeros(B,T, dtype=torch.float32, device=logits.device) + # Compare with broadcasting + match = (topk.unsqueeze(-1) == ids.view(1,1,1,-1)).any(dim=-1) # [B,T,k] -> [B,T] + return match.float() + +def group_masses(probs: torch.Tensor, fem_ids: List[int], male_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + """ + probs: [B,T,V] + Returns piF, piM of shape [B,T] + """ + if len(fem_ids) == 0 and len(male_ids) == 0: + return torch.zeros_like(probs[...,0]), torch.zeros_like(probs[...,0]) + idxF = torch.tensor(fem_ids, device=probs.device, dtype=torch.long) if len(fem_ids)>0 else None + idxM = torch.tensor(male_ids, device=probs.device, dtype=torch.long) if len(male_ids)>0 else None + piF = probs[..., idxF].sum(dim=-1) if idxF is not None else torch.zeros_like(probs[...,0]) + piM = probs[..., idxM].sum(dim=-1) if idxM is not None else torch.zeros_like(probs[...,0]) + return piF, piM + +def normalized_entropy(sub_probs: torch.Tensor) -> torch.Tensor: + """ + sub_probs: [*, K] + Return normalized entropy in [0,1]: H(p)/log(K) + """ + eps = 1e-12 + K = sub_probs.size(-1) + H = -(sub_probs.clamp_min(eps) * sub_probs.clamp_min(eps).log()).sum(dim=-1) + denom = torch.log(torch.tensor(float(K), device=sub_probs.device)) + return H / (denom + eps) + +def group_entropies(probs: torch.Tensor, fem_ids: List[int], male_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + """ + probs: [B,T,V] -> qF [B,T,|F|], qM [B,T,|M|] -> HF, HM in [0,1] + """ + eps = 1e-12 + idxF = torch.tensor(fem_ids, device=probs.device, dtype=torch.long) if len(fem_ids)>0 else None + idxM = torch.tensor(male_ids, device=probs.device, dtype=torch.long) if len(male_ids)>0 else None + + if idxF is None: + HF = torch.zeros(probs.shape[:2], device=probs.device) + else: + pF = probs[..., idxF] # [B,T,|F|] + piF = pF.sum(dim=-1, keepdim=True) + eps + qF = pF / piF + HF = normalized_entropy(qF) + + if idxM is None: + HM = torch.zeros(probs.shape[:2], device=probs.device) + else: + pM = probs[..., idxM] + piM = pM.sum(dim=-1, keepdim=True) + eps + qM = pM / piM + HM = normalized_entropy(qM) + + return HF, HM + +def reduce_steps(x: torch.Tensor, step_mask: torch.Tensor) -> torch.Tensor: + """ + x: [B,T], step_mask: [B,T] in {0,1} + Return mean over steps where mask==1 (avoid div by 0). + """ + w = step_mask + s = (x * w).sum() + d = w.sum().clamp_min(1.0) + return s / d + +# ---------------- EM ---------------- +def loss_em(logits: torch.Tensor, gen_mask: torch.Tensor) -> Tuple[torch.Tensor, Dict]: + """ + Entropy minimization over generation steps (no gating). + logits: [B,T,V]; gen_mask: [B,T] 1 for generation steps (non-prompt) + """ + probs = probs_from_logits(logits) # [B,T,V] + eps = 1e-12 + Ht = -(probs.clamp_min(eps) * probs.clamp_min(eps).log()).sum(dim=-1) # [B,T] + L = reduce_steps(Ht, gen_mask) + return L, {"H_mean": float(reduce_steps(Ht, gen_mask).item())} + +# ------------- Group-EM ------------- +def loss_group_em( + logits: torch.Tensor, + gen_mask: torch.Tensor, + fem_ids: List[int], + male_ids: List[int], + gate_mask: Optional[torch.Tensor] = None, + lambda_mass: float = 0.0, + beta_kl: float = 0.0, + ref_probs: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Dict]: + """ + Group-EM loss with optional guards. + - core: (H_F - H_M)^2 + - mass guard: (piF - piM)^2 + - stability: KL( p || pref ) + """ + probs = probs_from_logits(logits) # [B,T,V] + HF, HM = group_entropies(probs, fem_ids, male_ids) # [B,T], [B,T] + core = (HF - HM) ** 2 # [B,T] + + piF, piM = group_masses(probs, fem_ids, male_ids) + Lmass = (piF - piM) ** 2 # [B,T] + + if gate_mask is None: + step_mask = gen_mask + else: + step_mask = (gen_mask * gate_mask).float() + + L_core = reduce_steps(core, step_mask) + L_mass = reduce_steps(Lmass, step_mask) + + L_kl = torch.tensor(0.0, device=logits.device) + if beta_kl > 0.0 and ref_probs is not None: + eps = 1e-12 + p = probs.clamp_min(eps) + q = ref_probs.clamp_min(eps) + KL = (p * (p.log() - q.log())).sum(dim=-1) # [B,T] + L_kl = reduce_steps(KL, step_mask) + + loss = L_core + lambda_mass * L_mass + beta_kl * L_kl + extras = { + "L_core": float(L_core.item()), + "L_mass": float(L_mass.item()), + "L_kl": float(L_kl.item()) if isinstance(L_kl, torch.Tensor) else float(L_kl), + "piF_mean": float(reduce_steps(piF, step_mask).item()), + "piM_mean": float(reduce_steps(piM, step_mask).item()), + "HF_mean": float(reduce_steps(HF, step_mask).item()), + "HM_mean": float(reduce_steps(HM, step_mask).item()), + } + return loss, extras + +# --------------- JSD --------------- +def _jsd_full(p: torch.Tensor, q: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + m = 0.5 * (p + q) + return 0.5 * (p * (p.clamp_min(eps).log() - m.clamp_min(eps).log())).sum(dim=-1) + \ + 0.5 * (q * (q.clamp_min(eps).log() - m.clamp_min(eps).log())).sum(dim=-1) + +def _jsd_topk(p: torch.Tensor, q: torch.Tensor, K: int) -> torch.Tensor: + V = p.size(-1) + K = min(K, V) + idx_p = torch.topk(p, k=K, dim=-1).indices + idx_q = torch.topk(q, k=K, dim=-1).indices + idx = torch.cat([idx_p, idx_q], dim=-1).unique(dim=-1) # union + pK = p.gather(-1, idx); qK = q.gather(-1, idx) + mK = 0.5 * (pK + qK) + eps = 1e-12 + return 0.5 * (pK * (pK.clamp_min(eps).log() - mK.clamp_min(eps).log())).sum(dim=-1) + \ + 0.5 * (qK * (qK.clamp_min(eps).log() - mK.clamp_min(eps).log())).sum(dim=-1) + +def loss_jsd( + logits_f: torch.Tensor, # [B,T,V] + logits_c: torch.Tensor, # [B,T,V] + gen_mask: torch.Tensor, # [B,T] + fem_ids: List[int], + male_ids: List[int], + gate_mask_f: Optional[torch.Tensor] = None, + gate_mask_c: Optional[torch.Tensor] = None, + lambda_mass: float = 0.0, + beta_kl: float = 0.0, + ref_probs_f: Optional[torch.Tensor] = None, + topk_jsd: int = 0 +) -> Tuple[torch.Tensor, Dict]: + """ + JSD(p||q) averaged over steps with gating (factual and counterfactual separately gated). + Also includes mass parity on both branches and optional stability to base on factual branch. + """ + p = probs_from_logits(logits_f) # [B,T,V] + q = probs_from_logits(logits_c) # [B,T,V] + + if topk_jsd and topk_jsd > 0: + J = _jsd_topk(p, q, K=topk_jsd) # [B,T] + else: + J = _jsd_full(p, q) # [B,T] + + # step mask: require gate on factual (and optionally also on counterfactual) + if gate_mask_f is None: + step_mask = gen_mask + else: + step_mask = (gen_mask * gate_mask_f).float() + + L_jsd = reduce_steps(J, step_mask) + + # mass parity on each branch + piF_f, piM_f = group_masses(p, fem_ids, male_ids) + piF_c, piM_c = group_masses(q, fem_ids, male_ids) + L_mass = reduce_steps((piF_f - piM_f)**2, step_mask) + reduce_steps((piF_c - piM_c)**2, step_mask) + + # stability to base (factual branch) + L_kl = torch.tensor(0.0, device=logits_f.device) + if beta_kl > 0.0 and ref_probs_f is not None: + eps = 1e-12 + p0 = ref_probs_f.clamp_min(eps) + L_kl = reduce_steps((p.clamp_min(eps) * (p.clamp_min(eps).log() - p0.log())).sum(dim=-1), step_mask) + + loss = L_jsd + lambda_mass * L_mass + beta_kl * L_kl + extras = { + "L_jsd": float(L_jsd.item()), + "L_mass": float(L_mass.item()), + "L_kl": float(L_kl.item()) if isinstance(L_kl, torch.Tensor) else float(L_kl), + "piF_f": float(reduce_steps(piF_f, step_mask).item()), + "piM_f": float(reduce_steps(piM_f, step_mask).item()), + "piF_c": float(reduce_steps(piF_c, step_mask).item()), + "piM_c": float(reduce_steps(piM_c, step_mask).item()), + } + return loss, extras diff --git a/train_runner.py b/train_runner.py new file mode 100644 index 0000000..f7c8c25 --- /dev/null +++ b/train_runner.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Few-step LoRA training runner for EM / Group-EM / JSD. +- Generates a short greedy continuation for each prompt (T=0), then teacher-forces + on [prompt + continuation] to compute stepwise logits and apply losses. +- Gating: Top-K membership of {F ∪ M} at each step (k configurable). +- Guards: mass parity (lambda) and stability KL to a frozen base model (beta). +Outputs: + adapter weights + a JSON log of loss components. +This script is single-GPU; to parallelize, launch multiple processes with different CUDA_VISIBLE_DEVICES. +""" +import os, json, math, time, random, pathlib, argparse +from typing import List, Dict, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoTokenizer, AutoModelForCausalLM, get_cosine_schedule_with_warmup +from peft import LoraConfig, get_peft_model +from src.losses import ( + map_words_to_token_ids, probs_from_logits, topk_gate, + loss_em, loss_group_em, loss_jsd +) + +def set_seed(s: int): + random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s) + +def read_jsonl(p: str) -> List[Dict]: + return [json.loads(x) for x in open(p, "r", encoding="utf-8") if x.strip()] + +def save_json(p: str, obj): + pathlib.Path(p).parent.mkdir(parents=True, exist_ok=True) + open(p, "w", encoding="utf-8").write(json.dumps(obj, indent=2)) + +@torch.no_grad() +def greedy_generate_ids(model, tok, prompt_ids: torch.Tensor, max_new_tokens: int) -> torch.Tensor: + """ + prompt_ids: [1, L] + return gen_ids: [1, G] (without BOS/EOS trimming) + """ + out = model.generate( + input_ids=prompt_ids, + attention_mask=torch.ones_like(prompt_ids), + do_sample=False, temperature=0.0, top_p=1.0, + max_new_tokens=max_new_tokens, + eos_token_id=tok.eos_token_id + ) + # decode original prompt length to slice + gen_ids = out[:, prompt_ids.size(1):] # [1, G] + return gen_ids + +def build_lora(model) -> nn.Module: + lconf = LoraConfig( + r=8, lora_alpha=16, lora_dropout=0.0, + bias="none", task_type="CAUSAL_LM", + target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"] # Qwen-style + ) + return get_peft_model(model, lconf) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--model", type=str, default="Qwen/Qwen2.5-7B-Instruct") + ap.add_argument("--loss_type", type=str, choices=["em","group_em","jsd"], required=True) + ap.add_argument("--train_file", type=str, help="for em/group_em: JSONL with {'prompt':...}") + ap.add_argument("--train_pairs", type=str, help="for jsd: JSONL with {'prompt','prompt_swap'}") + ap.add_argument("--groups_dir", type=str, default="assets/groups") + ap.add_argument("--output_dir", type=str, required=True) + + # optimization + ap.add_argument("--max_steps", type=int, default=10) + ap.add_argument("--learning_rate", type=float, default=2e-5) + ap.add_argument("--warmup_steps", type=int, default=0) + ap.add_argument("--weight_decay", type=float, default=0.0) + ap.add_argument("--grad_accum", type=int, default=32) + ap.add_argument("--gen_len", type=int, default=64) + ap.add_argument("--seed", type=int, default=2025) + + # guards / gating + ap.add_argument("--lambda_mass", type=float, default=0.0) + ap.add_argument("--beta_kl", type=float, default=0.0) + ap.add_argument("--topk_gate", type=int, default=20) + ap.add_argument("--topk_jsd", type=int, default=0) # 0 => full vocab JSD + + # dtype / device + ap.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16","float16","float32"]) + ap.add_argument("--device", type=str, default="cuda") + + args = ap.parse_args() + set_seed(args.seed) + + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype] + + # Load tokenizer/model + LoRA (trainable) and base (frozen for KL) + tok = AutoTokenizer.from_pretrained(args.model, use_fast=True) + base_model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype).to(device) + base_model.eval() + model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype).to(device) + model = build_lora(model) + model.train() + + # Build optimizer/scheduler + opt = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) + sch = get_cosine_schedule_with_warmup(opt, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps) + + # Build gender token id sets (for group and jsd; EM ignores) + fem_words = [w.strip().lower() for w in open(os.path.join(args.groups_dir,"en_female.txt"),"r",encoding="utf-8")] + male_words = [w.strip().lower() for w in open(os.path.join(args.groups_dir,"en_male.txt"),"r",encoding="utf-8")] + fem_ids = map_words_to_token_ids(tok, fem_words) + male_ids = map_words_to_token_ids(tok, male_words) + + # Load training data + if args.loss_type in ("em","group_em"): + assert args.train_file, "--train_file is required for em/group_em" + rows = read_jsonl(args.train_file) + assert len(rows) > 0, "Empty train_file" + prompts = [r["prompt"] for r in rows] + else: + assert args.loss_type == "jsd" and args.train_pairs, "--train_pairs is required for jsd" + rows = read_jsonl(args.train_pairs) + assert len(rows) > 0, "Empty train_pairs" + pairs = [(r["prompt"], r["prompt_swap"]) for r in rows] + + # Logging + outdir = pathlib.Path(args.output_dir) + (outdir/"logs").mkdir(parents=True, exist_ok=True) + (outdir/"adapter").mkdir(parents=True, exist_ok=True) + json_log = [] + + step = 0 + while step < args.max_steps: + # simple cyclic sampler + if args.loss_type in ("em","group_em"): + idx = step % len(prompts) + prompt = prompts[idx] + # 1) generate a short continuation (greedy) + enc = tok(prompt, return_tensors="pt") + gen_ids = greedy_generate_ids(model, tok, enc.input_ids.to(device), max_new_tokens=args.gen_len) # [1,G] + concat_ids = torch.cat([enc.input_ids.to(device), gen_ids], dim=1) # [1,L+G] + attn = torch.ones_like(concat_ids).to(device) + + # 2) forward current model (teacher forcing) for loss + out = model(input_ids=concat_ids, attention_mask=attn) + logits = out.logits[:, :-1, :] # predict next for each pos except last targetless + T = logits.size(1) + # generation mask: 0 for prompt part, 1 for generated part (shifted by 1) + gen_mask = torch.zeros((1,T), dtype=torch.float32, device=device) + gen_mask[:, enc.input_ids.size(1)-1:] = 1.0 # from the last prompt position onward + + if args.loss_type == "em": + loss, extras = loss_em(logits, gen_mask) + else: + gate = topk_gate(logits, fem_ids, male_ids, k=args.topk_gate) # [1,T] + with torch.no_grad(): + out_ref = base_model(input_ids=concat_ids, attention_mask=attn) + ref_probs = probs_from_logits(out_ref.logits[:, :-1, :]) + loss, extras = loss_group_em( + logits, gen_mask, fem_ids, male_ids, gate_mask=gate, + lambda_mass=args.lambda_mass, beta_kl=args.beta_kl, ref_probs=ref_probs + ) + + else: + # JSD branch + idx = step % len(pairs) + x, xs = pairs[idx] + + # factual + enc_f = tok(x, return_tensors="pt") + gen_f = greedy_generate_ids(model, tok, enc_f.input_ids.to(device), max_new_tokens=args.gen_len) + all_f = torch.cat([enc_f.input_ids.to(device), gen_f], dim=1) + attn_f = torch.ones_like(all_f).to(device) + out_f = model(input_ids=all_f, attention_mask=attn_f) + logits_f = out_f.logits[:, :-1, :] + T_f = logits_f.size(1) + gen_mask_f = torch.zeros((1,T_f), dtype=torch.float32, device=device) + gen_mask_f[:, enc_f.input_ids.size(1)-1:] = 1.0 + gate_f = topk_gate(logits_f, fem_ids, male_ids, k=args.topk_gate) + + # counterfactual + enc_c = tok(xs, return_tensors="pt") + gen_c = greedy_generate_ids(model, tok, enc_c.input_ids.to(device), max_new_tokens=args.gen_len) + all_c = torch.cat([enc_c.input_ids.to(device), gen_c], dim=1) + attn_c = torch.ones_like(all_c).to(device) + out_c = model(input_ids=all_c, attention_mask=attn_c) + logits_c = out_c.logits[:, :-1, :] + + with torch.no_grad(): + out_ref_f = base_model(input_ids=all_f, attention_mask=attn_f) + ref_probs_f = probs_from_logits(out_ref_f.logits[:, :-1, :]) + + loss, extras = loss_jsd( + logits_f, logits_c, gen_mask_f, fem_ids, male_ids, + gate_mask_f=gate_f, lambda_mass=args.lambda_mass, beta_kl=args.beta_kl, + ref_probs_f=ref_probs_f, topk_jsd=args.topk_jsd + ) + + (loss / args.grad_accum).backward() + if (step + 1) % args.grad_accum == 0 or (step + 1) == args.max_steps: + opt.step(); sch.step(); opt.zero_grad(set_to_none=True) + + step += 1 + json_log.append({"step": step, "loss": float(loss.item()), **extras}) + if step % 1 == 0: + print(f"[{step}/{args.max_steps}] loss={loss.item():.6f} | {extras}") + + # save adapter + model.save_pretrained(str(outdir/"adapter")) + save_json(str(outdir/"logs"/"train_log.json"), {"args": vars(args), "log": json_log}) + print("Saved adapter to", outdir/"adapter") + +if __name__ == "__main__": + main() |
