Merge pull request #21 from xiaobo-yang/yxb/fix-info-mask-bugs
Fix bugs related to loss mask, meta info, and response length
This commit is contained in:
@@ -98,7 +98,7 @@ class LLMGenerationManager:
|
|||||||
|
|
||||||
return next_obs_ids
|
return next_obs_ids
|
||||||
|
|
||||||
def _update_rolling_state(self, rollings, cur_responses: torch.Tensor,
|
def _update_rolling_state(self, rollings: DataProto, cur_responses: torch.Tensor,
|
||||||
next_obs_ids: torch.Tensor) -> Dict:
|
next_obs_ids: torch.Tensor) -> Dict:
|
||||||
"""Update rolling state with new responses and observations."""
|
"""Update rolling state with new responses and observations."""
|
||||||
# Concatenate and handle padding
|
# Concatenate and handle padding
|
||||||
@@ -115,33 +115,64 @@ class LLMGenerationManager:
|
|||||||
# Cut to appropriate length
|
# Cut to appropriate length
|
||||||
effective_len = new_attention_mask.sum(dim=1).max()
|
effective_len = new_attention_mask.sum(dim=1).max()
|
||||||
max_len = min(self.config.max_prompt_length, effective_len)
|
max_len = min(self.config.max_prompt_length, effective_len)
|
||||||
|
|
||||||
return DataProto.from_dict({
|
new_rollings = DataProto.from_dict({
|
||||||
'input_ids': new_input_ids[:, -max_len:],
|
'input_ids': new_input_ids[:, -max_len:],
|
||||||
'position_ids': new_position_ids[:, -max_len:],
|
'position_ids': new_position_ids[:, -max_len:],
|
||||||
'attention_mask': new_attention_mask[:, -max_len:]
|
'attention_mask': new_attention_mask[:, -max_len:]
|
||||||
})
|
})
|
||||||
|
new_rollings.meta_info.update(rollings.meta_info)
|
||||||
|
|
||||||
|
return new_rollings
|
||||||
|
|
||||||
|
def _info_masked_concatenate_with_padding(self,
|
||||||
|
prompt: torch.Tensor,
|
||||||
|
prompt_with_mask: torch.Tensor,
|
||||||
|
response: torch.Tensor,
|
||||||
|
info: torch.Tensor = None,
|
||||||
|
pad_to_left: bool = True
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Concatenate tensors and handle padding. Additionally, create a mask (info_mask) to cover the information block if it exists."""
|
||||||
|
pad_id = self.tokenizer.pad_token_id
|
||||||
|
tensors = [prompt, response]
|
||||||
|
tensors_with_mask = [prompt_with_mask, response]
|
||||||
|
if info is not None:
|
||||||
|
tensors.append(info)
|
||||||
|
info_mask = torch.full(info.size(), pad_id, dtype=info.dtype, device=info.device) # information mask
|
||||||
|
tensors_with_mask.append(info_mask)
|
||||||
|
|
||||||
|
concatenated = torch.cat(tensors, dim=1)
|
||||||
|
concatenated_with_info = torch.cat(tensors_with_mask, dim=1)
|
||||||
|
mask = concatenated != pad_id if pad_to_left else concatenated == pad_id
|
||||||
|
sorted_indices = mask.to(torch.int64).argsort(dim=1, stable=True)
|
||||||
|
padded_tensor = concatenated.gather(1, sorted_indices)
|
||||||
|
padded_tensor_with_info = concatenated_with_info.gather(1, sorted_indices)
|
||||||
|
|
||||||
|
return padded_tensor, padded_tensor_with_info
|
||||||
|
|
||||||
def _update_right_side(self, right_side: Dict,
|
def _update_right_side(self, right_side: Dict,
|
||||||
cur_responses: torch.Tensor,
|
cur_responses: torch.Tensor,
|
||||||
next_obs_ids: torch.Tensor = None) -> Dict:
|
next_obs_ids: torch.Tensor = None) -> Dict:
|
||||||
"""Update right side state."""
|
"""Update right side state."""
|
||||||
if next_obs_ids != None:
|
if next_obs_ids != None:
|
||||||
responses = self.tensor_fn.concatenate_with_padding([
|
responses, responses_with_info_mask = self._info_masked_concatenate_with_padding(
|
||||||
right_side['responses'],
|
right_side['responses'],
|
||||||
cur_responses,
|
right_side['responses_with_info_mask'],
|
||||||
next_obs_ids
|
cur_responses,
|
||||||
], pad_to_left=False)
|
next_obs_ids,
|
||||||
|
pad_to_left=False
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
responses = self.tensor_fn.concatenate_with_padding([
|
responses, responses_with_info_mask = self._info_masked_concatenate_with_padding(
|
||||||
right_side['responses'],
|
right_side['responses'],
|
||||||
cur_responses,
|
right_side['responses_with_info_mask'],
|
||||||
], pad_to_left=False)
|
cur_responses,
|
||||||
|
pad_to_left=False
|
||||||
|
)
|
||||||
effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max()
|
effective_len = self.tensor_fn.create_attention_mask(responses).sum(dim=1).max()
|
||||||
max_len = min(self.config.max_prompt_length, effective_len)
|
max_len = min(self.config.max_prompt_length, effective_len)
|
||||||
|
|
||||||
return {'responses': responses[:, :max_len]}
|
return {'responses': responses[:, :max_len], 'responses_with_info_mask': responses_with_info_mask[:, :max_len]}
|
||||||
|
|
||||||
def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto:
|
def _generate_with_gpu_padding(self, active_batch: DataProto) -> DataProto:
|
||||||
"""
|
"""
|
||||||
@@ -194,7 +225,7 @@ class LLMGenerationManager:
|
|||||||
"""Run main LLM generation loop."""
|
"""Run main LLM generation loop."""
|
||||||
|
|
||||||
original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}
|
original_left_side = {'input_ids': initial_input_ids[:, -self.config.max_start_length:]}
|
||||||
original_right_side = {'responses': initial_input_ids[:, []]}
|
original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}
|
||||||
|
|
||||||
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
|
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
|
||||||
active_num_list = [active_mask.sum().item()]
|
active_num_list = [active_mask.sum().item()]
|
||||||
@@ -295,6 +326,10 @@ class LLMGenerationManager:
|
|||||||
self.tensor_fn.create_attention_mask(left_side['input_ids']),
|
self.tensor_fn.create_attention_mask(left_side['input_ids']),
|
||||||
self.tensor_fn.create_attention_mask(final_output['responses'])
|
self.tensor_fn.create_attention_mask(final_output['responses'])
|
||||||
], dim=1)
|
], dim=1)
|
||||||
|
final_output['info_mask'] = torch.cat([
|
||||||
|
self.tensor_fn.create_attention_mask(left_side['input_ids']),
|
||||||
|
self.tensor_fn.create_attention_mask(final_output['responses_with_info_mask'])
|
||||||
|
], dim=1)
|
||||||
|
|
||||||
final_output['position_ids'] = self.tensor_fn.create_position_ids(
|
final_output['position_ids'] = self.tensor_fn.create_position_ids(
|
||||||
final_output['attention_mask']
|
final_output['attention_mask']
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
|
|||||||
response_length = responses.size(1)
|
response_length = responses.size(1)
|
||||||
token_level_scores = data.batch['token_level_scores']
|
token_level_scores = data.batch['token_level_scores']
|
||||||
batch_size = data.batch.batch_size[0]
|
batch_size = data.batch.batch_size[0]
|
||||||
attention_mask = data.batch['attention_mask']
|
attention_mask = data.batch['info_mask']
|
||||||
response_mask = attention_mask[:, -response_length:]
|
response_mask = attention_mask[:, -response_length:]
|
||||||
|
|
||||||
# compute kl between ref_policy and current policy
|
# compute kl between ref_policy and current policy
|
||||||
@@ -163,8 +163,8 @@ def reduce_metrics(metrics: dict):
|
|||||||
def _compute_response_info(batch):
|
def _compute_response_info(batch):
|
||||||
response_length = batch.batch['responses'].shape[-1]
|
response_length = batch.batch['responses'].shape[-1]
|
||||||
|
|
||||||
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
|
prompt_mask = batch.batch['info_mask'][:, :-response_length]
|
||||||
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
response_mask = batch.batch['info_mask'][:, -response_length:]
|
||||||
|
|
||||||
prompt_length = prompt_mask.sum(-1).float()
|
prompt_length = prompt_mask.sum(-1).float()
|
||||||
response_length = response_mask.sum(-1).float() # (batch_size,)
|
response_length = response_mask.sum(-1).float() # (batch_size,)
|
||||||
@@ -855,50 +855,8 @@ class RayPPOTrainer(object):
|
|||||||
response_length = batch.batch['responses'].shape[-1]
|
response_length = batch.batch['responses'].shape[-1]
|
||||||
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
||||||
|
|
||||||
# Initialize state mask
|
loss_mask = batch.batch['info_mask'][:, -response_length:]
|
||||||
state_mask = torch.ones_like(response_mask)
|
|
||||||
|
|
||||||
responses = [self.tokenizer.decode(resp, skip_special_tokens=False) for resp in batch.batch['responses']]
|
|
||||||
|
|
||||||
for i, response in enumerate(responses):
|
|
||||||
# Find all pairs of start and end marker positions
|
|
||||||
start_marker = self.config.algorithm.state_masking.start_state_marker
|
|
||||||
end_marker = self.config.algorithm.state_masking.end_state_marker
|
|
||||||
|
|
||||||
# Get all start and end positions
|
|
||||||
start_positions = [m.start() for m in re.finditer(re.escape(start_marker), response)]
|
|
||||||
end_positions = [m.start() + len(end_marker) for m in re.finditer(re.escape(end_marker), response)]
|
|
||||||
|
|
||||||
# Convert character positions to token positions
|
|
||||||
for start, end in zip(start_positions, end_positions):
|
|
||||||
prefix_to_start = response[:start]
|
|
||||||
state_section = response[start:end]
|
|
||||||
|
|
||||||
start_tokens = self.tokenizer.encode(prefix_to_start, add_special_tokens=False)
|
|
||||||
state_tokens = self.tokenizer.encode(state_section, add_special_tokens=False)
|
|
||||||
|
|
||||||
start_token_pos = len(start_tokens)
|
|
||||||
end_token_pos = start_token_pos + len(state_tokens)
|
|
||||||
|
|
||||||
state_mask[i, start_token_pos:end_token_pos] = 0
|
|
||||||
|
|
||||||
loss_mask = state_mask * response_mask
|
|
||||||
batch.batch['loss_mask'] = loss_mask
|
batch.batch['loss_mask'] = loss_mask
|
||||||
|
|
||||||
# # Debug print
|
|
||||||
# print("\nRaw batch[0] (before masking):\n", self.tokenizer.decode(batch.batch['responses'][0]))
|
|
||||||
# response_ids = batch.batch['responses'][0]
|
|
||||||
# unmasked_ids = response_ids[loss_mask[0] == 0]
|
|
||||||
# print("\nMasked batch[0] (after masking):\n", self.tokenizer.decode(unmasked_ids))
|
|
||||||
|
|
||||||
# masked_ids = response_ids[loss_mask[0] == 1]
|
|
||||||
# print("\nUnmasked batch[0] (masked parts):\n", self.tokenizer.decode(masked_ids))
|
|
||||||
|
|
||||||
# masked_ids = response_ids[response_mask[0] == 1]
|
|
||||||
# print("\nresponse_mask[0] == 1:\n", self.tokenizer.decode(masked_ids))
|
|
||||||
|
|
||||||
# masked_ids = response_ids[response_mask[0] == 0]
|
|
||||||
# print("\nresponse_mask[0] == 0:\n", self.tokenizer.decode(masked_ids))
|
|
||||||
|
|
||||||
metrics.update({
|
metrics.update({
|
||||||
'state_tokens/total': loss_mask.sum().item(),
|
'state_tokens/total': loss_mask.sum().item(),
|
||||||
|
|||||||
Reference in New Issue
Block a user