summaryrefslogtreecommitdiff
path: root/scripts/evaluate_checkpoints.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/evaluate_checkpoints.py')
-rw-r--r--scripts/evaluate_checkpoints.py205
1 files changed, 205 insertions, 0 deletions
diff --git a/scripts/evaluate_checkpoints.py b/scripts/evaluate_checkpoints.py
new file mode 100644
index 0000000..cb5c993
--- /dev/null
+++ b/scripts/evaluate_checkpoints.py
@@ -0,0 +1,205 @@
+import json
+import os
+import glob
+import torch
+import matplotlib.pyplot as plt
+import pandas as pd
+from tqdm import tqdm
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
+from torch.utils.data import Dataset, DataLoader
+
+# --- Configuration ---
+BASE_MODEL_NAME = "Qwen/Qwen3-0.6B" # Or local path models/Qwen3-0.6B
+CHECKPOINT_DIR = "saves/qwen3-0.6b-full-sft-h200"
+TEST_FILE = "data/test_llama_factory.json"
+RESULTS_FILE = "evaluation_results.csv"
+PLOT_FILE = "evaluation_plot.png"
+
+# H200 Optimization
+BATCH_SIZE = 128 # H200 can handle massive batches for 0.6B model
+USE_FLASH_ATTN = False
+
+# Load System Prompt
+with open("fine_tuning_prompt_template.txt", "r", encoding="utf-8") as f:
+ SYSTEM_PROMPT = f.read()
+
+class EvalDataset(Dataset):
+ def __init__(self, data):
+ self.data = data
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+def load_test_data():
+ with open(TEST_FILE, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+def batch_generate(model, tokenizer, batch_data, device="cuda"):
+ prompts = []
+ for item in batch_data:
+ messages = [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": item["input"]}
+ ]
+ text = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+ prompts.append(text)
+
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left").to(device)
+
+ with torch.no_grad():
+ generated_ids = model.generate(
+ **inputs,
+ max_new_tokens=256,
+ do_sample=False,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ )
+
+ # Slice only generated tokens
+ input_len = inputs.input_ids.shape[1]
+ gen_tokens = generated_ids[:, input_len:]
+ responses = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
+ return responses
+
+def evaluate_single_model(model_path, test_data, device="cuda"):
+ print(f"Loading model: {model_path}...")
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="left")
+ # Ensure pad token is set for batch generation
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+
+ kwargs = {"device_map": device, "torch_dtype": torch.bfloat16, "trust_remote_code": True}
+ if USE_FLASH_ATTN:
+ kwargs["attn_implementation"] = "flash_attention_2"
+
+ model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
+ except Exception as e:
+ print(f"Failed to load {model_path}: {e}")
+ return None
+
+ model.eval()
+
+ dataset = EvalDataset(test_data)
+ dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4) # Use workers for data loading
+
+ y_true_has_pref = []
+ y_pred_has_pref = []
+ json_valid_count = 0
+
+ print(f"Evaluating on {len(test_data)} samples (Batch Size: {BATCH_SIZE})...")
+
+ for batch in tqdm(dataloader):
+ # batch is a dict of lists because default collate
+ # we need to reconstruct list of dicts or just access lists
+ # DataLoader collates list of dicts into dict of lists: {"input": [...], "output": [...]}
+ inputs = batch["input"]
+ outputs = batch["output"]
+
+ # Ground Truth
+ for gt_str in outputs:
+ try:
+ gt_json = json.loads(gt_str)
+ gt_has = len(gt_json.get("preferences", [])) > 0
+ except:
+ gt_has = False
+ y_true_has_pref.append(gt_has)
+
+ # Prediction
+ # batch_data structure required by batch_generate needs to be list of dicts with "input" key
+ # Reconstruct for helper function
+ batch_items = [{"input": inp} for inp in inputs]
+ responses = batch_generate(model, tokenizer, batch_items, device)
+
+ for pred_str in responses:
+ pred_has = False
+ try:
+ pred_json = json.loads(pred_str)
+ json_valid_count += 1
+ pred_has = len(pred_json.get("preferences", [])) > 0
+ except:
+ pass
+ y_pred_has_pref.append(pred_has)
+
+ # Metrics
+ metrics = {
+ "json_validity": json_valid_count / len(test_data),
+ "accuracy": accuracy_score(y_true_has_pref, y_pred_has_pref),
+ "precision": precision_score(y_true_has_pref, y_pred_has_pref, zero_division=0),
+ "recall": recall_score(y_true_has_pref, y_pred_has_pref, zero_division=0),
+ "f1": f1_score(y_true_has_pref, y_pred_has_pref, zero_division=0)
+ }
+
+ del model
+ del tokenizer
+ torch.cuda.empty_cache()
+
+ return metrics
+
+def main():
+ test_data = load_test_data()
+ results = []
+
+ # 1. Evaluate Base Model
+ print("\n--- Evaluating Base Model ---")
+ base_metrics = evaluate_single_model(BASE_MODEL_NAME, test_data)
+ if base_metrics:
+ base_metrics["step"] = 0
+ base_metrics["model"] = "Base"
+ results.append(base_metrics)
+ print(f"Base: {base_metrics}")
+
+ # 2. Evaluate Checkpoints
+ checkpoints = sorted(glob.glob(os.path.join(CHECKPOINT_DIR, "checkpoint-*")), key=lambda x: int(x.split("-")[-1]))
+ print(f"\nFound {len(checkpoints)} checkpoints.")
+
+ # Filter to only Base + Last Checkpoint (User Request)
+ if checkpoints:
+ checkpoints = [checkpoints[-1]]
+ print(f"Selecting only the last checkpoint: {checkpoints[0]}")
+
+ for ckpt in checkpoints:
+ step = int(ckpt.split("-")[-1])
+ print(f"\n--- Evaluating Checkpoint {step} ---")
+ metrics = evaluate_single_model(ckpt, test_data)
+ if metrics:
+ metrics["step"] = step
+ metrics["model"] = f"Ckpt-{step}"
+ results.append(metrics)
+ print(f"Step {step}: {metrics}")
+
+ # 3. Save & Plot
+ if not results:
+ print("No results generated.")
+ return
+
+ df = pd.DataFrame(results)
+ df = df.sort_values("step")
+ df.to_csv(RESULTS_FILE, index=False)
+ print(f"\nResults saved to {RESULTS_FILE}")
+ print(df)
+
+ plt.figure(figsize=(10, 6))
+ plt.plot(df["step"], df["f1"], marker='o', label="F1 Score")
+ plt.plot(df["step"], df["precision"], marker='s', label="Precision")
+ plt.plot(df["step"], df["recall"], marker='^', label="Recall")
+ plt.plot(df["step"], df["json_validity"], marker='x', linestyle='--', label="JSON Validity")
+
+ plt.title("Preference Extractor Training Progress")
+ plt.xlabel("Training Steps")
+ plt.ylabel("Score")
+ plt.legend()
+ plt.grid(True)
+ plt.savefig(PLOT_FILE)
+ print(f"Plot saved to {PLOT_FILE}")
+
+if __name__ == "__main__":
+ main()