fix turns_stats logging bug

This commit is contained in:
PeterGriffinJin
2025-03-21 14:58:42 +00:00
parent 83d10313be
commit d874947732

View File

@@ -260,24 +260,22 @@ def compute_data_metrics(batch, use_critic=True):
torch.min(prompt_length).detach().item(),
'prompt_length/clip_ratio':
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
# metrics for actions
'env/number_of_actions/mean':
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean()),
'env/number_of_actions/max':
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max()),
'env/number_of_actions/min':
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min()),
'env/finish_ratio':
1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean()),
'env/number_of_valid_action':
float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean()),
'env/ratio_of_valid_action':
float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean()),
'env/number_of_valid_search':
float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean()),
}
# metrics for actions
if 'turns_stats' in batch.meta_info:
metrics['env/number_of_actions/mean'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean())
metrics['env/number_of_actions/max'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max())
metrics['env/number_of_actions/min'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min())
if 'active_mask' in batch.meta_info:
metrics['env/finish_ratio'] = 1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean())
if 'valid_action_stats' in batch.meta_info:
metrics['env/number_of_valid_action'] = float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean())
metrics['env/ratio_of_valid_action'] = float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean())
if 'valid_search_stats' in batch.meta_info:
metrics['env/number_of_valid_search'] = float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean())
return metrics