diff --git a/train_grpo.sh b/train_grpo.sh index 6798f94..0b69431 100644 --- a/train_grpo.sh +++ b/train_grpo.sh @@ -3,8 +3,8 @@ export DATA_DIR='data/nq_search' WAND_PROJECT='Search-R1' -export BASE_MODEL='meta-llama/Llama-3.2-3B' -export EXPERIMENT_NAME=nq-search-r1-grpo-llama3.2-3b-em +# export BASE_MODEL='meta-llama/Llama-3.2-3B' +# export EXPERIMENT_NAME=nq-search-r1-grpo-llama3.2-3b-em # export BASE_MODEL='meta-llama/Llama-3.2-3B-Instruct' # export EXPERIMENT_NAME=nq-search-r1-grpo-llama3.2-3b-it-em # export BASE_MODEL='meta-llama/Llama-3.1-8B' @@ -12,8 +12,8 @@ export EXPERIMENT_NAME=nq-search-r1-grpo-llama3.2-3b-em # export BASE_MODEL='meta-llama/Llama-3.1-8B-Instruct' # export EXPERIMENT_NAME=nq-search-r1-grpo-llama3.1-8b-it-em -# export BASE_MODEL='Qwen/Qwen2.5-3B' -# export EXPERIMENT_NAME=nq-search-r1-grpo-qwen2.5-3b-em +export BASE_MODEL='Qwen/Qwen2.5-3B' +export EXPERIMENT_NAME=nq-search-r1-grpo-qwen2.5-3b-em # export BASE_MODEL='Qwen/Qwen2.5-3B-Instruct' # export EXPERIMENT_NAME=nq-search-r1-grpo-qwen2.5-3b-it-em # export BASE_MODEL='Qwen/Qwen2.5-7B' diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 28e8b3d..ae19265 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -739,8 +739,9 @@ class RayPPOTrainer(object): output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output) final_gen_batch_output = final_gen_batch_output.union(output) - batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], - dtype=object) + # batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], + # dtype=object) + batch.non_tensor_batch['uid'] = batch.non_tensor_batch['index'].copy() # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)