summaryrefslogtreecommitdiff
path: root/scripts/eval_single_ckpt.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/eval_single_ckpt.py')
-rw-r--r--scripts/eval_single_ckpt.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/scripts/eval_single_ckpt.py b/scripts/eval_single_ckpt.py
new file mode 100644
index 0000000..8597907
--- /dev/null
+++ b/scripts/eval_single_ckpt.py
@@ -0,0 +1,145 @@
+import json
+import os
+import torch
+import glob
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
+from torch.utils.data import Dataset, DataLoader
+
+# --- Configuration ---
+# You can manually set the checkpoint path here if glob fails or is slow
+# Example: "saves/qwen3-0.6b-full-sft-h200/checkpoint-4358"
+CHECKPOINT_DIR = "saves/qwen3-0.6b-full-sft-h200"
+TEST_FILE = "data/test_llama_factory.json"
+BATCH_SIZE = 128
+USE_FLASH_ATTN = False
+
+# Load System Prompt
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ SYSTEM_PROMPT = f.read()
+
+class EvalDataset(Dataset):
+ def __init__(self, data):
+ self.data = data
+ def __len__(self):
+ return len(self.data)
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+def load_test_data():
+ with open(TEST_FILE, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+def batch_generate(model, tokenizer, batch_data, device="cuda"):
+ prompts = []
+ for item in batch_data:
+ messages = [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": item["input"]}
+ ]
+ text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+ prompts.append(text)
+
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(device)
+
+ with torch.no_grad():
+ generated_ids = model.generate(
+ **inputs,
+ max_new_tokens=256,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ )
+
+ input_len = inputs.input_ids.shape[1]
+ gen_tokens = generated_ids[:, input_len:]
+ responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
+ return responses
+
+def evaluate_ckpt():
+ # 1. Find the latest checkpoint
+ checkpoints = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*")), key=lambda x: int(x.split("-")[-1]))
+ if not checkpoints:
+ print(f"No checkpoints found in {CHECKPOINT_DIR}")
+ return
+
+ latest_ckpt = checkpoints[-1]
+ print(f"\nTarget Checkpoint: {latest_ckpt}")
+
+ device = "cuda"
+ print(f"Loading model (Batch Size: {BATCH_SIZE})...")
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(latest_ckpt, trust_remote_code=True, padding_side="left")
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+
+ kwargs = {"device_map": device, "torch_dtype": torch.bfloat16, "trust_remote_code": True}
+ if USE_FLASH_ATTN:
+ kwargs["attn_implementation"] = "flash_attention_2"
+
+ model = AutoModelForCausalLM.from_pretrained(latest_ckpt, **kwargs)
+ model.eval()
+ except Exception as e:
+ print(f"CRITICAL ERROR loading model: {e}")
+ return
+
+ # 2. Prepare Data
+ test_data = load_test_data()
+ dataset = EvalDataset(test_data)
+ # Reduce num_workers to avoid hang if system is stressed
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
+
+ y_true_has_pref = []
+ y_pred_has_pref = []
+ json_valid_count = 0
+
+ print(f"Evaluating {len(test_data)} samples...")
+
+ # 3. Inference Loop
+ for batch in tqdm(dataloader):
+ inputs = batch["input"]
+ outputs = batch["output"]
+
+ # Ground Truth
+ for gt_str in outputs:
+ try:
+ gt_json = json.loads(gt_str)
+ gt_has = len(gt_json.get("preferences", [])) > 0
+ except:
+ gt_has = False
+ y_true_has_pref.append(gt_has)
+
+ # Prediction
+ batch_items = [{"input": inp} for inp in inputs]
+ responses = batch_generate(model, tokenizer, batch_items, device)
+
+ for pred_str in responses:
+ pred_has = False
+ try:
+ pred_json = json.loads(pred_str)
+ json_valid_count += 1
+ pred_has = len(pred_json.get("preferences", [])) > 0
+ except:
+ pass
+ y_pred_has_pref.append(pred_has)
+
+ # 4. Metrics
+ print("\n" + "="*40)
+ print(f"RESULTS for {latest_ckpt}")
+ print("="*40)
+ print(f"JSON Validity: {json_valid_count / len(test_data):.4f}")
+ print(f"Accuracy: {accuracy_score(y_true_has_pref, y_pred_has_pref):.4f}")
+ print(f"Precision: {precision_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}")
+ print(f"Recall: {recall_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}")
+ print(f"F1 Score: {f1_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}")
+ print("="*40)
+
+if __name__ == "__main__":
+ evaluate_ckpt()
+