diff options
Diffstat (limited to 'collaborativeagents/training/grpo_verl')
6 files changed, 1221 insertions, 0 deletions
diff --git a/collaborativeagents/training/grpo_verl/outputs/2026-01-11/03-50-42/.hydra/config.yaml b/collaborativeagents/training/grpo_verl/outputs/2026-01-11/03-50-42/.hydra/config.yaml new file mode 100644 index 0000000..7af183b --- /dev/null +++ b/collaborativeagents/training/grpo_verl/outputs/2026-01-11/03-50-42/.hydra/config.yaml @@ -0,0 +1,626 @@ +actor_rollout_ref: + actor: + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-06 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + warmup_style: null + override_optimizer_config: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: bfloat16 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: false + strategy: fsdp + dtype: bfloat16 + _target_: verl.workers.config.FSDPActorConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + ppo_mini_batch_size: 8 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: false + ppo_max_token_len_per_gpu: 16384 + clip_ratio: 0.2 + clip_ratio_low: 0.2 + clip_ratio_high: 0.2 + tau_pos: 1.0 + tau_neg: 1.05 + freeze_vision_tower: false + policy_loss: + _target_: verl.workers.config.PolicyLossConfig + loss_mode: vanilla + clip_cov_ratio: 0.0002 + clip_cov_lb: 1.0 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + ppo_kl_coef: 0.1 + clip_ratio_c: 3.0 + loss_agg_mode: token-mean + loss_scale_factor: null + entropy_coeff: 0 + calculate_entropy: false + use_kl_loss: true + use_torch_compile: true + kl_loss_coef: 0.003 + kl_loss_type: low_var_kl + ppo_epochs: 1 + shuffle: false + data_loader_seed: 42 + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + use_fused_kernels: ${oc.select:actor_rollout_ref.model.use_fused_kernels,false} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + step_start: 0 + step_end: null + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + grad_clip: 1.0 + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} + ref: + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: ${actor_rollout_ref.actor.strategy} + use_torch_compile: ${oc.select:actor_rollout_ref.actor.use_torch_compile,true} + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + step_start: 0 + step_end: null + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + router_replay: + _target_: verl.workers.config.RouterReplayConfig + mode: disabled + record_file: null + replay_file: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: true + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: bfloat16 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: true + strategy: fsdp + dtype: bfloat16 + _target_: verl.workers.config.FSDPActorConfig + ulysses_sequence_parallel_size: ${oc.select:actor_rollout_ref.actor.ulysses_sequence_parallel_size,1} + entropy_from_logits_with_chunking: false + entropy_checkpointing: false + rollout: + _target_: verl.workers.config.RolloutConfig + name: vllm + mode: async + temperature: 0.9 + top_k: -1 + top_p: 0.9 + prompt_length: ${oc.select:data.max_prompt_length,512} + response_length: ${oc.select:data.max_response_length,512} + dtype: bfloat16 + gpu_memory_utilization: 0.5 + ignore_eos: false + enforce_eager: false + cudagraph_capture_sizes: null + free_cache_engine: true + tensor_model_parallel_size: 1 + data_parallel_size: 1 + expert_parallel_size: 1 + pipeline_model_parallel_size: 1 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + enable_chunked_prefill: true + enable_prefix_caching: true + logprobs_mode: processed_logprobs + load_format: dummy + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + log_prob_max_token_len_per_gpu: ${oc.select:actor_rollout_ref.actor.ppo_max_token_len_per_gpu,16384} + disable_log_stats: true + do_sample: true + 'n': 8 + over_sample_rate: 0 + multi_stage_wake_up: false + engine_kwargs: + vllm: {} + sglang: {} + val_kwargs: + _target_: verl.workers.config.SamplingConfig + top_k: -1 + top_p: 1.0 + temperature: 0 + 'n': 1 + do_sample: false + multi_turn: + _target_: verl.workers.config.MultiTurnConfig + enable: false + max_assistant_turns: null + tool_config_path: null + max_user_turns: null + max_parallel_calls: 1 + max_tool_response_length: 256 + tool_response_truncate_side: middle + interaction_config_path: null + use_inference_chat_template: false + tokenization_sanity_check_mode: strict + format: hermes + num_repeat_rollouts: null + calculate_log_probs: false + agent: + _target_: verl.workers.config.AgentLoopConfig + num_workers: 8 + default_agent_loop: single_turn_agent + agent_loop_config_path: null + custom_async_server: + _target_: verl.workers.config.CustomAsyncServerConfig + path: null + name: null + update_weights_bucket_megabytes: 512 + trace: + _target_: verl.workers.config.TraceConfig + backend: null + token2text: false + max_samples_per_step_per_worker: null + skip_rollout: false + skip_dump_dir: /tmp/rollout_dump + skip_tokenizer_init: true + enable_rollout_routing_replay: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: ${oc.select:actor_rollout_ref.actor.profiler.enable,false} + all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false} + ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]} + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + prometheus: + _target_: verl.workers.config.PrometheusConfig + enable: false + port: 9090 + file: /tmp/ray/session_latest/metrics/prometheus/prometheus.yml + served_model_name: ${oc.select:actor_rollout_ref.model.path,null} + quantization: null + quantization_config_file: null + layered_summon: false + model: + _target_: verl.workers.config.HFModelConfig + path: /work/nvme/bfqt/yurenh2/sft_checkpoints/checkpoint-200 + hf_config_path: null + tokenizer_path: null + use_shm: false + trust_remote_code: false + custom_chat_template: null + external_lib: null + override_config: {} + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: true + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + exclude_modules: null + lora_adapter_path: null + use_liger: false + use_fused_kernels: false + fused_kernel_options: + impl_backend: torch + tiled_mlp: + enabled: false + num_shards: 4 + hybrid_engine: true + nccl_timeout: 600 +data: + tokenizer: null + use_shm: false + train_files: /projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/data/session_level_reflection_grpo_train.parquet + val_files: /projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/data/session_level_reflection_grpo_train.parquet + train_max_samples: -1 + val_max_samples: -1 + prompt_key: prompt + reward_fn_key: data_source + max_prompt_length: 2048 + max_response_length: 1024 + train_batch_size: 64 + val_batch_size: null + tool_config_path: ${oc.select:actor_rollout_ref.rollout.multi_turn.tool_config_path, + null} + return_raw_input_ids: false + return_raw_chat: true + return_full_prompt: false + shuffle: true + seed: null + dataloader_num_workers: 8 + image_patch_size: 14 + validation_shuffle: false + filter_overlong_prompts: true + filter_overlong_prompts_workers: 1 + truncation: error + image_key: images + video_key: videos + trust_remote_code: false + custom_cls: + path: null + name: null + return_multi_modal_inputs: true + sampler: + class_path: null + class_name: null + datagen: + path: null + name: null + apply_chat_template_kwargs: {} +reward_manager: + _target_: verl.trainer.config.config.RewardManagerConfig + source: register + name: ${oc.select:reward_model.reward_manager,naive} + module: + _target_: verl.trainer.config.config.ModuleConfig + path: null + name: custom_reward_manager +critic: + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + optimizer: AdamW + optimizer_impl: torch.optim + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + total_training_steps: -1 + weight_decay: 0.01 + lr_warmup_steps: -1 + betas: + - 0.9 + - 0.999 + clip_grad: 1.0 + min_lr_ratio: 0.0 + num_cycles: 0.5 + lr_scheduler_type: constant + warmup_style: null + override_optimizer_config: null + model: + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + optimizer_offload: false + offload_policy: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + model_dtype: fp32 + use_orig_params: false + seed: 42 + full_determinism: false + ulysses_sequence_parallel_size: 1 + entropy_from_logits_with_chunking: false + use_torch_compile: true + entropy_checkpointing: false + forward_only: false + strategy: fsdp + dtype: bfloat16 + path: ~/models/deepseek-llm-7b-chat + tokenizer_path: ${oc.select:actor_rollout_ref.model.path,"~/models/deepseek-llm-7b-chat"} + override_config: {} + external_lib: ${oc.select:actor_rollout_ref.model.external_lib,null} + trust_remote_code: ${oc.select:actor_rollout_ref.model.trust_remote_code,false} + _target_: verl.workers.config.FSDPCriticModelCfg + use_shm: false + enable_gradient_checkpointing: true + enable_activation_offload: false + use_remove_padding: false + lora_rank: 0 + lora_alpha: 16 + target_modules: all-linear + tiled_mlp: + enabled: false + num_shards: 4 + _target_: verl.workers.config.FSDPCriticConfig + rollout_n: ${oc.select:actor_rollout_ref.rollout.n,1} + strategy: fsdp + enable: null + ppo_mini_batch_size: ${oc.select:actor_rollout_ref.actor.ppo_mini_batch_size,256} + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size,null} + use_dynamic_bsz: ${oc.select:actor_rollout_ref.actor.use_dynamic_bsz,false} + ppo_max_token_len_per_gpu: 32768 + forward_max_token_len_per_gpu: ${.ppo_max_token_len_per_gpu} + ppo_epochs: ${oc.select:actor_rollout_ref.actor.ppo_epochs,1} + shuffle: ${oc.select:actor_rollout_ref.actor.shuffle,false} + data_loader_seed: 42 + cliprange_value: 0.5 + loss_agg_mode: ${oc.select:actor_rollout_ref.actor.loss_agg_mode,token-mean} + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + save_contents: + - model + - optimizer + - extra + load_contents: ${.save_contents} + async_save: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete} + npu: + _target_: verl.utils.profiler.config.NPUToolConfig + contents: [] + level: level0 + analysis: true + discrete: false + torch: + _target_: verl.utils.profiler.config.TorchProfilerToolConfig + step_start: 0 + step_end: null + torch_memory: + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + trace_alloc_max_entries: ${oc.select:global_profiler.global_tool_config.torch_memory.trace_alloc_max_entries,100000} + stack_depth: ${oc.select:global_profiler.global_tool_config.torch_memory.stack_depth,32} + forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null} + forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null} + ulysses_sequence_parallel_size: 1 + grad_clip: 1.0 +reward_model: + enable: false + enable_resource_pool: false + n_gpus_per_node: 8 + nnodes: 0 + strategy: fsdp + model: + input_tokenizer: ${actor_rollout_ref.model.path} + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + external_lib: ${actor_rollout_ref.model.external_lib} + trust_remote_code: false + override_config: {} + use_shm: false + use_remove_padding: false + use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels} + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + wrap_policy: + min_num_params: 0 + param_offload: false + reshard_after_forward: true + fsdp_size: -1 + forward_prefetch: false + micro_batch_size: null + micro_batch_size_per_gpu: null + max_length: null + use_dynamic_bsz: ${critic.use_dynamic_bsz} + forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu} + reward_manager: naive + reward_loop_source: register + reward_loop_module_path: null + reward_loop_class_name: null + launch_reward_fn_async: false + sandbox_fusion: + url: null + max_concurrent: 64 + memory_limit_mb: 1024 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: ${oc.select:global_profiler.tool,null} + enable: false + all_ranks: false + ranks: [] + save_path: ${oc.select:global_profiler.save_path,null} + tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null} + ulysses_sequence_parallel_size: 1 + use_reward_loop: true + num_workers: 1 + rollout: + _target_: verl.workers.config.RolloutConfig + name: ??? + dtype: bfloat16 + gpu_memory_utilization: 0.5 + enforce_eager: true + cudagraph_capture_sizes: null + free_cache_engine: true + data_parallel_size: 1 + expert_parallel_size: 1 + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_model_len: null + max_num_seqs: 1024 + load_format: auto + engine_kwargs: {} + limit_images: null + enable_chunked_prefill: true + enable_prefix_caching: true + disable_log_stats: true + skip_tokenizer_init: false + prompt_length: 2048 + response_length: 2048 +algorithm: + rollout_correction: + rollout_is: null + rollout_is_threshold: 2.0 + rollout_rs: null + rollout_rs_threshold: null + rollout_rs_threshold_lower: null + rollout_token_veto_threshold: null + bypass_mode: false + loss_type: ppo_clip + rollout_is_batch_normalize: false + _target_: verl.trainer.config.AlgoConfig + gamma: 1.0 + lam: 1.0 + adv_estimator: grpo + norm_adv_by_std_in_grpo: true + use_kl_in_reward: false + kl_penalty: kl + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + type: fixed + kl_coef: 0.001 + horizon: 10000 + target_kl: 0.1 + use_pf_ppo: false + pf_ppo: + reweight_method: pow + weight_pow: 2.0 +custom_reward_function: + path: /projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/verl_reward_functions.py + name: compute_score +trainer: + balance_batch: true + total_epochs: 1 + total_training_steps: null + project_name: collaborative-agent-reflection-grpo + experiment_name: llama3.1-8b-grpo + logger: + - console + log_val_generations: 0 + rollout_data_dir: null + validation_data_dir: null + nnodes: 1 + n_gpus_per_node: 2 + save_freq: 50 + esi_redundant_time: 0 + resume_mode: auto + resume_from_path: null + val_before_train: false + val_only: false + test_freq: 100 + critic_warmup: 0 + default_hdfs_dir: null + del_local_ckpt_after_load: false + default_local_dir: /scratch/bfqt/yurenh2/grpo_outputs + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + ray_wait_register_center_timeout: 300 + device: cuda + use_legacy_worker_impl: auto +global_profiler: + _target_: verl.utils.profiler.ProfilerConfig + tool: null + steps: null + profile_continuous_steps: false + save_path: outputs/profile + global_tool_config: + nsys: + _target_: verl.utils.profiler.config.NsightToolConfig + discrete: false + controller_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + worker_nsight_options: + trace: cuda,nvtx,cublas,ucx + cuda-memory-usage: 'true' + cuda-graph-trace: graph + capture-range: cudaProfilerApi + capture-range-end: null + kill: none + torch_memory: + trace_alloc_max_entries: 100000 + stack_depth: 32 + context: all + stacks: all + kw_args: {} +transfer_queue: + enable: false +ray_kwargs: + ray_init: + num_cpus: null + timeline_json_file: null diff --git a/collaborativeagents/training/grpo_verl/outputs/2026-01-11/03-50-42/.hydra/hydra.yaml b/collaborativeagents/training/grpo_verl/outputs/2026-01-11/03-50-42/.hydra/hydra.yaml new file mode 100644 index 0000000..8e4c4ec --- /dev/null +++ b/collaborativeagents/training/grpo_verl/outputs/2026-01-11/03-50-42/.hydra/hydra.yaml @@ -0,0 +1,214 @@ +hydra: + run: + dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} + launcher: + _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher + sweeper: + _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper + max_batch_size: null + params: null + help: + app_name: ${hydra.job.name} + header: '${hydra.help.app_name} is powered by Hydra. + + ' + footer: 'Powered by Hydra (https://hydra.cc) + + Use --hydra-help to view Hydra specific help + + ' + template: '${hydra.help.header} + + == Configuration groups == + + Compose your configuration from those groups (group=option) + + + $APP_CONFIG_GROUPS + + + == Config == + + Override anything in the config (foo.bar=value) + + + $CONFIG + + + ${hydra.help.footer} + + ' + hydra_help: + template: 'Hydra (${hydra.runtime.version}) + + See https://hydra.cc for more info. + + + == Flags == + + $FLAGS_HELP + + + == Configuration groups == + + Compose your configuration from those groups (For example, append hydra/job_logging=disabled + to command line) + + + $HYDRA_CONFIG_GROUPS + + + Use ''--cfg hydra'' to Show the Hydra config. + + ' + hydra_help: ??? + hydra_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][HYDRA] %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + root: + level: INFO + handlers: + - console + loggers: + logging_example: + level: DEBUG + disable_existing_loggers: false + job_logging: + version: 1 + formatters: + simple: + format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' + handlers: + console: + class: logging.StreamHandler + formatter: simple + stream: ext://sys.stdout + file: + class: logging.FileHandler + formatter: simple + filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log + root: + level: INFO + handlers: + - console + - file + disable_existing_loggers: false + env: {} + mode: RUN + searchpath: [] + callbacks: {} + output_subdir: .hydra + overrides: + hydra: + - hydra.mode=RUN + task: + - algorithm.adv_estimator=grpo + - data.train_files=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/data/session_level_reflection_grpo_train.parquet + - data.val_files=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/data/session_level_reflection_grpo_train.parquet + - data.train_batch_size=64 + - data.max_prompt_length=2048 + - data.max_response_length=1024 + - data.filter_overlong_prompts=True + - data.truncation=error + - data.prompt_key=prompt + - data.reward_fn_key=data_source + - actor_rollout_ref.model.path=/work/nvme/bfqt/yurenh2/sft_checkpoints/checkpoint-200 + - actor_rollout_ref.actor.optim.lr=1e-6 + - actor_rollout_ref.model.use_remove_padding=True + - actor_rollout_ref.actor.ppo_mini_batch_size=8 + - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 + - actor_rollout_ref.actor.use_kl_loss=True + - actor_rollout_ref.actor.kl_loss_coef=0.003 + - actor_rollout_ref.actor.kl_loss_type=low_var_kl + - actor_rollout_ref.actor.entropy_coeff=0 + - actor_rollout_ref.model.enable_gradient_checkpointing=True + - actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 + - actor_rollout_ref.actor.fsdp_config.param_offload=False + - actor_rollout_ref.actor.fsdp_config.optimizer_offload=False + - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 + - actor_rollout_ref.rollout.tensor_model_parallel_size=1 + - actor_rollout_ref.rollout.name=vllm + - actor_rollout_ref.rollout.gpu_memory_utilization=0.5 + - actor_rollout_ref.rollout.n=8 + - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 + - actor_rollout_ref.ref.fsdp_config.model_dtype=bfloat16 + - actor_rollout_ref.ref.fsdp_config.param_offload=True + - actor_rollout_ref.rollout.temperature=0.9 + - actor_rollout_ref.rollout.top_p=0.9 + - custom_reward_function.path=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/verl_reward_functions.py + - custom_reward_function.name=compute_score + - algorithm.use_kl_in_reward=False + - trainer.critic_warmup=0 + - trainer.val_before_train=False + - trainer.logger=["console"] + - trainer.project_name=collaborative-agent-reflection-grpo + - trainer.experiment_name=llama3.1-8b-grpo + - trainer.n_gpus_per_node=2 + - trainer.nnodes=1 + - trainer.save_freq=50 + - trainer.test_freq=100 + - trainer.total_epochs=1 + - trainer.default_local_dir=/scratch/bfqt/yurenh2/grpo_outputs + job: + name: main_ppo + chdir: null + override_dirname: actor_rollout_ref.actor.entropy_coeff=0,actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16,actor_rollout_ref.actor.fsdp_config.optimizer_offload=False,actor_rollout_ref.actor.fsdp_config.param_offload=False,actor_rollout_ref.actor.kl_loss_coef=0.003,actor_rollout_ref.actor.kl_loss_type=low_var_kl,actor_rollout_ref.actor.optim.lr=1e-6,actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4,actor_rollout_ref.actor.ppo_mini_batch_size=8,actor_rollout_ref.actor.use_kl_loss=True,actor_rollout_ref.model.enable_gradient_checkpointing=True,actor_rollout_ref.model.path=/work/nvme/bfqt/yurenh2/sft_checkpoints/checkpoint-200,actor_rollout_ref.model.use_remove_padding=True,actor_rollout_ref.ref.fsdp_config.model_dtype=bfloat16,actor_rollout_ref.ref.fsdp_config.param_offload=True,actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4,actor_rollout_ref.rollout.gpu_memory_utilization=0.5,actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4,actor_rollout_ref.rollout.n=8,actor_rollout_ref.rollout.name=vllm,actor_rollout_ref.rollout.temperature=0.9,actor_rollout_ref.rollout.tensor_model_parallel_size=1,actor_rollout_ref.rollout.top_p=0.9,algorithm.adv_estimator=grpo,algorithm.use_kl_in_reward=False,custom_reward_function.name=compute_score,custom_reward_function.path=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/verl_reward_functions.py,data.filter_overlong_prompts=True,data.max_prompt_length=2048,data.max_response_length=1024,data.prompt_key=prompt,data.reward_fn_key=data_source,data.train_batch_size=64,data.train_files=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/data/session_level_reflection_grpo_train.parquet,data.truncation=error,data.val_files=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/data/session_level_reflection_grpo_train.parquet,trainer.critic_warmup=0,trainer.default_local_dir=/scratch/bfqt/yurenh2/grpo_outputs,trainer.experiment_name=llama3.1-8b-grpo,trainer.logger=["console"],trainer.n_gpus_per_node=2,trainer.nnodes=1,trainer.project_name=collaborative-agent-reflection-grpo,trainer.save_freq=50,trainer.test_freq=100,trainer.total_epochs=1,trainer.val_before_train=False + id: ??? + num: ??? + config_name: ppo_trainer + env_set: {} + env_copy: [] + config: + override_dirname: + kv_sep: '=' + item_sep: ',' + exclude_keys: [] + runtime: + version: 1.3.2 + version_base: '1.3' + cwd: /projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl + config_sources: + - path: hydra.conf + schema: pkg + provider: hydra + - path: /u/yurenh2/miniforge3/envs/eval/lib/python3.11/site-packages/verl/trainer/config + schema: file + provider: main + - path: '' + schema: structured + provider: schema + output_dir: /projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/outputs/2026-01-11/03-50-42 + choices: + algorithm@algorithm.rollout_correction: rollout_correction + reward_model: dp_reward_loop + critic: dp_critic + critic/../engine@critic.model.fsdp_config: fsdp + critic/../optim@critic.optim: fsdp + model@actor_rollout_ref.model: hf_model + rollout@actor_rollout_ref.rollout: rollout + ref@actor_rollout_ref.ref: dp_ref + ref/../engine@actor_rollout_ref.ref.fsdp_config: fsdp + data: legacy_data + actor@actor_rollout_ref.actor: dp_actor + actor/../engine@actor_rollout_ref.actor.fsdp_config: fsdp + actor/../optim@actor_rollout_ref.actor.optim: fsdp + hydra/env: default + hydra/callbacks: null + hydra/job_logging: default + hydra/hydra_logging: default + hydra/hydra_help: default + hydra/help: default + hydra/sweeper: basic + hydra/launcher: basic + hydra/output: default + verbose: false diff --git a/collaborativeagents/training/grpo_verl/outputs/2026-01-11/03-50-42/.hydra/overrides.yaml b/collaborativeagents/training/grpo_verl/outputs/2026-01-11/03-50-42/.hydra/overrides.yaml new file mode 100644 index 0000000..8c6f184 --- /dev/null +++ b/collaborativeagents/training/grpo_verl/outputs/2026-01-11/03-50-42/.hydra/overrides.yaml @@ -0,0 +1,47 @@ +- algorithm.adv_estimator=grpo +- data.train_files=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/data/session_level_reflection_grpo_train.parquet +- data.val_files=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/data/session_level_reflection_grpo_train.parquet +- data.train_batch_size=64 +- data.max_prompt_length=2048 +- data.max_response_length=1024 +- data.filter_overlong_prompts=True +- data.truncation=error +- data.prompt_key=prompt +- data.reward_fn_key=data_source +- actor_rollout_ref.model.path=/work/nvme/bfqt/yurenh2/sft_checkpoints/checkpoint-200 +- actor_rollout_ref.actor.optim.lr=1e-6 +- actor_rollout_ref.model.use_remove_padding=True +- actor_rollout_ref.actor.ppo_mini_batch_size=8 +- actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 +- actor_rollout_ref.actor.use_kl_loss=True +- actor_rollout_ref.actor.kl_loss_coef=0.003 +- actor_rollout_ref.actor.kl_loss_type=low_var_kl +- actor_rollout_ref.actor.entropy_coeff=0 +- actor_rollout_ref.model.enable_gradient_checkpointing=True +- actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 +- actor_rollout_ref.actor.fsdp_config.param_offload=False +- actor_rollout_ref.actor.fsdp_config.optimizer_offload=False +- actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 +- actor_rollout_ref.rollout.tensor_model_parallel_size=1 +- actor_rollout_ref.rollout.name=vllm +- actor_rollout_ref.rollout.gpu_memory_utilization=0.5 +- actor_rollout_ref.rollout.n=8 +- actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 +- actor_rollout_ref.ref.fsdp_config.model_dtype=bfloat16 +- actor_rollout_ref.ref.fsdp_config.param_offload=True +- actor_rollout_ref.rollout.temperature=0.9 +- actor_rollout_ref.rollout.top_p=0.9 +- custom_reward_function.path=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/verl_reward_functions.py +- custom_reward_function.name=compute_score +- algorithm.use_kl_in_reward=False +- trainer.critic_warmup=0 +- trainer.val_before_train=False +- trainer.logger=["console"] +- trainer.project_name=collaborative-agent-reflection-grpo +- trainer.experiment_name=llama3.1-8b-grpo +- trainer.n_gpus_per_node=2 +- trainer.nnodes=1 +- trainer.save_freq=50 +- trainer.test_freq=100 +- trainer.total_epochs=1 +- trainer.default_local_dir=/scratch/bfqt/yurenh2/grpo_outputs diff --git a/collaborativeagents/training/grpo_verl/run_grpo.sbatch b/collaborativeagents/training/grpo_verl/run_grpo.sbatch new file mode 100644 index 0000000..e22b221 --- /dev/null +++ b/collaborativeagents/training/grpo_verl/run_grpo.sbatch @@ -0,0 +1,111 @@ +#!/bin/bash +#SBATCH --job-name=grpo_train +#SBATCH --account=bfqt-delta-gpu +#SBATCH --partition=gpuH200x8 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=32 +#SBATCH --gres=gpu:4 +#SBATCH --mem=256G +#SBATCH --time=12:00:00 +#SBATCH --output=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/grpo_%j.out +#SBATCH --error=/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/grpo_%j.err + +echo "=== GRPO Training with VERL ===" +date +nvidia-smi --query-gpu=index,name,memory.total --format=csv + +cd /projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl +source /u/yurenh2/miniforge3/etc/profile.d/conda.sh +conda activate eval + +export HF_HOME=/projects/bfqt/users/yurenh2/hf_cache/huggingface +export PYTHONPATH="/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model:$PYTHONPATH" +export WANDB_MODE=offline + +# Paths +TRAIN_DATA="/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/data/session_level_reflection_grpo_train.parquet" +MODEL_PATH="/work/nvme/bfqt/yurenh2/sft_checkpoints/checkpoint-200" +REWARD_FN="/projects/bfqt/users/yurenh2/ml-projects/personalization-user-model/collaborativeagents/training/grpo_verl/verl_reward_functions.py" +OUTPUT_DIR="/scratch/bfqt/yurenh2/grpo_outputs" + +mkdir -p $OUTPUT_DIR + +# Start 70B judge model for reward evaluation on GPUs 0,1 +echo "Starting 70B judge model on GPUs 0,1..." +CUDA_VISIBLE_DEVICES=0,1 python -m vllm.entrypoints.openai.api_server \ + --model meta-llama/Llama-3.1-70B-Instruct \ + --port 8004 --tensor-parallel-size 2 --gpu-memory-utilization 0.85 \ + --max-model-len 4096 --dtype bfloat16 --download-dir $HF_HOME & + +# Wait for judge model +echo "Waiting for judge model..." +for i in {1..200}; do + if curl -s http://localhost:8004/health > /dev/null 2>&1; then + echo "Judge model ready after $((i*5)) seconds" + break + fi + sleep 5 +done + +echo "" +echo "Starting GRPO training..." +echo "Model: $MODEL_PATH" +echo "Data: $TRAIN_DATA" +echo "Output: $OUTPUT_DIR" + +# GRPO training with VERL +CUDA_VISIBLE_DEVICES=2,3 python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="$TRAIN_DATA" \ + data.val_files="$TRAIN_DATA" \ + data.train_batch_size=64 \ + data.max_prompt_length=2048 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.prompt_key=prompt \ + data.reward_fn_key=data_source \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=8 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.003 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.ref.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.temperature=0.9 \ + actor_rollout_ref.rollout.top_p=0.9 \ + custom_reward_function.path=$REWARD_FN \ + custom_reward_function.name=compute_score \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.val_before_train=False \ + trainer.logger='["console"]' \ + trainer.project_name='collaborative-agent-reflection-grpo' \ + trainer.experiment_name='llama3.1-8b-grpo' \ + trainer.n_gpus_per_node=2 \ + trainer.nnodes=1 \ + trainer.save_freq=50 \ + trainer.test_freq=100 \ + trainer.total_epochs=1 \ + trainer.default_local_dir=$OUTPUT_DIR + +pkill -f "vllm.entrypoints" 2>/dev/null || true + +echo "" +echo "GRPO Training complete!" +date diff --git a/collaborativeagents/training/grpo_verl/run_verl_grpo.sh b/collaborativeagents/training/grpo_verl/run_verl_grpo.sh new file mode 100644 index 0000000..ede35ab --- /dev/null +++ b/collaborativeagents/training/grpo_verl/run_verl_grpo.sh @@ -0,0 +1,63 @@ +#!/bin/bash +export PYTHONPATH="/shared/storage-01/users/mehri2/verl:$PYTHONPATH" +set -x +HYDRA_FULL_ERROR=1 + +train_data="/shared/storage-01/users/mehri2/mem/collaborativeagents/training/grpo_verl/data/session_level_reflection_grpo_train.parquet" +model_path="/shared/storage-01/users/mehri2/LLaMA-Factory/saves/llama-3.1-8b-instruct/full/sft_session_level_reflection/checkpoint-628" +reward_fn_path="/shared/storage-01/users/mehri2/mem/collaborativeagents/training/grpo_verl/verl_reward_functions.py" + +max_prompt_length=2048 +max_response_length=1024 +train_batch_size=8 +n_generations=8 +# Effective batch size is 64 + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="$train_data" \ + data.val_files="$train_data" \ + data.train_batch_size=$train_batch_size \ + data.max_prompt_length=$max_prompt_length \ + data.max_response_length=$max_response_length \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + data.prompt_key=prompt \ + data.reward_fn_key=data_source \ + actor_rollout_ref.model.path=$model_path \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=8 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.003 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=$n_generations \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.ref.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.rollout.temperature=0.9 \ + actor_rollout_ref.rollout.top_p=0.9 \ + custom_reward_function.path=$reward_fn_path \ + custom_reward_function.name=compute_score \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.val_before_train=False \ + trainer.logger='["console","wandb"]' \ + trainer.project_name='collaborative-agent-reflection-grpo' \ + trainer.experiment_name='llama3.1-8b-verl-grpo-v3' \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=50 \ + trainer.test_freq=100 \ + trainer.total_epochs=1 \ + trainer.default_local_dir=/shared/storage-01/users/mehri2/mem/collaborativeagents/training/grpo_verl/results/v3 $@
\ No newline at end of file diff --git a/collaborativeagents/training/grpo_verl/verl_reward_functions.py b/collaborativeagents/training/grpo_verl/verl_reward_functions.py new file mode 100644 index 0000000..fa38b8f --- /dev/null +++ b/collaborativeagents/training/grpo_verl/verl_reward_functions.py @@ -0,0 +1,160 @@ +""" +Custom reward functions for VERL GRPO training +Compatible with VERL's reward function signature +""" + +import openai +from json_repair import repair_json +import concurrent.futures +import time + +# Initialize judge client +client = openai.OpenAI(base_url="http://localhost:8004/v1", api_key="EMPTY") + +# Global tracker for reflection scores +reflection_scores_tracker = { + "scores": [], + "batch_count": 0 +} + + +def extract_json_answer(text: str) -> str: + """Extract agent_notes from JSON response""" + try: + answer = repair_json(text, return_objects=True) + answer = answer["agent_notes"] + except Exception as e: + print(f"Error extracting JSON answer: {e}") + return "" + return answer + + +def ask_judge(prompt, system_prompt=None, max_retries=3): + """Ask the judge model for evaluation""" + if system_prompt: + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt} + ] + else: + messages = [{"role": "user", "content": prompt}] + + for attempt in range(max_retries): + try: + chat_completion = client.chat.completions.create( + model="meta-llama/Llama-3.1-70B-Instruct", + messages=messages, + max_tokens=2048 + ) + return chat_completion.choices[0].message.content.strip() + except Exception as e: + if attempt < max_retries - 1: + time.sleep(1) + else: + raise e + + +def evaluate_reflection(completion, gold_response, responses_that_enforce_preferences): + """Evaluate a single reflection completion""" + + completion_text = extract_json_answer(completion) + if completion_text == "": + print(f"Poorly formatted completion: {completion}") + return 0 + + user_messages_where_they_enforce_preferences = "" + for i, response in enumerate(responses_that_enforce_preferences): + user_messages_where_they_enforce_preferences += f"User message #{i+1}: {response}\n" + + reflection_evaluation_prompt = f"""You are an expert evaluator analyzing a conversational agent's reflection of a conversation, where they analyze the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations. + +Throughout the conversation, the user explicitly enforces their preferences whenever necessary. The agent analyzes the conversation to identify the user's preferences and create actionable notes to help them satisfy these preferences in future conversations. + +# Your Task: +Evaluate whether the agent's reflection succesfully captures the user's preferences and provides actionable notes to help them satisfy these preferences in future conversations. + +# Agent's Reflection: +{completion_text} + +# User Messages Where They Enforce Their Preferences: +{user_messages_where_they_enforce_preferences} + +# Gold Reflection: +Here is a gold reflection for the same conversation. Use this as a reference to evaluate the agent's reflection. +{gold_response} + +# Evaluation Criteria: +Assess the reflection on four dimensions: +- **Coverage (Completeness):** Does the agent's reflection capture all of the user's preferences? +- **Actionability (Quality):** Does the agent's reflection provide actionable notes and details that help the agent satisfy these preferences in future conversations? +- **Accuracy (No Hallucination):** Are all points grounded in actual user statements? Does the reflection avoid inventing preferences or misrepresenting user statements? +- **Clarity:** Is the reflection well-organized and clearly formatted? Does the reflection avoid redundancy, with each preference stated once without repetitive or overlapping notes? + +You will output a score from 0-3, where: +- 0: Does not effectively capture user preferences: gaps in converage, or significant hallucinations +- 1: Captures some preferences with limited actionable notes, may hallucinate some preferences +- 2: Captures most preferences with actionable notes, may have some slight hallucinations +- 3: Comprehensively captures all preferences with highly actionable notes and no hallucinations + +# Output Format: +{{ + "reasoning": # Brief explanation of your decision + "reflection_score": # 0-3 +}} + +Output a properly formatted JSON response, as specified by the Output Format.""" + + reflection_score_response = ask_judge(reflection_evaluation_prompt) + reflection_score = repair_json(reflection_score_response, return_objects=True)["reflection_score"] + + print(f"Reflection Score: {reflection_score}") + return reflection_score + + +def soft_format_reward(solution_str): + """Check if the completion has JSON format with required fields""" + reward = 0.0 + try: + parsed_json = repair_json(solution_str, return_objects=True) + if "agent_notes" in parsed_json and "user_preferences_reasoning" in parsed_json: + reward = 0.5 + except Exception: + pass + + print(f"Soft Format Reward: {reward}") + return reward + + +# VERL reward function signature: (data_source, solution_str, ground_truth, extra_info) +def compute_score(data_source, solution_str, ground_truth, extra_info=None): + """ + Main reward function for VERL (named 'compute_score' for default VERL compatibility). + This matches the signature expected by VERL's reward managers. + + Args: + data_source: Source identifier for the data + solution_str: The model's generated completion + ground_truth: Not used directly, passed in extra_info + extra_info: Dictionary containing 'gold_response' and 'responses_that_enforce_preferences' + + Returns: + float: Combined reward score + """ + if extra_info is None: + print("Warning: extra_info is None") + return 0.0 + + gold_response = extra_info.get('gold_response', '') + responses_that_enforce_preferences = extra_info.get('responses_that_enforce_preferences', []) + + # Soft format reward + format_reward = soft_format_reward(solution_str) + # Reflection quality reward + reflection_score = evaluate_reflection(solution_str, gold_response, responses_that_enforce_preferences) + + total_reward = format_reward + reflection_score + + reflection_scores_tracker["scores"].append(reflection_score) + print(f"Total Reward: {total_reward} (Format: {format_reward}, Reflection: {reflection_score})") + + return total_reward
\ No newline at end of file |
