add action status
This commit is contained in:
@@ -262,18 +262,20 @@ def compute_data_metrics(batch, use_critic=True):
|
||||
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
|
||||
|
||||
# metrics for actions
|
||||
# 'metric/total_env':
|
||||
# int(np.array(batch.non_tensor_batch['total_env'], dtype=np.int16).sum()),
|
||||
# 'metric/finished_env':
|
||||
# int(np.array(batch.non_tensor_batch['finished_env'], dtype=np.int16).sum()),
|
||||
# 'metric/traj_length':
|
||||
# float(np.array(batch.non_tensor_batch['traj_length'], dtype=np.int16).mean()),
|
||||
# 'metric/valid_action':
|
||||
# float(np.array(batch.non_tensor_batch['valid_action'], dtype=np.int16).mean()),
|
||||
# 'metric/effective_action':
|
||||
# float(np.array(batch.non_tensor_batch['effective_action'], dtype=np.int16).mean()),
|
||||
# 'metric/effective_action_ratio':
|
||||
# float(np.array(batch.non_tensor_batch['effective_action_ratio'], dtype=np.float32).mean()),
|
||||
'env/number_of_actions/mean':
|
||||
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean()),
|
||||
'env/number_of_actions/max':
|
||||
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max()),
|
||||
'env/number_of_actions/min':
|
||||
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min()),
|
||||
'env/finish_ratio':
|
||||
1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean()),
|
||||
'env/number_of_valid_action':
|
||||
float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean()),
|
||||
'env/ratio_of_valid_action':
|
||||
float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean()),
|
||||
'env/number_of_valid_search':
|
||||
float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean()),
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
Reference in New Issue
Block a user