import json import os import torch import glob from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score from torch.utils.data import Dataset, DataLoader # --- Configuration --- # You can manually set the checkpoint path here if glob fails or is slow # Example: "saves/qwen3-0.6b-full-sft-h200/checkpoint-4358" CHECKPOINT_DIR = "saves/qwen3-0.6b-full-sft-h200" TEST_FILE = "data/test_llama_factory.json" BATCH_SIZE = 128 USE_FLASH_ATTN = False # Load System Prompt with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: SYSTEM_PROMPT = f.read() class EvalDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def load_test_data(): with open(TEST_FILE, "r", encoding="utf-8") as f: return json.load(f) def batch_generate(model, tokenizer, batch_data, device="cuda"): prompts = [] for item in batch_data: messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": item["input"]} ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) prompts.append(text) inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(device) with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=256, do_sample=False, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) input_len = inputs.input_ids.shape[1] gen_tokens = generated_ids[:, input_len:] responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) return responses def evaluate_ckpt(): # 1. Find the latest checkpoint checkpoints = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*")), key=lambda x: int(x.split("-")[-1])) if not checkpoints: print(f"No checkpoints found in {CHECKPOINT_DIR}") return latest_ckpt = checkpoints[-1] print(f"\nTarget Checkpoint: {latest_ckpt}") device = "cuda" print(f"Loading model (Batch Size: {BATCH_SIZE})...") try: tokenizer = AutoTokenizer.from_pretrained(latest_ckpt, trust_remote_code=True, padding_side="left") if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id kwargs = {"device_map": device, "torch_dtype": torch.bfloat16, "trust_remote_code": True} if USE_FLASH_ATTN: kwargs["attn_implementation"] = "flash_attention_2" model = AutoModelForCausalLM.from_pretrained(latest_ckpt, **kwargs) model.eval() except Exception as e: print(f"CRITICAL ERROR loading model: {e}") return # 2. Prepare Data test_data = load_test_data() dataset = EvalDataset(test_data) # Reduce num_workers to avoid hang if system is stressed dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) y_true_has_pref = [] y_pred_has_pref = [] json_valid_count = 0 print(f"Evaluating {len(test_data)} samples...") # 3. Inference Loop for batch in tqdm(dataloader): inputs = batch["input"] outputs = batch["output"] # Ground Truth for gt_str in outputs: try: gt_json = json.loads(gt_str) gt_has = len(gt_json.get("preferences", [])) > 0 except: gt_has = False y_true_has_pref.append(gt_has) # Prediction batch_items = [{"input": inp} for inp in inputs] responses = batch_generate(model, tokenizer, batch_items, device) for pred_str in responses: pred_has = False try: pred_json = json.loads(pred_str) json_valid_count += 1 pred_has = len(pred_json.get("preferences", [])) > 0 except: pass y_pred_has_pref.append(pred_has) # 4. Metrics print("\n" + "="*40) print(f"RESULTS for {latest_ckpt}") print("="*40) print(f"JSON Validity: {json_valid_count / len(test_data):.4f}") print(f"Accuracy: {accuracy_score(y_true_has_pref, y_pred_has_pref):.4f}") print(f"Precision: {precision_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}") print(f"Recall: {recall_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}") print(f"F1 Score: {f1_score(y_true_has_pref, y_pred_has_pref, zero_division=0):.4f}") print("="*40) if __name__ == "__main__": evaluate_ckpt()