Merge pull request #21 from xiaobo-yang/yxb/fix-info-mask-bugs
Fix bugs related to loss mask, meta info, and response length
This commit is contained in:
@@ -93,7 +93,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
|
||||
response_length = responses.size(1)
|
||||
token_level_scores = data.batch['token_level_scores']
|
||||
batch_size = data.batch.batch_size[0]
|
||||
attention_mask = data.batch['attention_mask']
|
||||
attention_mask = data.batch['info_mask']
|
||||
response_mask = attention_mask[:, -response_length:]
|
||||
|
||||
# compute kl between ref_policy and current policy
|
||||
@@ -163,8 +163,8 @@ def reduce_metrics(metrics: dict):
|
||||
def _compute_response_info(batch):
|
||||
response_length = batch.batch['responses'].shape[-1]
|
||||
|
||||
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
|
||||
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
||||
prompt_mask = batch.batch['info_mask'][:, :-response_length]
|
||||
response_mask = batch.batch['info_mask'][:, -response_length:]
|
||||
|
||||
prompt_length = prompt_mask.sum(-1).float()
|
||||
response_length = response_mask.sum(-1).float() # (batch_size,)
|
||||
@@ -855,50 +855,8 @@ class RayPPOTrainer(object):
|
||||
response_length = batch.batch['responses'].shape[-1]
|
||||
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
||||
|
||||
# Initialize state mask
|
||||
state_mask = torch.ones_like(response_mask)
|
||||
|
||||
responses = [self.tokenizer.decode(resp, skip_special_tokens=False) for resp in batch.batch['responses']]
|
||||
|
||||
for i, response in enumerate(responses):
|
||||
# Find all pairs of start and end marker positions
|
||||
start_marker = self.config.algorithm.state_masking.start_state_marker
|
||||
end_marker = self.config.algorithm.state_masking.end_state_marker
|
||||
|
||||
# Get all start and end positions
|
||||
start_positions = [m.start() for m in re.finditer(re.escape(start_marker), response)]
|
||||
end_positions = [m.start() + len(end_marker) for m in re.finditer(re.escape(end_marker), response)]
|
||||
|
||||
# Convert character positions to token positions
|
||||
for start, end in zip(start_positions, end_positions):
|
||||
prefix_to_start = response[:start]
|
||||
state_section = response[start:end]
|
||||
|
||||
start_tokens = self.tokenizer.encode(prefix_to_start, add_special_tokens=False)
|
||||
state_tokens = self.tokenizer.encode(state_section, add_special_tokens=False)
|
||||
|
||||
start_token_pos = len(start_tokens)
|
||||
end_token_pos = start_token_pos + len(state_tokens)
|
||||
|
||||
state_mask[i, start_token_pos:end_token_pos] = 0
|
||||
|
||||
loss_mask = state_mask * response_mask
|
||||
loss_mask = batch.batch['info_mask'][:, -response_length:]
|
||||
batch.batch['loss_mask'] = loss_mask
|
||||
|
||||
# # Debug print
|
||||
# print("\nRaw batch[0] (before masking):\n", self.tokenizer.decode(batch.batch['responses'][0]))
|
||||
# response_ids = batch.batch['responses'][0]
|
||||
# unmasked_ids = response_ids[loss_mask[0] == 0]
|
||||
# print("\nMasked batch[0] (after masking):\n", self.tokenizer.decode(unmasked_ids))
|
||||
|
||||
# masked_ids = response_ids[loss_mask[0] == 1]
|
||||
# print("\nUnmasked batch[0] (masked parts):\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
# masked_ids = response_ids[response_mask[0] == 1]
|
||||
# print("\nresponse_mask[0] == 1:\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
# masked_ids = response_ids[response_mask[0] == 0]
|
||||
# print("\nresponse_mask[0] == 0:\n", self.tokenizer.decode(masked_ids))
|
||||
|
||||
metrics.update({
|
||||
'state_tokens/total': loss_mask.sum().item(),
|
||||
|
||||
Reference in New Issue
Block a user