diff --git a/search_r1/llm_agent/generation.py b/search_r1/llm_agent/generation.py index fa98a05..3ea9abe 100644 --- a/search_r1/llm_agent/generation.py +++ b/search_r1/llm_agent/generation.py @@ -98,7 +98,7 @@ class LLMGenerationManager: return next_obs_ids - def _update_rolling_state(self, rollings, cur_responses: torch.Tensor, + def _update_rolling_state(self, rollings: DataProto, cur_responses: torch.Tensor, next_obs_ids: torch.Tensor) -> Dict: """Update rolling state with new responses and observations.""" # Concatenate and handle padding @@ -115,33 +115,64 @@ class LLMGenerationManager: # Cut to appropriate length effective_len = new_attention_mask.sum(dim=1).max() max_len = min(self.config.max_prompt_length, effective_len) - - return DataProto.from_dict({ + + new_rollings = DataProto.from_dict({ 'input_ids': new_input_ids[:, -max_len:], 'position_ids': new_position_ids[:, -max_len:], 'attention_mask': new_attention_mask[:, -max_len:] }) + new_rollings.meta_info.update(rollings.meta_info) + + return new_rollings + + def _info_masked_concatenate_with_padding(self, + prompt: torch.Tensor, + prompt_with_mask: torch.Tensor, + response: torch.Tensor, + info: torch.Tensor = None, + pad_to_left: bool = True + ) -> torch.Tensor: + """Concatenate tensors and handle padding. Additionally, create a mask (info_mask) to cover the information block if it exists.""" + pad_id = self.tokenizer.pad_token_id + tensors = [prompt, response] + tensors_with_mask = [prompt_with_mask, response] + if info is not None: + tensors.append(info) + info_mask = torch.full(info.size(), pad_id, dtype=info.dtype, device=info.device) # information mask + tensors_with_mask.append(info_mask) + + concatenated = torch.cat(tensors, dim=1) + concatenated_with_info = torch.cat(tensors_with_mask, dim=1) + mask = concatenated != pad_id if pad_to_left else concatenated == pad_id + sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True) + padded_tensor = concatenated.gather(1, sorted_indices) + padded_tensor_with_info = concatenated_with_info.gather(1, sorted_indices) + + return padded_tensor, padded_tensor_with_info def _update_right_side(self, right_side: Dict, cur_responses: torch.Tensor, next_obs_ids: torch.Tensor = None) -> Dict: """Update right side state.""" if next_obs_ids != None: - responses = self.tensor_fn.concatenate_with_padding([ - right_side['responses'], - cur_responses, - next_obs_ids - ], pad_to_left=False) + responses, responses_with_info_mask = self._info_masked_concatenate_with_padding( + right_side['responses'], + right_side['responses_with_info_mask'], + cur_responses, + next_obs_ids, + pad_to_left=False + ) else: - responses = self.tensor_fn.concatenate_with_padding([ - right_side['responses'], - cur_responses, - ], pad_to_left=False) - + responses, responses_with_info_mask = self._info_masked_concatenate_with_padding( + right_side['responses'], + right_side['responses_with_info_mask'], + cur_responses, + pad_to_left=False + ) effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max() max_len = min(self.config.max_prompt_length, effective_len) - return {'responses': responses[:, :max_len]} + return {'responses': responses[:, :max_len], 'responses_with_info_mask': responses_with_info_mask[:, :max_len]} def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto: """ @@ -194,7 +225,7 @@ class LLMGenerationManager: """Run main LLM generation loop.""" original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]} - original_right_side = {'responses': initial_input_ids[:, []]} + original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]} active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool) active_num_list = [active_mask.sum().item()] @@ -295,6 +326,10 @@ class LLMGenerationManager: self.tensor_fn.create_attention_mask(left_side['input_ids']), self.tensor_fn.create_attention_mask(final_output['responses']) ], dim=1) + final_output['info_mask'] = torch.cat([ + self.tensor_fn.create_attention_mask(left_side['input_ids']), + self.tensor_fn.create_attention_mask(final_output['responses_with_info_mask']) + ], dim=1) final_output['position_ids'] = self.tensor_fn.create_position_ids( final_output['attention_mask'] diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index af6cb18..0df2128 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -93,7 +93,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, response_length = responses.size(1) token_level_scores = data.batch['token_level_scores'] batch_size = data.batch.batch_size[0] - attention_mask = data.batch['attention_mask'] + attention_mask = data.batch['info_mask'] response_mask = attention_mask[:, -response_length:] # compute kl between ref_policy and current policy @@ -163,8 +163,8 @@ def reduce_metrics(metrics: dict): def _compute_response_info(batch): response_length = batch.batch['responses'].shape[-1] - prompt_mask = batch.batch['attention_mask'][:, :-response_length] - response_mask = batch.batch['attention_mask'][:, -response_length:] + prompt_mask = batch.batch['info_mask'][:, :-response_length] + response_mask = batch.batch['info_mask'][:, -response_length:] prompt_length = prompt_mask.sum(-1).float() response_length = response_mask.sum(-1).float() # (batch_size,) @@ -867,50 +867,8 @@ class RayPPOTrainer(object): response_length = batch.batch['responses'].shape[-1] response_mask = batch.batch['attention_mask'][:, -response_length:] - # Initialize state mask - state_mask = torch.ones_like(response_mask) - - responses = [self.tokenizer.decode(resp, skip_special_tokens=False) for resp in batch.batch['responses']] - - for i, response in enumerate(responses): - # Find all pairs of start and end marker positions - start_marker = self.config.algorithm.state_masking.start_state_marker - end_marker = self.config.algorithm.state_masking.end_state_marker - - # Get all start and end positions - start_positions = [m.start() for m in re.finditer(re.escape(start_marker), response)] - end_positions = [m.start() + len(end_marker) for m in re.finditer(re.escape(end_marker), response)] - - # Convert character positions to token positions - for start, end in zip(start_positions, end_positions): - prefix_to_start = response[:start] - state_section = response[start:end] - - start_tokens = self.tokenizer.encode(prefix_to_start, add_special_tokens=False) - state_tokens = self.tokenizer.encode(state_section, add_special_tokens=False) - - start_token_pos = len(start_tokens) - end_token_pos = start_token_pos + len(state_tokens) - - state_mask[i, start_token_pos:end_token_pos] = 0 - - loss_mask = state_mask * response_mask + loss_mask = batch.batch['info_mask'][:, -response_length:] batch.batch['loss_mask'] = loss_mask - - # # Debug print - # print("\nRaw batch[0] (before masking):\n", self.tokenizer.decode(batch.batch['responses'][0])) - # response_ids = batch.batch['responses'][0] - # unmasked_ids = response_ids[loss_mask[0] == 0] - # print("\nMasked batch[0] (after masking):\n", self.tokenizer.decode(unmasked_ids)) - - # masked_ids = response_ids[loss_mask[0] == 1] - # print("\nUnmasked batch[0] (masked parts):\n", self.tokenizer.decode(masked_ids)) - - # masked_ids = response_ids[response_mask[0] == 1] - # print("\nresponse_mask[0] == 1:\n", self.tokenizer.decode(masked_ids)) - - # masked_ids = response_ids[response_mask[0] == 0] - # print("\nresponse_mask[0] == 0:\n", self.tokenizer.decode(masked_ids)) metrics.update({ 'state_tokens/total': loss_mask.sum().item(),