add action status

This commit is contained in:
PeterGriffinJin
2025-03-19 22:19:27 +00:00
parent 9ec2fa9892
commit 83d10313be
2 changed files with 39 additions and 16 deletions

View File

@@ -262,18 +262,20 @@ def compute_data_metrics(batch, use_critic=True):
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
# metrics for actions
# 'metric/total_env':
# int(np.array(batch.non_tensor_batch['total_env'], dtype=np.int16).sum()),
# 'metric/finished_env':
# int(np.array(batch.non_tensor_batch['finished_env'], dtype=np.int16).sum()),
# 'metric/traj_length':
# float(np.array(batch.non_tensor_batch['traj_length'], dtype=np.int16).mean()),
# 'metric/valid_action':
# float(np.array(batch.non_tensor_batch['valid_action'], dtype=np.int16).mean()),
# 'metric/effective_action':
# float(np.array(batch.non_tensor_batch['effective_action'], dtype=np.int16).mean()),
# 'metric/effective_action_ratio':
# float(np.array(batch.non_tensor_batch['effective_action_ratio'], dtype=np.float32).mean()),
'env/number_of_actions/mean':
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean()),
'env/number_of_actions/max':
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max()),
'env/number_of_actions/min':
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min()),
'env/finish_ratio':
1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean()),
'env/number_of_valid_action':
float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean()),
'env/ratio_of_valid_action':
float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean()),
'env/number_of_valid_search':
float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean()),
}
return metrics