response length include retrieval info
This commit is contained in:
@@ -163,8 +163,8 @@ def reduce_metrics(metrics: dict):
|
||||
def _compute_response_info(batch):
|
||||
response_length = batch.batch['responses'].shape[-1]
|
||||
|
||||
prompt_mask = batch.batch['info_mask'][:, :-response_length]
|
||||
response_mask = batch.batch['info_mask'][:, -response_length:]
|
||||
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
|
||||
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
||||
|
||||
prompt_length = prompt_mask.sum(-1).float()
|
||||
response_length = response_mask.sum(-1).float() # (batch_size,)
|
||||
|
||||
Reference in New Issue
Block a user