Merge pull request #21 from xiaobo-yang/yxb/fix-info-mask-bugs
Fix bugs related to loss mask, meta info, and response length
This commit is contained in:
@@ -98,7 +98,7 @@ class LLMGenerationManager:
|
||||
|
||||
return next_obs_ids
|
||||
|
||||
def _update_rolling_state(self, rollings, cur_responses: torch.Tensor,
|
||||
def _update_rolling_state(self, rollings: DataProto, cur_responses: torch.Tensor,
|
||||
next_obs_ids: torch.Tensor) -> Dict:
|
||||
"""Update rolling state with new responses and observations."""
|
||||
# Concatenate and handle padding
|
||||
@@ -115,33 +115,64 @@ class LLMGenerationManager:
|
||||
# Cut to appropriate length
|
||||
effective_len = new_attention_mask.sum(dim=1).max()
|
||||
max_len = min(self.config.max_prompt_length, effective_len)
|
||||
|
||||
return DataProto.from_dict({
|
||||
|
||||
new_rollings = DataProto.from_dict({
|
||||
'input_ids': new_input_ids[:, -max_len:],
|
||||
'position_ids': new_position_ids[:, -max_len:],
|
||||
'attention_mask': new_attention_mask[:, -max_len:]
|
||||
})
|
||||
new_rollings.meta_info.update(rollings.meta_info)
|
||||
|
||||
return new_rollings
|
||||
|
||||
def _info_masked_concatenate_with_padding(self,
|
||||
prompt: torch.Tensor,
|
||||
prompt_with_mask: torch.Tensor,
|
||||
response: torch.Tensor,
|
||||
info: torch.Tensor = None,
|
||||
pad_to_left: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""Concatenate tensors and handle padding. Additionally, create a mask (info_mask) to cover the information block if it exists."""
|
||||
pad_id = self.tokenizer.pad_token_id
|
||||
tensors = [prompt, response]
|
||||
tensors_with_mask = [prompt_with_mask, response]
|
||||
if info is not None:
|
||||
tensors.append(info)
|
||||
info_mask = torch.full(info.size(), pad_id, dtype=info.dtype, device=info.device) # information mask
|
||||
tensors_with_mask.append(info_mask)
|
||||
|
||||
concatenated = torch.cat(tensors, dim=1)
|
||||
concatenated_with_info = torch.cat(tensors_with_mask, dim=1)
|
||||
mask = concatenated != pad_id if pad_to_left else concatenated == pad_id
|
||||
sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True)
|
||||
padded_tensor = concatenated.gather(1, sorted_indices)
|
||||
padded_tensor_with_info = concatenated_with_info.gather(1, sorted_indices)
|
||||
|
||||
return padded_tensor, padded_tensor_with_info
|
||||
|
||||
def _update_right_side(self, right_side: Dict,
|
||||
cur_responses: torch.Tensor,
|
||||
next_obs_ids: torch.Tensor = None) -> Dict:
|
||||
"""Update right side state."""
|
||||
if next_obs_ids != None:
|
||||
responses = self.tensor_fn.concatenate_with_padding([
|
||||
right_side['responses'],
|
||||
cur_responses,
|
||||
next_obs_ids
|
||||
], pad_to_left=False)
|
||||
responses, responses_with_info_mask = self._info_masked_concatenate_with_padding(
|
||||
right_side['responses'],
|
||||
right_side['responses_with_info_mask'],
|
||||
cur_responses,
|
||||
next_obs_ids,
|
||||
pad_to_left=False
|
||||
)
|
||||
else:
|
||||
responses = self.tensor_fn.concatenate_with_padding([
|
||||
right_side['responses'],
|
||||
cur_responses,
|
||||
], pad_to_left=False)
|
||||
|
||||
responses, responses_with_info_mask = self._info_masked_concatenate_with_padding(
|
||||
right_side['responses'],
|
||||
right_side['responses_with_info_mask'],
|
||||
cur_responses,
|
||||
pad_to_left=False
|
||||
)
|
||||
effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max()
|
||||
max_len = min(self.config.max_prompt_length, effective_len)
|
||||
|
||||
return {'responses': responses[:, :max_len]}
|
||||
return {'responses': responses[:, :max_len], 'responses_with_info_mask': responses_with_info_mask[:, :max_len]}
|
||||
|
||||
def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto:
|
||||
"""
|
||||
@@ -194,7 +225,7 @@ class LLMGenerationManager:
|
||||
"""Run main LLM generation loop."""
|
||||
|
||||
original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}
|
||||
original_right_side = {'responses': initial_input_ids[:, []]}
|
||||
original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}
|
||||
|
||||
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
|
||||
active_num_list = [active_mask.sum().item()]
|
||||
@@ -295,6 +326,10 @@ class LLMGenerationManager:
|
||||
self.tensor_fn.create_attention_mask(left_side['input_ids']),
|
||||
self.tensor_fn.create_attention_mask(final_output['responses'])
|
||||
], dim=1)
|
||||
final_output['info_mask'] = torch.cat([
|
||||
self.tensor_fn.create_attention_mask(left_side['input_ids']),
|
||||
self.tensor_fn.create_attention_mask(final_output['responses_with_info_mask'])
|
||||
], dim=1)
|
||||
|
||||
final_output['position_ids'] = self.tensor_fn.create_position_ids(
|
||||
final_output['attention_mask']
|
||||
|
||||
@@ -93,7 +93,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
|
||||
response_length = responses.size(1)
|
||||
token_level_scores = data.batch['token_level_scores']
|
||||
batch_size = data.batch.batch_size[0]
|
||||
attention_mask = data.batch['attention_mask']
|
||||
attention_mask = data.batch['info_mask']
|
||||
response_mask = attention_mask[:, -response_length:]
|
||||
|
||||
# compute kl between ref_policy and current policy
|
||||
@@ -163,8 +163,8 @@ def reduce_metrics(metrics: dict):
|
||||
def _compute_response_info(batch):
|
||||
response_length = batch.batch['responses'].shape[-1]
|
||||
|
||||
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
|
||||
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
||||
prompt_mask = batch.batch['info_mask'][:, :-response_length]
|
||||
response_mask = batch.batch['info_mask'][:, -response_length:]
|
||||
|
||||
prompt_length = prompt_mask.sum(-1).float()
|
||||
response_length = response_mask.sum(-1).float() # (batch_size,)
|
||||
@@ -855,50 +855,8 @@ class RayPPOTrainer(object):
|
||||
response_length = batch.batch['responses'].shape[-1]
|
||||
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
||||
|
||||
# Initialize state mask
|
||||
state_mask = torch.ones_like(response_mask)
|
||||
|
||||
responses = [self.tokenizer.decode(resp, skip_special_tokens=False) for resp in batch.batch['responses']]
|
||||
|
||||
for i, response in enumerate(responses):
|
||||
# Find all pairs of start and end marker positions
|
||||
start_marker = self.config.algorithm.state_masking.start_state_marker
|
||||
end_marker = self.config.algorithm.state_masking.end_state_marker
|
||||
|
||||
# Get all start and end positions
|
||||
start_positions = [m.start() for m in re.finditer(re.escape(start_marker), response)]
|
||||
end_positions = [m.start() + len(end_marker) for m in re.finditer(re.escape(end_marker), response)]
|
||||
|
||||
# Convert character positions to token positions
|
||||
for start, end in zip(start_positions, end_positions):
|
||||
prefix_to_start = response[:start]
|
||||
state_section = response[start:end]
|
||||
|
||||
start_tokens = self.tokenizer.encode(prefix_to_start, add_special_tokens=False)
|
||||
state_tokens = self.tokenizer.encode(state_section, add_special_tokens=False)
|
||||
|
||||
start_token_pos = len(start_tokens)
|
||||
end_token_pos = start_token_pos + len(state_tokens)
|
||||
|
||||
state_mask[i, start_token_pos:end_token_pos] = 0
|
||||
|
||||
loss_mask = state_mask * response_mask
|
||||
loss_mask = batch.batch['info_mask'][:, -response_length:]
|
||||
batch.batch['loss_mask'] = loss_mask
|
||||
|
||||
# # Debug print
|
||||
# print("\nRaw batch[0] (before masking):\n", self.tokenizer.decode(batch.batch['responses'][0]))
|
||||
# response_ids = batch.batch['responses'][0]
|
||||
# unmasked_ids = response_ids[loss_mask[0] == 0]
|
||||
# print("\nMasked batch[0] (after masking):\n", self.tokenizer.decode(unmasked_ids))
|
||||
|
||||
# masked_ids = response_ids[loss_mask[0] == 1]
|
||||
# print("\nUnmasked batch[0] (masked parts):\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
# masked_ids = response_ids[response_mask[0] == 1]
|
||||
# print("\nresponse_mask[0] == 1:\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
# masked_ids = response_ids[response_mask[0] == 0]
|
||||
# print("\nresponse_mask[0] == 0:\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
metrics.update({
|
||||
'state_tokens/total': loss_mask.sum().item(),
|
||||
|
||||
Reference in New Issue
Block a user