diff options
Diffstat (limited to 'scripts/evaluate_checkpoints.py')
| -rw-r--r-- | scripts/evaluate_checkpoints.py | 205 |
1 files changed, 205 insertions, 0 deletions
diff --git a/scripts/evaluate_checkpoints.py b/scripts/evaluate_checkpoints.py new file mode 100644 index 0000000..cb5c993 --- /dev/null +++ b/scripts/evaluate_checkpoints.py @@ -0,0 +1,205 @@ +import json +import os +import glob +import torch +import matplotlib.pyplot as plt +import pandas as pd +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score +from torch.utils.data import Dataset, DataLoader + +# --- Configuration --- +BASE_MODEL_NAME = "Qwen/Qwen3-0.6B" # Or local path models/Qwen3-0.6B +CHECKPOINT_DIR = "saves/qwen3-0.6b-full-sft-h200" +TEST_FILE = "data/test_llama_factory.json" +RESULTS_FILE = "evaluation_results.csv" +PLOT_FILE = "evaluation_plot.png" + +# H200 Optimization +BATCH_SIZE = 128 # H200 can handle massive batches for 0.6B model +USE_FLASH_ATTN = False + +# Load System Prompt +with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f: + SYSTEM_PROMPT = f.read() + +class EvalDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + +def load_test_data(): + with open(TEST_FILE, "r", encoding="utf-8") as f: + return json.load(f) + +def batch_generate(model, tokenizer, batch_data, device="cuda"): + prompts = [] + for item in batch_data: + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": item["input"]} + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + prompts.append(text) + + inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(device) + + with torch.no_grad(): + generated_ids = model.generate( + **inputs, + max_new_tokens=256, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + # Slice only generated tokens + input_len = inputs.input_ids.shape[1] + gen_tokens = generated_ids[:, input_len:] + responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) + return responses + +def evaluate_single_model(model_path, test_data, device="cuda"): + print(f"Loading model: {model_path}...") + try: + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="left") + # Ensure pad token is set for batch generation + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + kwargs = {"device_map": device, "torch_dtype": torch.bfloat16, "trust_remote_code": True} + if USE_FLASH_ATTN: + kwargs["attn_implementation"] = "flash_attention_2" + + model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + except Exception as e: + print(f"Failed to load {model_path}: {e}") + return None + + model.eval() + + dataset = EvalDataset(test_data) + dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4) # Use workers for data loading + + y_true_has_pref = [] + y_pred_has_pref = [] + json_valid_count = 0 + + print(f"Evaluating on {len(test_data)} samples (Batch Size: {BATCH_SIZE})...") + + for batch in tqdm(dataloader): + # batch is a dict of lists because default collate + # we need to reconstruct list of dicts or just access lists + # DataLoader collates list of dicts into dict of lists: {"input": [...], "output": [...]} + inputs = batch["input"] + outputs = batch["output"] + + # Ground Truth + for gt_str in outputs: + try: + gt_json = json.loads(gt_str) + gt_has = len(gt_json.get("preferences", [])) > 0 + except: + gt_has = False + y_true_has_pref.append(gt_has) + + # Prediction + # batch_data structure required by batch_generate needs to be list of dicts with "input" key + # Reconstruct for helper function + batch_items = [{"input": inp} for inp in inputs] + responses = batch_generate(model, tokenizer, batch_items, device) + + for pred_str in responses: + pred_has = False + try: + pred_json = json.loads(pred_str) + json_valid_count += 1 + pred_has = len(pred_json.get("preferences", [])) > 0 + except: + pass + y_pred_has_pref.append(pred_has) + + # Metrics + metrics = { + "json_validity": json_valid_count / len(test_data), + "accuracy": accuracy_score(y_true_has_pref, y_pred_has_pref), + "precision": precision_score(y_true_has_pref, y_pred_has_pref, zero_division=0), + "recall": recall_score(y_true_has_pref, y_pred_has_pref, zero_division=0), + "f1": f1_score(y_true_has_pref, y_pred_has_pref, zero_division=0) + } + + del model + del tokenizer + torch.cuda.empty_cache() + + return metrics + +def main(): + test_data = load_test_data() + results = [] + + # 1. Evaluate Base Model + print("\n--- Evaluating Base Model ---") + base_metrics = evaluate_single_model(BASE_MODEL_NAME, test_data) + if base_metrics: + base_metrics["step"] = 0 + base_metrics["model"] = "Base" + results.append(base_metrics) + print(f"Base: {base_metrics}") + + # 2. Evaluate Checkpoints + checkpoints = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*")), key=lambda x: int(x.split("-")[-1])) + print(f"\nFound {len(checkpoints)} checkpoints.") + + # Filter to only Base + Last Checkpoint (User Request) + if checkpoints: + checkpoints = [checkpoints[-1]] + print(f"Selecting only the last checkpoint: {checkpoints[0]}") + + for ckpt in checkpoints: + step = int(ckpt.split("-")[-1]) + print(f"\n--- Evaluating Checkpoint {step} ---") + metrics = evaluate_single_model(ckpt, test_data) + if metrics: + metrics["step"] = step + metrics["model"] = f"Ckpt-{step}" + results.append(metrics) + print(f"Step {step}: {metrics}") + + # 3. Save & Plot + if not results: + print("No results generated.") + return + + df = pd.DataFrame(results) + df = df.sort_values("step") + df.to_csv(RESULTS_FILE, index=False) + print(f"\nResults saved to {RESULTS_FILE}") + print(df) + + plt.figure(figsize=(10, 6)) + plt.plot(df["step"], df["f1"], marker='o', label="F1 Score") + plt.plot(df["step"], df["precision"], marker='s', label="Precision") + plt.plot(df["step"], df["recall"], marker='^', label="Recall") + plt.plot(df["step"], df["json_validity"], marker='x', linestyle='--', label="JSON Validity") + + plt.title("Preference Extractor Training Progress") + plt.xlabel("Training Steps") + plt.ylabel("Score") + plt.legend() + plt.grid(True) + plt.savefig(PLOT_FILE) + print(f"Plot saved to {PLOT_FILE}") + +if __name__ == "__main__": + main() |
