diff options
Diffstat (limited to 'scripts/eval_single_ckpt.py')
| -rw-r--r-- | scripts/eval_single_ckpt.py | 145 |
1 files changed, 145 insertions, 0 deletions
diff --git a/scripts/eval_single_ckpt.py b/scripts/eval_single_ckpt.py new file mode 100644 index 0000000..8597907 --- /dev/null +++ b/scripts/eval_single_ckpt.py @@ -0,0 +1,145 @@ +import json +import os +import torch +import glob +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score +from torch.utils.data import Dataset, DataLoader + +# --- Configuration --- +# You can manually set the checkpoint path here if glob fails or is slow +# Example: "saves/qwen3-0.6b-full-sft-h200/checkpoint-4358" +CHECKPOINT_DIR = "saves/qwen3-0.6b-full-sft-h200" +TEST_FILE = "data/test_llama_factory.json" +BATCH_SIZE = 128 +USE_FLASH_ATTN = False + +# Load System Prompt +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + SYSTEM_PROMPT = f.read() + +class EvalDataset(Dataset): + def __init__(self, data): + self.data = data + def __len__(self): + return len(self.data) + def __getitem__(self, idx): + return self.data[idx] + +def load_test_data(): + with open(TEST_FILE, "r", encoding="utf-8") as f: + return json.load(f) + +def batch_generate(model, tokenizer, batch_data, device="cuda"): + prompts = [] + for item in batch_data: + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": item["input"]} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + prompts.append(text) + + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(device) + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=256, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + input_len = inputs.input_ids.shape[1] + gen_tokens = generated_ids[:, input_len:] + responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) + return responses + +def evaluate_ckpt(): + # 1. Find the latest checkpoint + checkpoints = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*")), key=lambda x: int(x.split("-")[-1])) + if not checkpoints: + print(f"No checkpoints found in {CHECKPOINT_DIR}") + return + + latest_ckpt = checkpoints[-1] + print(f"\nTarget Checkpoint: {latest_ckpt}") + + device = "cuda" + print(f"Loading model (Batch Size: {BATCH_SIZE})...") + + try: + tokenizer = AutoTokenizer.from_pretrained(latest_ckpt, trust_remote_code=True, padding_side="left") + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + kwargs = {"device_map": device, "torch_dtype": torch.bfloat16, "trust_remote_code": True} + if USE_FLASH_ATTN: + kwargs["attn_implementation"] = "flash_attention_2" + + model = AutoModelForCausalLM.from_pretrained(latest_ckpt, **kwargs) + model.eval() + except Exception as e: + print(f"CRITICAL ERROR loading model: {e}") + return + + # 2. Prepare Data + test_data = load_test_data() + dataset = EvalDataset(test_data) + # Reduce num_workers to avoid hang if system is stressed + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) + + y_true_has_pref = [] + y_pred_has_pref = [] + json_valid_count = 0 + + print(f"Evaluating {len(test_data)} samples...") + + # 3. Inference Loop + for batch in tqdm(dataloader): + inputs = batch["input"] + outputs = batch["output"] + + # Ground Truth + for gt_str in outputs: + try: + gt_json = json.loads(gt_str) + gt_has = len(gt_json.get("preferences", [])) > 0 + except: + gt_has = False + y_true_has_pref.append(gt_has) + + # Prediction + batch_items = [{"input": inp} for inp in inputs] + responses = batch_generate(model, tokenizer, batch_items, device) + + for pred_str in responses: + pred_has = False + try: + pred_json = json.loads(pred_str) + json_valid_count += 1 + pred_has = len(pred_json.get("preferences", [])) > 0 + except: + pass + y_pred_has_pref.append(pred_has) + + # 4. Metrics + print("\n" + "="*40) + print(f"RESULTS for {latest_ckpt}") + print("="*40) + print(f"JSON Validity: {json_valid_count / len(test_data):.4f}") + print(f"Accuracy: {accuracy_score(y_true_has_pref, y_pred_has_pref):.4f}") + print(f"Precision: {precision_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}") + print(f"Recall: {recall_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}") + print(f"F1 Score: {f1_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}") + print("="*40) + +if __name__ == "__main__": + evaluate_ckpt() + |
