fix turns_stats logging bug

This commit is contained in:
PeterGriffinJin
2025-03-21 14:58:42 +00:00
parent 83d10313be
commit d874947732

View File

@@ -260,23 +260,21 @@ def compute_data_metrics(batch, use_critic=True):
torch.min(prompt_length).detach().item(), torch.min(prompt_length).detach().item(),
'prompt_length/clip_ratio': 'prompt_length/clip_ratio':
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
}
# metrics for actions # metrics for actions
'env/number_of_actions/mean': if 'turns_stats' in batch.meta_info:
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean()), metrics['env/number_of_actions/mean'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean())
'env/number_of_actions/max': metrics['env/number_of_actions/max'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max())
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max()), metrics['env/number_of_actions/min'] = float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min())
'env/number_of_actions/min': if 'active_mask' in batch.meta_info:
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min()), metrics['env/finish_ratio'] = 1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean())
'env/finish_ratio': if 'valid_action_stats' in batch.meta_info:
1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean()), metrics['env/number_of_valid_action'] = float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean())
'env/number_of_valid_action': metrics['env/ratio_of_valid_action'] = float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean())
float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean()), if 'valid_search_stats' in batch.meta_info:
'env/ratio_of_valid_action': metrics['env/number_of_valid_search'] = float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean())
float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean()),
'env/number_of_valid_search':
float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean()),
}
return metrics return metrics