fix turns_stats logging bug
This commit is contained in:
@@ -260,24 +260,22 @@ def compute_data_metrics(batch, use_critic=True):
|
||||
torch.min(prompt_length).detach().item(),
|
||||
'prompt_length/clip_ratio':
|
||||
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
|
||||
|
||||
# metrics for actions
|
||||
'env/number_of_actions/mean':
|
||||
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean()),
|
||||
'env/number_of_actions/max':
|
||||
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max()),
|
||||
'env/number_of_actions/min':
|
||||
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min()),
|
||||
'env/finish_ratio':
|
||||
1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean()),
|
||||
'env/number_of_valid_action':
|
||||
float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean()),
|
||||
'env/ratio_of_valid_action':
|
||||
float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean()),
|
||||
'env/number_of_valid_search':
|
||||
float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean()),
|
||||
}
|
||||
|
||||
# metrics for actions
|
||||
if 'turns_stats' in batch.meta_info:
|
||||
metrics['env/number_of_actions/mean'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean())
|
||||
metrics['env/number_of_actions/max'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max())
|
||||
metrics['env/number_of_actions/min'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min())
|
||||
if 'active_mask' in batch.meta_info:
|
||||
metrics['env/finish_ratio'] = 1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean())
|
||||
if 'valid_action_stats' in batch.meta_info:
|
||||
metrics['env/number_of_valid_action'] = float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean())
|
||||
metrics['env/ratio_of_valid_action'] = float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean())
|
||||
if 'valid_search_stats' in batch.meta_info:
|
||||
metrics['env/number_of_valid_search'] = float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean())
|
||||
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user