add action status

This commit is contained in:
PeterGriffinJin
2025-03-19 22:19:27 +00:00
parent 9ec2fa9892
commit 83d10313be
2 changed files with 39 additions and 16 deletions

View File

@@ -228,6 +228,9 @@ class LLMGenerationManager:
original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
turns_stats = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
valid_action_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
valid_search_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
active_num_list = [active_mask.sum().item()]
rollings = gen_batch
@@ -251,13 +254,16 @@ class LLMGenerationManager:
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
# Execute in environment and process observations
next_obs, dones = self.execute_predictions(
next_obs, dones, valid_action, is_search = self.execute_predictions(
responses_str, self.tokenizer.pad_token, active_mask
)
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
active_mask = active_mask * curr_active_mask
active_num_list.append(active_mask.sum().item())
turns_stats[curr_active_mask] += 1
valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
valid_search_stats += torch.tensor(is_search, dtype=torch.int)
next_obs_ids = self._process_next_obs(next_obs)
@@ -291,13 +297,20 @@ class LLMGenerationManager:
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
# # Execute in environment and process observations
_, dones = self.execute_predictions(
_, dones, valid_action, is_search = self.execute_predictions(
responses_str, self.tokenizer.pad_token, active_mask, do_search=False
)
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
active_mask = active_mask * curr_active_mask
active_num_list.append(active_mask.sum().item())
valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
valid_search_stats += torch.tensor(is_search, dtype=torch.int)
meta_info['turns_stats'] = turns_stats.tolist()
meta_info['active_mask'] = active_mask.tolist()
meta_info['valid_action_stats'] = valid_action_stats.tolist()
meta_info['valid_search_stats'] = valid_search_stats.tolist()
original_right_side = self._update_right_side(
original_right_side,
@@ -355,7 +368,7 @@ class LLMGenerationManager:
List of observation strings
"""
cur_actions, contents = self.postprocess_predictions(predictions)
next_obs, dones = [], []
next_obs, dones, valid_action, is_search = [], [], [], []
search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search']
if do_search:
@@ -369,22 +382,30 @@ class LLMGenerationManager:
if not active:
next_obs.append('')
dones.append(1)
valid_action.append(0)
is_search.append(0)
else:
if action == 'answer':
next_obs.append('')
dones.append(1)
valid_action.append(1)
is_search.append(0)
elif action == 'search':
next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n')
dones.append(0)
valid_action.append(1)
is_search.append(1)
else:
next_obs.append(f'\nMy previous action is invalid. \
If I want to search, I should put the query between <search> and </search>. \
If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n')
dones.append(0)
valid_action.append(0)
is_search.append(0)
assert len(search_results) == 0
return next_obs, dones
return next_obs, dones, valid_action, is_search
def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]:
"""