add action status
This commit is contained in:
@@ -228,6 +228,9 @@ class LLMGenerationManager:
|
||||
original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}
|
||||
|
||||
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
|
||||
turns_stats = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
|
||||
valid_action_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
|
||||
valid_search_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
|
||||
active_num_list = [active_mask.sum().item()]
|
||||
rollings = gen_batch
|
||||
|
||||
@@ -251,13 +254,16 @@ class LLMGenerationManager:
|
||||
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||
|
||||
# Execute in environment and process observations
|
||||
next_obs, dones = self.execute_predictions(
|
||||
next_obs, dones, valid_action, is_search = self.execute_predictions(
|
||||
responses_str, self.tokenizer.pad_token, active_mask
|
||||
)
|
||||
|
||||
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||
active_mask = active_mask * curr_active_mask
|
||||
active_num_list.append(active_mask.sum().item())
|
||||
turns_stats[curr_active_mask] += 1
|
||||
valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
|
||||
valid_search_stats += torch.tensor(is_search, dtype=torch.int)
|
||||
|
||||
next_obs_ids = self._process_next_obs(next_obs)
|
||||
|
||||
@@ -291,13 +297,20 @@ class LLMGenerationManager:
|
||||
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||
|
||||
# # Execute in environment and process observations
|
||||
_, dones = self.execute_predictions(
|
||||
_, dones, valid_action, is_search = self.execute_predictions(
|
||||
responses_str, self.tokenizer.pad_token, active_mask, do_search=False
|
||||
)
|
||||
|
||||
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||
active_mask = active_mask * curr_active_mask
|
||||
active_num_list.append(active_mask.sum().item())
|
||||
valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
|
||||
valid_search_stats += torch.tensor(is_search, dtype=torch.int)
|
||||
|
||||
meta_info['turns_stats'] = turns_stats.tolist()
|
||||
meta_info['active_mask'] = active_mask.tolist()
|
||||
meta_info['valid_action_stats'] = valid_action_stats.tolist()
|
||||
meta_info['valid_search_stats'] = valid_search_stats.tolist()
|
||||
|
||||
original_right_side = self._update_right_side(
|
||||
original_right_side,
|
||||
@@ -355,7 +368,7 @@ class LLMGenerationManager:
|
||||
List of observation strings
|
||||
"""
|
||||
cur_actions, contents = self.postprocess_predictions(predictions)
|
||||
next_obs, dones = [], []
|
||||
next_obs, dones, valid_action, is_search = [], [], [], []
|
||||
|
||||
search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search']
|
||||
if do_search:
|
||||
@@ -369,22 +382,30 @@ class LLMGenerationManager:
|
||||
if not active:
|
||||
next_obs.append('')
|
||||
dones.append(1)
|
||||
valid_action.append(0)
|
||||
is_search.append(0)
|
||||
else:
|
||||
if action == 'answer':
|
||||
next_obs.append('')
|
||||
dones.append(1)
|
||||
valid_action.append(1)
|
||||
is_search.append(0)
|
||||
elif action == 'search':
|
||||
next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n')
|
||||
dones.append(0)
|
||||
valid_action.append(1)
|
||||
is_search.append(1)
|
||||
else:
|
||||
next_obs.append(f'\nMy previous action is invalid. \
|
||||
If I want to search, I should put the query between <search> and </search>. \
|
||||
If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n')
|
||||
dones.append(0)
|
||||
valid_action.append(0)
|
||||
is_search.append(0)
|
||||
|
||||
assert len(search_results) == 0
|
||||
|
||||
return next_obs, dones
|
||||
return next_obs, dones, valid_action, is_search
|
||||
|
||||
def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]:
|
||||
"""
|
||||
|
||||
@@ -262,18 +262,20 @@ def compute_data_metrics(batch, use_critic=True):
|
||||
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
|
||||
|
||||
# metrics for actions
|
||||
# 'metric/total_env':
|
||||
# int(np.array(batch.non_tensor_batch['total_env'], dtype=np.int16).sum()),
|
||||
# 'metric/finished_env':
|
||||
# int(np.array(batch.non_tensor_batch['finished_env'], dtype=np.int16).sum()),
|
||||
# 'metric/traj_length':
|
||||
# float(np.array(batch.non_tensor_batch['traj_length'], dtype=np.int16).mean()),
|
||||
# 'metric/valid_action':
|
||||
# float(np.array(batch.non_tensor_batch['valid_action'], dtype=np.int16).mean()),
|
||||
# 'metric/effective_action':
|
||||
# float(np.array(batch.non_tensor_batch['effective_action'], dtype=np.int16).mean()),
|
||||
# 'metric/effective_action_ratio':
|
||||
# float(np.array(batch.non_tensor_batch['effective_action_ratio'], dtype=np.float32).mean()),
|
||||
'env/number_of_actions/mean':
|
||||
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean()),
|
||||
'env/number_of_actions/max':
|
||||
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max()),
|
||||
'env/number_of_actions/min':
|
||||
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min()),
|
||||
'env/finish_ratio':
|
||||
1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean()),
|
||||
'env/number_of_valid_action':
|
||||
float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean()),
|
||||
'env/ratio_of_valid_action':
|
||||
float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean()),
|
||||
'env/number_of_valid_search':
|
||||
float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean()),
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
Reference in New Issue
Block a user