summaryrefslogtreecommitdiff
path: root/scripts/analyze_learning_trend.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2025-12-17 04:29:37 -0600
commite43b3f8aa36c198b95c1e46bea2eaf3893b13dc3 (patch)
tree6ce8a00d2f8b9ebd83c894a27ea01ac50cfb2ff5 /scripts/analyze_learning_trend.py
Initial commit (clean history)HEADmain
Diffstat (limited to 'scripts/analyze_learning_trend.py')
-rw-r--r--scripts/analyze_learning_trend.py521
1 files changed, 521 insertions, 0 deletions
diff --git a/scripts/analyze_learning_trend.py b/scripts/analyze_learning_trend.py
new file mode 100644
index 0000000..9ab4699
--- /dev/null
+++ b/scripts/analyze_learning_trend.py
@@ -0,0 +1,521 @@
+#!/usr/bin/env python3
+"""
+Analyze Learning Trend: Correlation and z_u Norm over Sessions
+
+This script shows that:
+1. User vector norms (||z_u||) grow over sessions (learning is happening)
+2. Correlation between learned and ground-truth similarity increases over sessions
+
+Usage:
+ python scripts/analyze_learning_trend.py \
+ --logs data/logs/pilot_v4_full-greedy_*.jsonl
+"""
+
+import argparse
+import json
+import numpy as np
+from typing import Dict, List, Tuple
+from collections import defaultdict
+from dataclasses import dataclass
+import os
+
+
+# =============================================================================
+# Persona Definitions (ground truth)
+# =============================================================================
+
+@dataclass
+class StylePrefs:
+ require_short: bool = False
+ max_chars: int = 300
+ require_bullets: bool = False
+ lang: str = "en"
+
+
+PERSONAS = {
+ "user_A_short_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="en"),
+ "user_B_short_no_bullets_en": StylePrefs(require_short=True, max_chars=200, require_bullets=False, lang="en"),
+ "user_C_long_bullets_en": StylePrefs(require_short=False, max_chars=800, require_bullets=True, lang="en"),
+ "user_D_short_bullets_zh": StylePrefs(require_short=True, max_chars=200, require_bullets=True, lang="zh"),
+ "user_E_long_no_bullets_zh": StylePrefs(require_short=False, max_chars=800, require_bullets=False, lang="zh"),
+ "user_F_extreme_short_en": StylePrefs(require_short=True, max_chars=100, require_bullets=True, lang="en"),
+}
+
+
+# =============================================================================
+# Data Loading
+# =============================================================================
+
+def load_logs(filepath: str) -> List[dict]:
+ """Load turn logs from JSONL file."""
+ logs = []
+ with open(filepath, "r") as f:
+ for line in f:
+ if line.strip():
+ logs.append(json.loads(line))
+ return logs
+
+
+def extract_z_norms_by_session(logs: List[dict]) -> Dict[str, Dict[int, Tuple[float, float]]]:
+ """
+ Extract z_long_norm and z_short_norm at the end of each session for each user.
+
+ Returns:
+ {user_id: {session_id: (z_long_norm, z_short_norm)}}
+ """
+ user_session_norms = defaultdict(dict)
+
+ # Group by user and session, take the last turn's z_norm
+ user_session_turns = defaultdict(lambda: defaultdict(list))
+ for log in logs:
+ user_id = log["user_id"]
+ session_id = log["session_id"]
+ user_session_turns[user_id][session_id].append(log)
+
+ for user_id, sessions in user_session_turns.items():
+ for session_id, turns in sessions.items():
+ # Get the last turn of this session
+ last_turn = max(turns, key=lambda x: x["turn_id"])
+ z_long = last_turn.get("z_long_norm_after", 0.0)
+ z_short = last_turn.get("z_short_norm_after", 0.0)
+ user_session_norms[user_id][session_id] = (z_long, z_short)
+
+ return dict(user_session_norms)
+
+
+# =============================================================================
+# Similarity Computation
+# =============================================================================
+
+def cosine_similarity(v1: np.ndarray, v2: np.ndarray) -> float:
+ """Compute cosine similarity."""
+ norm1 = np.linalg.norm(v1)
+ norm2 = np.linalg.norm(v2)
+ if norm1 < 1e-10 or norm2 < 1e-10:
+ return 0.0
+ return float(np.dot(v1, v2) / (norm1 * norm2))
+
+
+def compute_ground_truth_similarity_matrix(user_order: List[str]) -> np.ndarray:
+ """Compute ground truth similarity based on preference overlap."""
+ n = len(user_order)
+ sim_matrix = np.zeros((n, n))
+
+ for i, u1 in enumerate(user_order):
+ for j, u2 in enumerate(user_order):
+ if u1 not in PERSONAS or u2 not in PERSONAS:
+ sim_matrix[i, j] = 0.0 if i != j else 1.0
+ continue
+
+ p1 = PERSONAS[u1]
+ p2 = PERSONAS[u2]
+
+ matches = 0
+ if p1.require_short == p2.require_short:
+ matches += 1
+ if p1.require_bullets == p2.require_bullets:
+ matches += 1
+ if p1.lang == p2.lang:
+ matches += 1
+
+ sim_matrix[i, j] = matches / 3.0
+
+ return sim_matrix
+
+
+def compute_spearman_correlation(learned: np.ndarray, ground_truth: np.ndarray) -> float:
+ """Compute Spearman correlation between similarity matrices."""
+ from scipy.stats import spearmanr
+
+ n = learned.shape[0]
+ learned_flat = []
+ gt_flat = []
+
+ for i in range(n):
+ for j in range(i + 1, n):
+ learned_flat.append(learned[i, j])
+ gt_flat.append(ground_truth[i, j])
+
+ if len(learned_flat) < 2:
+ return 0.0
+
+ # Handle case where all values are the same
+ if np.std(learned_flat) < 1e-10:
+ return 0.0
+
+ corr, _ = spearmanr(learned_flat, gt_flat)
+ return float(corr) if not np.isnan(corr) else 0.0
+
+
+def load_final_z_vectors(user_store_path: str) -> Dict[str, Tuple[np.ndarray, np.ndarray]]:
+ """Load final z_u vectors from saved user store."""
+ try:
+ data = np.load(user_store_path, allow_pickle=True)
+ user_vectors = {}
+
+ # UserTensorStore saves in format: {uid}_long, {uid}_short
+ user_ids = set()
+ for key in data.files:
+ if key.endswith("_long"):
+ uid = key[:-5]
+ user_ids.add(uid)
+
+ for uid in user_ids:
+ long_key = f"{uid}_long"
+ short_key = f"{uid}_short"
+ if long_key in data.files and short_key in data.files:
+ user_vectors[uid] = (data[long_key], data[short_key])
+
+ return user_vectors
+ except Exception as e:
+ print(f"[Warning] Could not load user store: {e}")
+ return {}
+
+
+# Global cache for final z vectors
+_FINAL_Z_VECTORS = None
+
+
+def get_z_vectors_at_session(
+ logs: List[dict],
+ user_order: List[str],
+ up_to_session: int,
+ final_z_vectors: Dict[str, Tuple[np.ndarray, np.ndarray]]
+) -> Dict[str, np.ndarray]:
+ """
+ Estimate z_u vectors at a given session checkpoint.
+
+ Method: Use the DIRECTION of the final z_u, scaled by the z_norm at session s.
+ This assumes z_u direction is relatively stable but magnitude grows.
+
+ z_u(s) ≈ (z_final / ||z_final||) * ||z(s)||
+ """
+ user_vectors = {}
+
+ for user_id in user_order:
+ # Get z_norm at the end of this session
+ user_turns = [l for l in logs if l["user_id"] == user_id and l["session_id"] <= up_to_session]
+
+ if not user_turns:
+ user_vectors[user_id] = np.zeros(512) # 256 + 256
+ continue
+
+ # Get the last turn's z_norm at this session
+ last_turn = max(user_turns, key=lambda x: (x["session_id"], x["turn_id"]))
+ z_long_norm_s = last_turn.get("z_long_norm_after", 0.0)
+ z_short_norm_s = last_turn.get("z_short_norm_after", 0.0)
+
+ # Get final z vectors (direction)
+ if user_id in final_z_vectors:
+ z_long_final, z_short_final = final_z_vectors[user_id]
+
+ # Compute unit vectors (direction)
+ z_long_final_norm = np.linalg.norm(z_long_final)
+ z_short_final_norm = np.linalg.norm(z_short_final)
+
+ if z_long_final_norm > 1e-10:
+ z_long_unit = z_long_final / z_long_final_norm
+ else:
+ z_long_unit = np.zeros_like(z_long_final)
+
+ if z_short_final_norm > 1e-10:
+ z_short_unit = z_short_final / z_short_final_norm
+ else:
+ z_short_unit = np.zeros_like(z_short_final)
+
+ # Scale by the norm at this session
+ z_long_s = z_long_unit * z_long_norm_s
+ z_short_s = z_short_unit * z_short_norm_s
+
+ # Concatenate
+ user_vectors[user_id] = np.concatenate([z_long_s, z_short_s])
+ else:
+ user_vectors[user_id] = np.zeros(512)
+
+ return user_vectors
+
+
+def compute_similarity_at_session(
+ logs: List[dict],
+ user_order: List[str],
+ up_to_session: int,
+ final_z_vectors: Dict[str, Tuple[np.ndarray, np.ndarray]] = None
+) -> np.ndarray:
+ """Compute learned similarity matrix at a given session using actual z vectors."""
+ if final_z_vectors:
+ user_vectors = get_z_vectors_at_session(logs, user_order, up_to_session, final_z_vectors)
+ else:
+ # Fallback to old method
+ user_vectors = simulate_z_vectors_at_session_fallback(logs, user_order, up_to_session)
+
+ n = len(user_order)
+ sim_matrix = np.zeros((n, n))
+
+ for i, u1 in enumerate(user_order):
+ for j, u2 in enumerate(user_order):
+ v1 = user_vectors.get(u1, np.zeros(512))
+ v2 = user_vectors.get(u2, np.zeros(512))
+ sim_matrix[i, j] = cosine_similarity(v1, v2)
+
+ return sim_matrix
+
+
+def simulate_z_vectors_at_session_fallback(
+ logs: List[dict],
+ user_order: List[str],
+ up_to_session: int,
+ dim: int = 256
+) -> Dict[str, np.ndarray]:
+ """Fallback: simulate z_u based on violation patterns (less accurate)."""
+ user_vectors = {}
+
+ for user_id in user_order:
+ user_turns = [l for l in logs if l["user_id"] == user_id and l["session_id"] <= up_to_session]
+
+ if not user_turns:
+ user_vectors[user_id] = np.zeros(dim * 2)
+ continue
+
+ last_turn = max(user_turns, key=lambda x: (x["session_id"], x["turn_id"]))
+ z_long_norm = last_turn.get("z_long_norm_after", 0.0)
+ z_short_norm = last_turn.get("z_short_norm_after", 0.0)
+
+ violation_counts = defaultdict(int)
+ for turn in user_turns:
+ for v in turn.get("violations", []):
+ violation_counts[v] += 1
+
+ feature_dim = 10
+ features = np.zeros(feature_dim)
+ features[0] = violation_counts.get("too_long", 0)
+ features[1] = violation_counts.get("no_bullets", 0)
+ features[2] = violation_counts.get("has_bullets", 0)
+ features[3] = violation_counts.get("wrong_lang", 0)
+ features[4] = z_long_norm * 100
+ features[5] = z_short_norm * 100
+
+ norm = np.linalg.norm(features)
+ if norm > 1e-10:
+ features = features / norm
+
+ user_vectors[user_id] = features
+
+ return user_vectors
+
+
+def compute_similarity_at_session(
+ logs: List[dict],
+ user_order: List[str],
+ up_to_session: int
+) -> np.ndarray:
+ """Compute learned similarity matrix at a given session."""
+ user_vectors = simulate_z_vectors_at_session(logs, user_order, up_to_session)
+
+ n = len(user_order)
+ sim_matrix = np.zeros((n, n))
+
+ for i, u1 in enumerate(user_order):
+ for j, u2 in enumerate(user_order):
+ v1 = user_vectors.get(u1, np.zeros(10))
+ v2 = user_vectors.get(u2, np.zeros(10))
+ sim_matrix[i, j] = cosine_similarity(v1, v2)
+
+ return sim_matrix
+
+
+# =============================================================================
+# Main Analysis
+# =============================================================================
+
+def analyze_learning_trend(logs_path: str, output_dir: str = "data/analysis",
+ user_store_path: str = "data/users/user_store_pilot_v4_full-greedy.npz"):
+ """Analyze correlation and z_u norm trends over sessions."""
+ os.makedirs(output_dir, exist_ok=True)
+
+ print("=" * 70)
+ print("LEARNING TREND ANALYSIS")
+ print("=" * 70)
+
+ # Load logs
+ print(f"\n[1] Loading logs from: {logs_path}")
+ logs = load_logs(logs_path)
+ print(f" Loaded {len(logs)} turns")
+
+ # Get user order
+ user_order = [u for u in PERSONAS.keys() if any(l["user_id"] == u for l in logs)]
+ print(f" Users: {user_order}")
+
+ # Get max session
+ max_session = max(l["session_id"] for l in logs)
+ print(f" Sessions: 1 to {max_session}")
+
+ # Extract z_norms by session
+ print("\n[2] Extracting z_u norms by session...")
+ z_norms_by_session = extract_z_norms_by_session(logs)
+
+ # Load final z vectors from user store
+ print(f"\n[2.5] Loading final z vectors from: {user_store_path}")
+ final_z_vectors = load_final_z_vectors(user_store_path)
+ if final_z_vectors:
+ print(f" Loaded final z vectors for {len(final_z_vectors)} users")
+ else:
+ print(" [Warning] No final z vectors found, using fallback method")
+
+ # Compute ground truth similarity (constant)
+ gt_sim = compute_ground_truth_similarity_matrix(user_order)
+
+ # Compute CUMULATIVE correlation and avg z_norm
+ # At session N, we use all data from session 1 to N
+ print("\n[3] Computing CUMULATIVE correlation trend (S1→S1-2→S1-3→...→S1-N)...")
+ sessions = list(range(1, max_session + 1))
+ correlations = []
+ avg_z_norms = []
+
+ for s in sessions:
+ # Compute similarity using z_u at end of session s (cumulative learning)
+ learned_sim = compute_similarity_at_session(logs, user_order, s, final_z_vectors)
+ corr = compute_spearman_correlation(learned_sim, gt_sim)
+ correlations.append(corr)
+
+ # Compute average z_norm at the END of session s (this is already cumulative)
+ z_norms = []
+ for user_id in user_order:
+ if user_id in z_norms_by_session and s in z_norms_by_session[user_id]:
+ zl, zs = z_norms_by_session[user_id][s]
+ z_norms.append(np.sqrt(zl**2 + zs**2)) # Combined norm
+
+ avg_z = np.mean(z_norms) if z_norms else 0.0
+ avg_z_norms.append(avg_z)
+
+ # Print results
+ print("\n[4] Results:")
+ print("-" * 60)
+ print(f"{'Session':<10} {'Correlation':<15} {'Avg ||z_u||':<15}")
+ print("-" * 60)
+ for s, corr, z_norm in zip(sessions, correlations, avg_z_norms):
+ print(f"{s:<10} {corr:<15.4f} {z_norm:<15.6f}")
+
+ # Summary statistics
+ print("\n[5] Trend Summary:")
+ print("-" * 60)
+
+ # Linear regression for correlation trend
+ from scipy.stats import linregress
+ slope_corr, intercept_corr, r_corr, p_corr, _ = linregress(sessions, correlations)
+ print(f" Correlation trend: slope={slope_corr:.4f}, R²={r_corr**2:.4f}, p={p_corr:.4f}")
+
+ # Linear regression for z_norm trend
+ slope_z, intercept_z, r_z, p_z, _ = linregress(sessions, avg_z_norms)
+ print(f" ||z_u|| trend: slope={slope_z:.6f}, R²={r_z**2:.4f}, p={p_z:.4f}")
+
+ # Correlation between the two trends
+ trend_corr, _ = spearmanr(correlations, avg_z_norms) if len(correlations) > 2 else (0, 1)
+ print(f" Correlation between trends: {trend_corr:.4f}")
+
+ # Save data
+ results = {
+ "sessions": np.array(sessions),
+ "correlations": np.array(correlations),
+ "avg_z_norms": np.array(avg_z_norms),
+ "slope_corr": slope_corr,
+ "slope_z": slope_z,
+ "trend_corr": trend_corr,
+ }
+ results_path = os.path.join(output_dir, "learning_trend_results.npz")
+ np.savez(results_path, **results)
+ print(f"\n[Results] Saved to: {results_path}")
+
+ # Plot
+ print("\n[6] Generating plots...")
+ plot_learning_trend(sessions, correlations, avg_z_norms, output_dir)
+
+ print("\n" + "=" * 70)
+ print("ANALYSIS COMPLETE")
+ print("=" * 70)
+
+ return results
+
+
+def plot_learning_trend(sessions, correlations, avg_z_norms, output_dir):
+ """Generate plots for learning trend."""
+ try:
+ import matplotlib.pyplot as plt
+ import matplotlib
+ matplotlib.use('Agg') # Non-interactive backend
+ except ImportError:
+ print("[Warning] matplotlib not available, skipping plots")
+ # Save as text instead
+ with open(os.path.join(output_dir, "learning_trend.txt"), "w") as f:
+ f.write("Session,Correlation,Avg_Z_Norm\n")
+ for s, c, z in zip(sessions, correlations, avg_z_norms):
+ f.write(f"{s},{c:.4f},{z:.6f}\n")
+ print(f"[Data] Saved to: {os.path.join(output_dir, 'learning_trend.txt')}")
+ return
+
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5))
+
+ # Plot 1: Correlation vs Session
+ ax1 = axes[0]
+ ax1.plot(sessions, correlations, 'o-', color='#2ecc71', linewidth=2, markersize=8)
+ ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+ # Add trend line
+ from scipy.stats import linregress
+ slope, intercept, _, _, _ = linregress(sessions, correlations)
+ trend_line = [slope * s + intercept for s in sessions]
+ ax1.plot(sessions, trend_line, '--', color='#27ae60', alpha=0.7, label=f'Trend (slope={slope:.3f})')
+
+ ax1.set_xlabel('Sessions (Cumulative: 1→N)', fontsize=12)
+ ax1.set_ylabel('Spearman Correlation', fontsize=12)
+ ax1.set_title('Learned vs Ground-Truth Similarity\nCorrelation with Cumulative Data', fontsize=14)
+ ax1.set_xticks(sessions)
+ ax1.legend()
+ ax1.grid(True, alpha=0.3)
+ ax1.set_ylim(-0.5, 1.0)
+
+ # Plot 2: z_u norm vs Session
+ ax2 = axes[1]
+ ax2.plot(sessions, avg_z_norms, 's-', color='#3498db', linewidth=2, markersize=8)
+
+ # Add trend line
+ slope_z, intercept_z, _, _, _ = linregress(sessions, avg_z_norms)
+ trend_line_z = [slope_z * s + intercept_z for s in sessions]
+ ax2.plot(sessions, trend_line_z, '--', color='#2980b9', alpha=0.7, label=f'Trend (slope={slope_z:.5f})')
+
+ ax2.set_xlabel('Session (End of)', fontsize=12)
+ ax2.set_ylabel('Average ||z_u||', fontsize=12)
+ ax2.set_title('User Vector Norm\n(Cumulative Learning)', fontsize=14)
+ ax2.set_xticks(sessions)
+ ax2.legend()
+ ax2.grid(True, alpha=0.3)
+
+ plt.tight_layout()
+
+ output_path = os.path.join(output_dir, "learning_trend.png")
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
+ print(f"[Plot] Saved to: {output_path}")
+
+ # Also save as PDF for paper
+ pdf_path = os.path.join(output_dir, "learning_trend.pdf")
+ plt.savefig(pdf_path, bbox_inches='tight')
+ print(f"[Plot] Saved to: {pdf_path}")
+
+
+# Need this import at top level for trend calculation
+from scipy.stats import spearmanr
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Analyze Learning Trend")
+ parser.add_argument("--logs", type=str, required=True, help="Path to log file")
+ parser.add_argument("--user-store", type=str, default="data/users/user_store_pilot_v4_full-greedy.npz",
+ help="Path to user store with final z vectors")
+ parser.add_argument("--output-dir", type=str, default="data/analysis", help="Output directory")
+ args = parser.parse_args()
+
+ analyze_learning_trend(args.logs, args.output_dir, args.user_store)
+
+
+if __name__ == "__main__":
+ main()
+