summaryrefslogtreecommitdiff
path: root/scripts/run_all_methods.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/run_all_methods.py')
-rw-r--r--scripts/run_all_methods.py34
1 files changed, 22 insertions, 12 deletions
diff --git a/scripts/run_all_methods.py b/scripts/run_all_methods.py
index 66b9cb7..b240645 100644
--- a/scripts/run_all_methods.py
+++ b/scripts/run_all_methods.py
@@ -84,10 +84,11 @@ def generate_greedy(wrapper, prompt, max_new_tokens=512, min_new_tokens=128):
class MethodRunner:
"""Encapsulates running a single method across all examples."""
- def __init__(self, wrapper, device, dense_retriever=None):
+ def __init__(self, wrapper, device, dense_retriever=None, uph_d=64):
self.wrapper = wrapper
self.device = device
self.dense_retriever = dense_retriever
+ self.uph_d = uph_d
def run(self, method_name, examples, support_sets, references, support_texts, N):
dispatch = {
@@ -219,8 +220,10 @@ class MethodRunner:
return per_user
def _run_uph(self, examples, support_sets, references, support_texts, N):
+ d = self.uph_d
H = self.wrapper.hidden_size
- uncond = UnconditionalHead(H, d=64, alpha=0.1, basis_seed=42).to(self.device)
+ uncond = UnconditionalHead(H, d=d, alpha=0.1, basis_seed=42).to(self.device)
+ print(f" UPH d={d}, params={d}, bytes={d*2}")
lm_head_bias = None
if hasattr(self.wrapper.model.lm_head, 'bias') and self.wrapper.model.lm_head.bias is not None:
lm_head_bias = self.wrapper.model.lm_head.bias.data
@@ -238,7 +241,7 @@ class MethodRunner:
lm_head_weight=self.wrapper.lm_head_weight,
lm_head_bias=lm_head_bias,
head_module=uncond,
- d=64, lr=0.05, steps=30, beta=0.05, lam=1e-4,
+ d=d, lr=0.05, steps=30, beta=0.05, lam=1e-4,
max_grad_norm=5.0, device=self.device,
)
prompt = build_query_prompt(ex['query_input'], ex['task'])
@@ -319,6 +322,7 @@ def main():
help='Comma-separated methods or "all"')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--d', type=int, default=64, help='UPH theta dimension')
parser.add_argument('--output_dir', type=str, default='outputs/unified')
args = parser.parse_args()
@@ -353,7 +357,7 @@ def main():
print(f"Loading model on {args.device}...")
wrapper = QwenWrapper('Qwen/Qwen2.5-1.5B-Instruct', device=args.device)
- runner = MethodRunner(wrapper, args.device)
+ runner = MethodRunner(wrapper, args.device, uph_d=args.d)
all_per_user = {}
for method in methods:
@@ -393,11 +397,14 @@ def main():
# Save per-method data in separate directories
# Structure: output_dir/task_setting_K{K}/{method}/per_user.json
+ # For d ablation: uph method saved as uph_d{d}
exp_dir = os.path.join(args.output_dir, f"{task}_{setting}_K{K}")
os.makedirs(exp_dir, exist_ok=True)
for method in methods:
- method_dir = os.path.join(exp_dir, method)
+ # Use uph_d{d} as directory name for d ablation
+ method_label = f"uph_d{args.d}" if method == 'uph' and args.d != 64 else method
+ method_dir = os.path.join(exp_dir, method_label)
os.makedirs(method_dir, exist_ok=True)
pu = all_per_user[method]
@@ -408,14 +415,17 @@ def main():
'avg_len': float(np.mean([u['metrics']['length'] for u in pu])),
}
+ save_meta = {
+ 'per_user': pu,
+ 'aggregate': agg_m,
+ 'num_examples': N, 'task': task, 'setting': setting, 'K': K,
+ 'method': method_label,
+ 'decode_policy': 'greedy, min=128, max=512',
+ }
+ if method == 'uph':
+ save_meta['d'] = args.d
with open(os.path.join(method_dir, 'per_user.json'), 'w') as f:
- json.dump({
- 'per_user': pu,
- 'aggregate': agg_m,
- 'num_examples': N, 'task': task, 'setting': setting, 'K': K,
- 'method': method,
- 'decode_policy': 'greedy, min=128, max=512',
- }, f, indent=2, default=str)
+ json.dump(save_meta, f, indent=2, default=str)
print(f" Saved: {method_dir}/per_user.json")