response length include retrieval info
This commit is contained in:
@@ -163,8 +163,8 @@ def reduce_metrics(metrics: dict):
|
|||||||
def _compute_response_info(batch):
|
def _compute_response_info(batch):
|
||||||
response_length = batch.batch['responses'].shape[-1]
|
response_length = batch.batch['responses'].shape[-1]
|
||||||
|
|
||||||
prompt_mask = batch.batch['info_mask'][:, :-response_length]
|
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
|
||||||
response_mask = batch.batch['info_mask'][:, -response_length:]
|
response_mask = batch.batch['attention_mask'][:, -response_length:]
|
||||||
|
|
||||||
prompt_length = prompt_mask.sum(-1).float()
|
prompt_length = prompt_mask.sum(-1).float()
|
||||||
response_length = response_mask.sum(-1).float() # (batch_size,)
|
response_length = response_mask.sum(-1).float() # (batch_size,)
|
||||||
|
|||||||
Reference in New Issue
Block a user