From fc6d57ffb8d5ddb5820fcc00b5491a585c259ebc Mon Sep 17 00:00:00 2001 From: Yuren Hao Date: Thu, 4 Sep 2025 22:16:22 -0500 Subject: Initial commit --- logits_shift.py | 132 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 logits_shift.py (limited to 'logits_shift.py') diff --git a/logits_shift.py b/logits_shift.py new file mode 100644 index 0000000..28b7809 --- /dev/null +++ b/logits_shift.py @@ -0,0 +1,132 @@ +import os, gc, random +import numpy as np +import pandas as pd +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from tqdm import tqdm +import matplotlib.pyplot as plt +from scipy.stats import skew + +PARQUET_PATH = "data/train/one_shot_rlvr/dsr_sub.parquet" +NUM_SAMPLES = 20 +MAX_NEW_TOKENS = 500 +TEMPERATURE = 0.6 +PAD_TOKEN_ID = 151643 + +model_paths = { + "A": "/volume/ailab4sci/models/Qwen2.5-Math-7B", # base + "B": "/volume/ailab4sci/ztgao/em/checkpoints/qwen25math7b/t05_lr2e-05_bsz64_seed15/step_10", # em + "C": "/volume/ailab4sci/ztgao/One-Shot-RLVR/checkpoints/verl_few_shot/Qwen2.5-Math-7B-origin-dsr_sub/global_step_460/actor", # rl + "D": "/volume/ailab4sci/ztgao/One-Shot-RLVR/checkpoints/verl_few_shot/Qwen2.5-Math-7B-em_step10/global_step_460/actor" # emrl +} + +torch.manual_seed(42) + +df = pd.read_parquet(PARQUET_PATH).head(NUM_SAMPLES) +prompts = df["prompt"].tolist() +messages_list = [list(arr) for arr in prompts] +logits_dict = {m: [] for m in model_paths} + +for model_name, model_path in model_paths.items(): + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True + ) + model.eval() + + for msg in tqdm(messages_list, desc=f"{model_name} samples"): + text = tokenizer.apply_chat_template(msg, tokenize=False) + inputs = tokenizer(text, return_tensors="pt") + input_ids = inputs["input_ids"].to(model.device) + attention_mask = torch.ones_like(input_ids).to(model.device) + + with torch.no_grad(): + output = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=True, + temperature=TEMPERATURE, + return_dict_in_generate=True, + output_scores=True, + pad_token_id=PAD_TOKEN_ID, + ) + + logits_steps = [score[0].cpu().numpy() for score in output.scores] + logits_dict[model_name].append(logits_steps) + + del model + gc.collect() + torch.cuda.empty_cache() + +min_len = min( + len(step_list) for model in logits_dict.values() for step_list in model +) + +for model in logits_dict: + logits_dict[model] = [steps[:min_len] for steps in logits_dict[model]] + +def flatten_model_logits(model_logits): + flat = np.concatenate([np.stack(step_list).reshape(-1) for step_list in model_logits]) + flat = np.nan_to_num(flat, nan=0.0, posinf=0.0, neginf=0.0) + return flat + +flat_A = flatten_model_logits(logits_dict["A"]) +flat_B = flatten_model_logits(logits_dict["B"]) +flat_C = flatten_model_logits(logits_dict["C"]) +flat_D = flatten_model_logits(logits_dict["D"]) + +def stable_softmax(x): + x = x - np.max(x) + exp_x = np.exp(x) + return exp_x / exp_x.sum() + +prob_A = stable_softmax(flat_A) +prob_B = stable_softmax(flat_B) +prob_C = stable_softmax(flat_C) +prob_D = stable_softmax(flat_D) + + +fig, axes = plt.subplots(2, 2, figsize=(12, 8)) +axes = axes.flatten() +colors = ["#6BAED6", "#9ECAE1", "#C6DBEF", "#08306B"] + +all_logits = [] +flattened_logits_dict = {} + +for name, logits_list in logits_dict.items(): + logits_all = torch.tensor(np.stack(logits_list)) + logits_all = torch.nan_to_num(logits_all, nan=0.0, posinf=0.0, neginf=0.0) + flattened = logits_all[logits_all != 0].flatten().cpu().numpy() + + if flattened.size == 0: + continue + + flattened_logits_dict[name] = flattened + all_logits.append(flattened) + +combined = np.concatenate(all_logits) +global_min = combined.min() +global_max = combined.max() + +for i, (name, flat_logits) in enumerate(flattened_logits_dict.items()): + ax = axes[i] + ax.hist(flat_logits, bins=50, color=colors[i], alpha=0.8) + ax.set_xlim(global_min, global_max) + names = ["Base", "EM", "RL", "EM-RL"] + ax.set_title(f"{names[i]} Logits") + ax.set_xlabel("Logit Value") + ax.set_ylabel("Frequency") + mu = flat_logits.mean() + sigma = flat_logits.std() + skewness = skew(flat_logits) + stats_text = f"μ={mu:.2f}\nσ={sigma:.2f}\nskew={skewness:.2f}" + ax.text(0.985, 0.97, stats_text, transform=ax.transAxes, fontsize=10, va='top', ha='right', + bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.7)) + +for j in range(len(flattened_logits_dict), 4): + axes[j].axis("off") + +plt.tight_layout(rect=[0, 0.03, 1, 0.95]) +plt.savefig("logits.pdf", format="pdf", bbox_inches="tight") +plt.show() \ No newline at end of file -- cgit v1.2.3