Fix bugs related to loss mask, meta info, and response length

1. Construct the loss mask immediately after obtaining the observation to prevent encoding misalignment when converting back to tokens after text transformation.
2. Follow up on meta info to ensure that the test batch can apply do sample.
3. Remove the recording of info information for response length.
This commit is contained in:
xiaobo-yang
2025-03-14 13:44:42 +08:00
parent 118c6e7361
commit 32719b5119
2 changed files with 54 additions and 61 deletions

View File

@@ -93,7 +93,7 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,
response_length = responses.size(1)
token_level_scores = data.batch['token_level_scores']
batch_size = data.batch.batch_size[0]
attention_mask = data.batch['attention_mask']
attention_mask = data.batch['info_mask']
response_mask = attention_mask[:, -response_length:]
# compute kl between ref_policy and current policy
@@ -163,8 +163,8 @@ def reduce_metrics(metrics: dict):
def _compute_response_info(batch):
response_length = batch.batch['responses'].shape[-1]
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
response_mask = batch.batch['attention_mask'][:, -response_length:]
prompt_mask = batch.batch['info_mask'][:, :-response_length]
response_mask = batch.batch['info_mask'][:, -response_length:]
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float() # (batch_size,)
@@ -867,50 +867,8 @@ class RayPPOTrainer(object):
response_length = batch.batch['responses'].shape[-1]
response_mask = batch.batch['attention_mask'][:, -response_length:]
# Initialize state mask
state_mask = torch.ones_like(response_mask)
responses = [self.tokenizer.decode(resp, skip_special_tokens=False) for resp in batch.batch['responses']]
for i, response in enumerate(responses):
# Find all pairs of start and end marker positions
start_marker = self.config.algorithm.state_masking.start_state_marker
end_marker = self.config.algorithm.state_masking.end_state_marker
# Get all start and end positions
start_positions = [m.start() for m in re.finditer(re.escape(start_marker), response)]
end_positions = [m.start() + len(end_marker) for m in re.finditer(re.escape(end_marker), response)]
# Convert character positions to token positions
for start, end in zip(start_positions, end_positions):
prefix_to_start = response[:start]
state_section = response[start:end]
start_tokens = self.tokenizer.encode(prefix_to_start, add_special_tokens=False)
state_tokens = self.tokenizer.encode(state_section, add_special_tokens=False)
start_token_pos = len(start_tokens)
end_token_pos = start_token_pos + len(state_tokens)
state_mask[i, start_token_pos:end_token_pos] = 0
loss_mask = state_mask * response_mask
loss_mask = batch.batch['info_mask'][:, -response_length:]
batch.batch['loss_mask'] = loss_mask
# # Debug print
# print("\nRaw batch[0] (before masking):\n", self.tokenizer.decode(batch.batch['responses'][0]))
# response_ids = batch.batch['responses'][0]
# unmasked_ids = response_ids[loss_mask[0] == 0]
# print("\nMasked batch[0] (after masking):\n", self.tokenizer.decode(unmasked_ids))
# masked_ids = response_ids[loss_mask[0] == 1]
# print("\nUnmasked batch[0] (masked parts):\n", self.tokenizer.decode(masked_ids))
# masked_ids = response_ids[response_mask[0] == 1]
# print("\nresponse_mask[0] == 1:\n", self.tokenizer.decode(masked_ids))
# masked_ids = response_ids[response_mask[0] == 0]
# print("\nresponse_mask[0] == 0:\n", self.tokenizer.decode(masked_ids))
metrics.update({
'state_tokens/total': loss_mask.sum().item(),