add action status

This commit is contained in:
PeterGriffinJin
2025-03-19 22:19:27 +00:00
parent 9ec2fa9892
commit 83d10313be
2 changed files with 39 additions and 16 deletions

View File

@@ -228,6 +228,9 @@ class LLMGenerationManager:
original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
turns_stats = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
valid_action_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
valid_search_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
active_num_list = [active_mask.sum().item()]
rollings = gen_batch
@@ -251,13 +254,16 @@ class LLMGenerationManager:
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
# Execute in environment and process observations
next_obs, dones = self.execute_predictions(
next_obs, dones, valid_action, is_search = self.execute_predictions(
responses_str, self.tokenizer.pad_token, active_mask
)
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
active_mask = active_mask * curr_active_mask
active_num_list.append(active_mask.sum().item())
turns_stats[curr_active_mask] += 1
valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
valid_search_stats += torch.tensor(is_search, dtype=torch.int)
next_obs_ids = self._process_next_obs(next_obs)
@@ -291,13 +297,20 @@ class LLMGenerationManager:
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
# # Execute in environment and process observations
_, dones = self.execute_predictions(
_, dones, valid_action, is_search = self.execute_predictions(
responses_str, self.tokenizer.pad_token, active_mask, do_search=False
)
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
active_mask = active_mask * curr_active_mask
active_num_list.append(active_mask.sum().item())
valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
valid_search_stats += torch.tensor(is_search, dtype=torch.int)
meta_info['turns_stats'] = turns_stats.tolist()
meta_info['active_mask'] = active_mask.tolist()
meta_info['valid_action_stats'] = valid_action_stats.tolist()
meta_info['valid_search_stats'] = valid_search_stats.tolist()
original_right_side = self._update_right_side(
original_right_side,
@@ -355,7 +368,7 @@ class LLMGenerationManager:
List of observation strings
"""
cur_actions, contents = self.postprocess_predictions(predictions)
next_obs, dones = [], []
next_obs, dones, valid_action, is_search = [], [], [], []
search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search']
if do_search:
@@ -369,22 +382,30 @@ class LLMGenerationManager:
if not active:
next_obs.append('')
dones.append(1)
valid_action.append(0)
is_search.append(0)
else:
if action == 'answer':
next_obs.append('')
dones.append(1)
valid_action.append(1)
is_search.append(0)
elif action == 'search':
next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n')
dones.append(0)
valid_action.append(1)
is_search.append(1)
else:
next_obs.append(f'\nMy previous action is invalid. \
If I want to search, I should put the query between <search> and </search>. \
If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n')
dones.append(0)
valid_action.append(0)
is_search.append(0)
assert len(search_results) == 0
return next_obs, dones
return next_obs, dones, valid_action, is_search
def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]:
"""

View File

@@ -262,18 +262,20 @@ def compute_data_metrics(batch, use_critic=True):
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
# metrics for actions
# 'metric/total_env':
# int(np.array(batch.non_tensor_batch['total_env'], dtype=np.int16).sum()),
# 'metric/finished_env':
# int(np.array(batch.non_tensor_batch['finished_env'], dtype=np.int16).sum()),
# 'metric/traj_length':
# float(np.array(batch.non_tensor_batch['traj_length'], dtype=np.int16).mean()),
# 'metric/valid_action':
# float(np.array(batch.non_tensor_batch['valid_action'], dtype=np.int16).mean()),
# 'metric/effective_action':
# float(np.array(batch.non_tensor_batch['effective_action'], dtype=np.int16).mean()),
# 'metric/effective_action_ratio':
# float(np.array(batch.non_tensor_batch['effective_action_ratio'], dtype=np.float32).mean()),
'env/number_of_actions/mean':
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).mean()),
'env/number_of_actions/max':
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).max()),
'env/number_of_actions/min':
float(np.array(batch.meta_info['turns_stats'], dtype=np.int16).min()),
'env/finish_ratio':
1 - float(np.array(batch.meta_info['active_mask'], dtype=np.int16).mean()),
'env/number_of_valid_action':
float(np.array(batch.meta_info['valid_action_stats'], dtype=np.int16).mean()),
'env/ratio_of_valid_action':
float((np.array(batch.meta_info['valid_action_stats'], dtype=np.int16) / np.array(batch.meta_info['turns_stats'], dtype=np.int16)).mean()),
'env/number_of_valid_search':
float(np.array(batch.meta_info['valid_search_stats'], dtype=np.int16).mean()),
}
return metrics