add action status
This commit is contained in:
@@ -228,6 +228,9 @@ class LLMGenerationManager:
|
|||||||
original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}
|
original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}
|
||||||
|
|
||||||
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
|
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
|
||||||
|
turns_stats = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
|
||||||
|
valid_action_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
|
||||||
|
valid_search_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
|
||||||
active_num_list = [active_mask.sum().item()]
|
active_num_list = [active_mask.sum().item()]
|
||||||
rollings = gen_batch
|
rollings = gen_batch
|
||||||
|
|
||||||
@@ -251,13 +254,16 @@ class LLMGenerationManager:
|
|||||||
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||||
|
|
||||||
# Execute in environment and process observations
|
# Execute in environment and process observations
|
||||||
next_obs, dones = self.execute_predictions(
|
next_obs, dones, valid_action, is_search = self.execute_predictions(
|
||||||
responses_str, self.tokenizer.pad_token, active_mask
|
responses_str, self.tokenizer.pad_token, active_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||||
active_mask = active_mask * curr_active_mask
|
active_mask = active_mask * curr_active_mask
|
||||||
active_num_list.append(active_mask.sum().item())
|
active_num_list.append(active_mask.sum().item())
|
||||||
|
turns_stats[curr_active_mask] += 1
|
||||||
|
valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
|
||||||
|
valid_search_stats += torch.tensor(is_search, dtype=torch.int)
|
||||||
|
|
||||||
next_obs_ids = self._process_next_obs(next_obs)
|
next_obs_ids = self._process_next_obs(next_obs)
|
||||||
|
|
||||||
@@ -291,13 +297,20 @@ class LLMGenerationManager:
|
|||||||
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||||
|
|
||||||
# # Execute in environment and process observations
|
# # Execute in environment and process observations
|
||||||
_, dones = self.execute_predictions(
|
_, dones, valid_action, is_search = self.execute_predictions(
|
||||||
responses_str, self.tokenizer.pad_token, active_mask, do_search=False
|
responses_str, self.tokenizer.pad_token, active_mask, do_search=False
|
||||||
)
|
)
|
||||||
|
|
||||||
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||||
active_mask = active_mask * curr_active_mask
|
active_mask = active_mask * curr_active_mask
|
||||||
active_num_list.append(active_mask.sum().item())
|
active_num_list.append(active_mask.sum().item())
|
||||||
|
valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
|
||||||
|
valid_search_stats += torch.tensor(is_search, dtype=torch.int)
|
||||||
|
|
||||||
|
meta_info['turns_stats'] = turns_stats.tolist()
|
||||||
|
meta_info['active_mask'] = active_mask.tolist()
|
||||||
|
meta_info['valid_action_stats'] = valid_action_stats.tolist()
|
||||||
|
meta_info['valid_search_stats'] = valid_search_stats.tolist()
|
||||||
|
|
||||||
original_right_side = self._update_right_side(
|
original_right_side = self._update_right_side(
|
||||||
original_right_side,
|
original_right_side,
|
||||||
@@ -355,7 +368,7 @@ class LLMGenerationManager:
|
|||||||
List of observation strings
|
List of observation strings
|
||||||
"""
|
"""
|
||||||
cur_actions, contents = self.postprocess_predictions(predictions)
|
cur_actions, contents = self.postprocess_predictions(predictions)
|
||||||
next_obs, dones = [], []
|
next_obs, dones, valid_action, is_search = [], [], [], []
|
||||||
|
|
||||||
search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search']
|
search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search']
|
||||||
if do_search:
|
if do_search:
|
||||||
@@ -369,22 +382,30 @@ class LLMGenerationManager:
|
|||||||
if not active:
|
if not active:
|
||||||
next_obs.append('')
|
next_obs.append('')
|
||||||
dones.append(1)
|
dones.append(1)
|
||||||
|
valid_action.append(0)
|
||||||
|
is_search.append(0)
|
||||||
else:
|
else:
|
||||||
if action == 'answer':
|
if action == 'answer':
|
||||||
next_obs.append('')
|
next_obs.append('')
|
||||||
dones.append(1)
|
dones.append(1)
|
||||||
|
valid_action.append(1)
|
||||||
|
is_search.append(0)
|
||||||
elif action == 'search':
|
elif action == 'search':
|
||||||
next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n')
|
next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n')
|
||||||
dones.append(0)
|
dones.append(0)
|
||||||
|
valid_action.append(1)
|
||||||
|
is_search.append(1)
|
||||||
else:
|
else:
|
||||||
next_obs.append(f'\nMy previous action is invalid. \
|
next_obs.append(f'\nMy previous action is invalid. \
|
||||||
If I want to search, I should put the query between <search> and </search>. \
|
If I want to search, I should put the query between <search> and </search>. \
|
||||||
If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n')
|
If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n')
|
||||||
dones.append(0)
|
dones.append(0)
|
||||||
|
valid_action.append(0)
|
||||||
|
is_search.append(0)
|
||||||
|
|
||||||
assert len(search_results) == 0
|
assert len(search_results) == 0
|
||||||
|
|
||||||
return next_obs, dones
|
return next_obs, dones, valid_action, is_search
|
||||||
|
|
||||||
def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]:
|
def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -262,18 +262,20 @@ def compute_data_metrics(batch, use_critic=True):
|
|||||||
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
|
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
|
||||||
|
|
||||||
# metrics for actions
|
# metrics for actions
|
||||||
# 'metric/total_env':
|
'env/number_of_actions/mean':
|
||||||
# int(np.array(batch.non_tensor_batch['total_env'], dtype=np.int16).sum()),
|
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean()),
|
||||||
# 'metric/finished_env':
|
'env/number_of_actions/max':
|
||||||
# int(np.array(batch.non_tensor_batch['finished_env'], dtype=np.int16).sum()),
|
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max()),
|
||||||
# 'metric/traj_length':
|
'env/number_of_actions/min':
|
||||||
# float(np.array(batch.non_tensor_batch['traj_length'], dtype=np.int16).mean()),
|
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min()),
|
||||||
# 'metric/valid_action':
|
'env/finish_ratio':
|
||||||
# float(np.array(batch.non_tensor_batch['valid_action'], dtype=np.int16).mean()),
|
1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean()),
|
||||||
# 'metric/effective_action':
|
'env/number_of_valid_action':
|
||||||
# float(np.array(batch.non_tensor_batch['effective_action'], dtype=np.int16).mean()),
|
float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean()),
|
||||||
# 'metric/effective_action_ratio':
|
'env/ratio_of_valid_action':
|
||||||
# float(np.array(batch.non_tensor_batch['effective_action_ratio'], dtype=np.float32).mean()),
|
float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean()),
|
||||||
|
'env/number_of_valid_search':
|
||||||
|
float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean()),
|
||||||
}
|
}
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|||||||
Reference in New Issue
Block a user