response length include retrieval info

This commit is contained in:
PeterGriffinJin
2025-03-19 00:36:21 +00:00
parent 50cedb2c00
commit 8c7f04ca45

View File

@@ -163,8 +163,8 @@ def reduce_metrics(metrics: dict):
def _compute_response_info(batch):
response_length = batch.batch['responses'].shape[-1]
prompt_mask = batch.batch['info_mask'][:, :-response_length]
response_mask = batch.batch['info_mask'][:, -response_length:]
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
response_mask = batch.batch['attention_mask'][:, -response_length:]
prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float() # (batch_size,)