fix grpo id bug
This commit is contained in:
@@ -3,8 +3,8 @@ export DATA_DIR='data/nq_search'
|
||||
|
||||
WAND_PROJECT='Search-R1'
|
||||
|
||||
export BASE_MODEL='meta-llama/Llama-3.2-3B'
|
||||
export EXPERIMENT_NAME=nq-search-r1-grpo-llama3.2-3b-em
|
||||
# export BASE_MODEL='meta-llama/Llama-3.2-3B'
|
||||
# export EXPERIMENT_NAME=nq-search-r1-grpo-llama3.2-3b-em
|
||||
# export BASE_MODEL='meta-llama/Llama-3.2-3B-Instruct'
|
||||
# export EXPERIMENT_NAME=nq-search-r1-grpo-llama3.2-3b-it-em
|
||||
# export BASE_MODEL='meta-llama/Llama-3.1-8B'
|
||||
@@ -12,8 +12,8 @@ export EXPERIMENT_NAME=nq-search-r1-grpo-llama3.2-3b-em
|
||||
# export BASE_MODEL='meta-llama/Llama-3.1-8B-Instruct'
|
||||
# export EXPERIMENT_NAME=nq-search-r1-grpo-llama3.1-8b-it-em
|
||||
|
||||
# export BASE_MODEL='Qwen/Qwen2.5-3B'
|
||||
# export EXPERIMENT_NAME=nq-search-r1-grpo-qwen2.5-3b-em
|
||||
export BASE_MODEL='Qwen/Qwen2.5-3B'
|
||||
export EXPERIMENT_NAME=nq-search-r1-grpo-qwen2.5-3b-em
|
||||
# export BASE_MODEL='Qwen/Qwen2.5-3B-Instruct'
|
||||
# export EXPERIMENT_NAME=nq-search-r1-grpo-qwen2.5-3b-it-em
|
||||
# export BASE_MODEL='Qwen/Qwen2.5-7B'
|
||||
|
||||
@@ -739,8 +739,9 @@ class RayPPOTrainer(object):
|
||||
output = self.actor_rollout_wg.compute_log_prob(final_gen_batch_output)
|
||||
final_gen_batch_output = final_gen_batch_output.union(output)
|
||||
|
||||
batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
|
||||
dtype=object)
|
||||
# batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
|
||||
# dtype=object)
|
||||
batch.non_tensor_batch['uid'] = batch.non_tensor_batch['index'].copy()
|
||||
|
||||
# repeat to align with repeated responses in rollout
|
||||
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
||||
|
||||
Reference in New Issue
Block a user