Fix bugs related to loss mask, meta info, and response length

1. Construct the loss mask immediately after obtaining the observation to prevent encoding misalignment when converting back to tokens after text transformation.
2. Follow up on meta info to ensure that the test batch can apply do sample.
3. Remove the recording of info information for response length.
This commit is contained in:
xiaobo-yang
2025-03-14 13:44:42 +08:00
parent 118c6e7361
commit 32719b5119
2 changed files with 54 additions and 61 deletions

View File

@@ -98,7 +98,7 @@ class LLMGenerationManager:
return next_obs_ids
def _update_rolling_state(self, rollings, cur_responses: torch.Tensor,
def _update_rolling_state(self, rollings: DataProto, cur_responses: torch.Tensor,
next_obs_ids: torch.Tensor) -> Dict:
"""Update rolling state with new responses and observations."""
# Concatenate and handle padding
@@ -115,33 +115,64 @@ class LLMGenerationManager:
# Cut to appropriate length
effective_len = new_attention_mask.sum(dim=1).max()
max_len = min(self.config.max_prompt_length, effective_len)
return DataProto.from_dict({
new_rollings = DataProto.from_dict({
'input_ids': new_input_ids[:, -max_len:],
'position_ids': new_position_ids[:, -max_len:],
'attention_mask': new_attention_mask[:, -max_len:]
})
new_rollings.meta_info.update(rollings.meta_info)
return new_rollings
def _info_masked_concatenate_with_padding(self,
prompt: torch.Tensor,
prompt_with_mask: torch.Tensor,
response: torch.Tensor,
info: torch.Tensor = None,
pad_to_left: bool = True
) -> torch.Tensor:
"""Concatenate tensors and handle padding. Additionally, create a mask (info_mask) to cover the information block if it exists."""
pad_id = self.tokenizer.pad_token_id
tensors = [prompt, response]
tensors_with_mask = [prompt_with_mask, response]
if info is not None:
tensors.append(info)
info_mask = torch.full(info.size(), pad_id, dtype=info.dtype, device=info.device) # information mask
tensors_with_mask.append(info_mask)
concatenated = torch.cat(tensors, dim=1)
concatenated_with_info = torch.cat(tensors_with_mask, dim=1)
mask = concatenated != pad_id if pad_to_left else concatenated == pad_id
sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True)
padded_tensor = concatenated.gather(1, sorted_indices)
padded_tensor_with_info = concatenated_with_info.gather(1, sorted_indices)
return padded_tensor, padded_tensor_with_info
def _update_right_side(self, right_side: Dict,
cur_responses: torch.Tensor,
next_obs_ids: torch.Tensor = None) -> Dict:
"""Update right side state."""
if next_obs_ids != None:
responses = self.tensor_fn.concatenate_with_padding([
right_side['responses'],
cur_responses,
next_obs_ids
], pad_to_left=False)
responses, responses_with_info_mask = self._info_masked_concatenate_with_padding(
right_side['responses'],
right_side['responses_with_info_mask'],
cur_responses,
next_obs_ids,
pad_to_left=False
)
else:
responses = self.tensor_fn.concatenate_with_padding([
right_side['responses'],
cur_responses,
], pad_to_left=False)
responses, responses_with_info_mask = self._info_masked_concatenate_with_padding(
right_side['responses'],
right_side['responses_with_info_mask'],
cur_responses,
pad_to_left=False
)
effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max()
max_len = min(self.config.max_prompt_length, effective_len)
return {'responses': responses[:, :max_len]}
return {'responses': responses[:, :max_len], 'responses_with_info_mask': responses_with_info_mask[:, :max_len]}
def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto:
"""
@@ -194,7 +225,7 @@ class LLMGenerationManager:
"""Run main LLM generation loop."""
original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}
original_right_side = {'responses': initial_input_ids[:, []]}
original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
active_num_list = [active_mask.sum().item()]
@@ -295,6 +326,10 @@ class LLMGenerationManager:
self.tensor_fn.create_attention_mask(left_side['input_ids']),
self.tensor_fn.create_attention_mask(final_output['responses'])
], dim=1)
final_output['info_mask'] = torch.cat([
self.tensor_fn.create_attention_mask(left_side['input_ids']),
self.tensor_fn.create_attention_mask(final_output['responses_with_info_mask'])
], dim=1)
final_output['position_ids'] = self.tensor_fn.create_position_ids(
final_output['attention_mask']