Merge pull request #21 from xiaobo-yang/yxb/fix-info-mask-bugs

Fix bugs related to loss mask, meta info, and response length
This commit is contained in:
Bowen Jin
2025-03-18 19:33:50 -05:00
committed by GitHub
2 changed files with 54 additions and 61 deletions

View File

@@ -93,7 +93,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
response_length = responses.size(1)
token_level_scores = data.batch['token_level_scores']
batch_size = data.batch.batch_size[0]
attention_mask = data.batch['attention_mask']
attention_mask = data.batch['info_mask']
response_mask = attention_mask[:, -response_length:]
# compute kl between ref_policy and current policy
@@ -163,8 +163,8 @@ def reduce_metrics(metrics: dict):
def _compute_response_info(batch):
response_length = batch.batch['responses'].shape[-1]
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
response_mask = batch.batch['attention_mask'][:, -response_length:]
prompt_mask = batch.batch['info_mask'][:, :-response_length]
response_mask = batch.batch['info_mask'][:, -response_length:]
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float() # (batch_size,)
@@ -855,50 +855,8 @@ class RayPPOTrainer(object):
response_length = batch.batch['responses'].shape[-1]
response_mask = batch.batch['attention_mask'][:, -response_length:]
# Initialize state mask
state_mask = torch.ones_like(response_mask)
responses = [self.tokenizer.decode(resp, skip_special_tokens=False) for resp in batch.batch['responses']]
for i, response in enumerate(responses):
# Find all pairs of start and end marker positions
start_marker = self.config.algorithm.state_masking.start_state_marker
end_marker = self.config.algorithm.state_masking.end_state_marker
# Get all start and end positions
start_positions = [m.start() for m in re.finditer(re.escape(start_marker), response)]
end_positions = [m.start() + len(end_marker) for m in re.finditer(re.escape(end_marker), response)]
# Convert character positions to token positions
for start, end in zip(start_positions, end_positions):
prefix_to_start = response[:start]
state_section = response[start:end]
start_tokens = self.tokenizer.encode(prefix_to_start, add_special_tokens=False)
state_tokens = self.tokenizer.encode(state_section, add_special_tokens=False)
start_token_pos = len(start_tokens)
end_token_pos = start_token_pos + len(state_tokens)
state_mask[i, start_token_pos:end_token_pos] = 0
loss_mask = state_mask * response_mask
loss_mask = batch.batch['info_mask'][:, -response_length:]
batch.batch['loss_mask'] = loss_mask
# # Debug print
# print("\nRaw batch[0] (before masking):\n", self.tokenizer.decode(batch.batch['responses'][0]))
# response_ids = batch.batch['responses'][0]
# unmasked_ids = response_ids[loss_mask[0] == 0]
# print("\nMasked batch[0] (after masking):\n", self.tokenizer.decode(unmasked_ids))
# masked_ids = response_ids[loss_mask[0] == 1]
# print("\nUnmasked batch[0] (masked parts):\n", self.tokenizer.decode(masked_ids))
# masked_ids = response_ids[response_mask[0] == 1]
# print("\nresponse_mask[0] == 1:\n", self.tokenizer.decode(masked_ids))
# masked_ids = response_ids[response_mask[0] == 0]
# print("\nresponse_mask[0] == 0:\n", self.tokenizer.decode(masked_ids))
metrics.update({
'state_tokens/total': loss_mask.sum().item(),