Fix bugs related to loss mask, meta info, and response length
1. Construct the loss mask immediately after obtaining the observation to prevent encoding misalignment when converting back to tokens after text transformation. 2. Follow up on meta info to ensure that the test batch can apply do sample. 3. Remove the recording of info information for response length.
This commit is contained in:
@@ -98,7 +98,7 @@ class LLMGenerationManager:
|
||||
|
||||
return next_obs_ids
|
||||
|
||||
def _update_rolling_state(self, rollings, cur_responses: torch.Tensor,
|
||||
def _update_rolling_state(self, rollings: DataProto, cur_responses: torch.Tensor,
|
||||
next_obs_ids: torch.Tensor) -> Dict:
|
||||
"""Update rolling state with new responses and observations."""
|
||||
# Concatenate and handle padding
|
||||
@@ -115,33 +115,64 @@ class LLMGenerationManager:
|
||||
# Cut to appropriate length
|
||||
effective_len = new_attention_mask.sum(dim=1).max()
|
||||
max_len = min(self.config.max_prompt_length, effective_len)
|
||||
|
||||
return DataProto.from_dict({
|
||||
|
||||
new_rollings = DataProto.from_dict({
|
||||
'input_ids': new_input_ids[:, -max_len:],
|
||||
'position_ids': new_position_ids[:, -max_len:],
|
||||
'attention_mask': new_attention_mask[:, -max_len:]
|
||||
})
|
||||
new_rollings.meta_info.update(rollings.meta_info)
|
||||
|
||||
return new_rollings
|
||||
|
||||
def _info_masked_concatenate_with_padding(self,
|
||||
prompt: torch.Tensor,
|
||||
prompt_with_mask: torch.Tensor,
|
||||
response: torch.Tensor,
|
||||
info: torch.Tensor = None,
|
||||
pad_to_left: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""Concatenate tensors and handle padding. Additionally, create a mask (info_mask) to cover the information block if it exists."""
|
||||
pad_id = self.tokenizer.pad_token_id
|
||||
tensors = [prompt, response]
|
||||
tensors_with_mask = [prompt_with_mask, response]
|
||||
if info is not None:
|
||||
tensors.append(info)
|
||||
info_mask = torch.full(info.size(), pad_id, dtype=info.dtype, device=info.device) # information mask
|
||||
tensors_with_mask.append(info_mask)
|
||||
|
||||
concatenated = torch.cat(tensors, dim=1)
|
||||
concatenated_with_info = torch.cat(tensors_with_mask, dim=1)
|
||||
mask = concatenated != pad_id if pad_to_left else concatenated == pad_id
|
||||
sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True)
|
||||
padded_tensor = concatenated.gather(1, sorted_indices)
|
||||
padded_tensor_with_info = concatenated_with_info.gather(1, sorted_indices)
|
||||
|
||||
return padded_tensor, padded_tensor_with_info
|
||||
|
||||
def _update_right_side(self, right_side: Dict,
|
||||
cur_responses: torch.Tensor,
|
||||
next_obs_ids: torch.Tensor = None) -> Dict:
|
||||
"""Update right side state."""
|
||||
if next_obs_ids != None:
|
||||
responses = self.tensor_fn.concatenate_with_padding([
|
||||
right_side['responses'],
|
||||
cur_responses,
|
||||
next_obs_ids
|
||||
], pad_to_left=False)
|
||||
responses, responses_with_info_mask = self._info_masked_concatenate_with_padding(
|
||||
right_side['responses'],
|
||||
right_side['responses_with_info_mask'],
|
||||
cur_responses,
|
||||
next_obs_ids,
|
||||
pad_to_left=False
|
||||
)
|
||||
else:
|
||||
responses = self.tensor_fn.concatenate_with_padding([
|
||||
right_side['responses'],
|
||||
cur_responses,
|
||||
], pad_to_left=False)
|
||||
|
||||
responses, responses_with_info_mask = self._info_masked_concatenate_with_padding(
|
||||
right_side['responses'],
|
||||
right_side['responses_with_info_mask'],
|
||||
cur_responses,
|
||||
pad_to_left=False
|
||||
)
|
||||
effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max()
|
||||
max_len = min(self.config.max_prompt_length, effective_len)
|
||||
|
||||
return {'responses': responses[:, :max_len]}
|
||||
return {'responses': responses[:, :max_len], 'responses_with_info_mask': responses_with_info_mask[:, :max_len]}
|
||||
|
||||
def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto:
|
||||
"""
|
||||
@@ -194,7 +225,7 @@ class LLMGenerationManager:
|
||||
"""Run main LLM generation loop."""
|
||||
|
||||
original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}
|
||||
original_right_side = {'responses': initial_input_ids[:, []]}
|
||||
original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}
|
||||
|
||||
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
|
||||
active_num_list = [active_mask.sum().item()]
|
||||
@@ -295,6 +326,10 @@ class LLMGenerationManager:
|
||||
self.tensor_fn.create_attention_mask(left_side['input_ids']),
|
||||
self.tensor_fn.create_attention_mask(final_output['responses'])
|
||||
], dim=1)
|
||||
final_output['info_mask'] = torch.cat([
|
||||
self.tensor_fn.create_attention_mask(left_side['input_ids']),
|
||||
self.tensor_fn.create_attention_mask(final_output['responses_with_info_mask'])
|
||||
], dim=1)
|
||||
|
||||
final_output['position_ids'] = self.tensor_fn.create_position_ids(
|
||||
final_output['attention_mask']
|
||||
|
||||
@@ -93,7 +93,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
|
||||
response_length = responses.size(1)
|
||||
token_level_scores = data.batch['token_level_scores']
|
||||
batch_size = data.batch.batch_size[0]
|
||||
attention_mask = data.batch['attention_mask']
|
||||
attention_mask = data.batch['info_mask']
|
||||
response_mask = attention_mask[:, -response_length:]
|
||||
|
||||
# compute kl between ref_policy and current policy
|
||||
@@ -163,8 +163,8 @@ def reduce_metrics(metrics: dict):
|
||||
def _compute_response_info(batch):
|
||||
response_length = batch.batch['responses'].shape[-1]
|
||||
|
||||
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
|
||||
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
||||
prompt_mask = batch.batch['info_mask'][:, :-response_length]
|
||||
response_mask = batch.batch['info_mask'][:, -response_length:]
|
||||
|
||||
prompt_length = prompt_mask.sum(-1).float()
|
||||
response_length = response_mask.sum(-1).float() # (batch_size,)
|
||||
@@ -867,50 +867,8 @@ class RayPPOTrainer(object):
|
||||
response_length = batch.batch['responses'].shape[-1]
|
||||
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
||||
|
||||
# Initialize state mask
|
||||
state_mask = torch.ones_like(response_mask)
|
||||
|
||||
responses = [self.tokenizer.decode(resp, skip_special_tokens=False) for resp in batch.batch['responses']]
|
||||
|
||||
for i, response in enumerate(responses):
|
||||
# Find all pairs of start and end marker positions
|
||||
start_marker = self.config.algorithm.state_masking.start_state_marker
|
||||
end_marker = self.config.algorithm.state_masking.end_state_marker
|
||||
|
||||
# Get all start and end positions
|
||||
start_positions = [m.start() for m in re.finditer(re.escape(start_marker), response)]
|
||||
end_positions = [m.start() + len(end_marker) for m in re.finditer(re.escape(end_marker), response)]
|
||||
|
||||
# Convert character positions to token positions
|
||||
for start, end in zip(start_positions, end_positions):
|
||||
prefix_to_start = response[:start]
|
||||
state_section = response[start:end]
|
||||
|
||||
start_tokens = self.tokenizer.encode(prefix_to_start, add_special_tokens=False)
|
||||
state_tokens = self.tokenizer.encode(state_section, add_special_tokens=False)
|
||||
|
||||
start_token_pos = len(start_tokens)
|
||||
end_token_pos = start_token_pos + len(state_tokens)
|
||||
|
||||
state_mask[i, start_token_pos:end_token_pos] = 0
|
||||
|
||||
loss_mask = state_mask * response_mask
|
||||
loss_mask = batch.batch['info_mask'][:, -response_length:]
|
||||
batch.batch['loss_mask'] = loss_mask
|
||||
|
||||
# # Debug print
|
||||
# print("\nRaw batch[0] (before masking):\n", self.tokenizer.decode(batch.batch['responses'][0]))
|
||||
# response_ids = batch.batch['responses'][0]
|
||||
# unmasked_ids = response_ids[loss_mask[0] == 0]
|
||||
# print("\nMasked batch[0] (after masking):\n", self.tokenizer.decode(unmasked_ids))
|
||||
|
||||
# masked_ids = response_ids[loss_mask[0] == 1]
|
||||
# print("\nUnmasked batch[0] (masked parts):\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
# masked_ids = response_ids[response_mask[0] == 1]
|
||||
# print("\nresponse_mask[0] == 1:\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
# masked_ids = response_ids[response_mask[0] == 0]
|
||||
# print("\nresponse_mask[0] == 0:\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
metrics.update({
|
||||
'state_tokens/total': loss_mask.sum().item(),
|
||||
|
||||
Reference in New Issue
Block a user