add action status
This commit is contained in:
@@ -228,6 +228,9 @@ class LLMGenerationManager:
|
||||
original_right_side = {'responses': initial_input_ids[:, []], 'responses_with_info_mask': initial_input_ids[:, []]}
|
||||
|
||||
active_mask = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.bool)
|
||||
turns_stats = torch.ones(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
|
||||
valid_action_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
|
||||
valid_search_stats = torch.zeros(gen_batch.batch['input_ids'].shape[0], dtype=torch.int)
|
||||
active_num_list = [active_mask.sum().item()]
|
||||
rollings = gen_batch
|
||||
|
||||
@@ -251,13 +254,16 @@ class LLMGenerationManager:
|
||||
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||
|
||||
# Execute in environment and process observations
|
||||
next_obs, dones = self.execute_predictions(
|
||||
next_obs, dones, valid_action, is_search = self.execute_predictions(
|
||||
responses_str, self.tokenizer.pad_token, active_mask
|
||||
)
|
||||
|
||||
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||
active_mask = active_mask * curr_active_mask
|
||||
active_num_list.append(active_mask.sum().item())
|
||||
turns_stats[curr_active_mask] += 1
|
||||
valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
|
||||
valid_search_stats += torch.tensor(is_search, dtype=torch.int)
|
||||
|
||||
next_obs_ids = self._process_next_obs(next_obs)
|
||||
|
||||
@@ -291,13 +297,20 @@ class LLMGenerationManager:
|
||||
responses_ids, responses_str = self.tensor_fn._example_level_pad(responses_ids, responses_str, active_mask)
|
||||
|
||||
# # Execute in environment and process observations
|
||||
_, dones = self.execute_predictions(
|
||||
_, dones, valid_action, is_search = self.execute_predictions(
|
||||
responses_str, self.tokenizer.pad_token, active_mask, do_search=False
|
||||
)
|
||||
|
||||
curr_active_mask = torch.tensor([not done for done in dones], dtype=torch.bool)
|
||||
active_mask = active_mask * curr_active_mask
|
||||
active_num_list.append(active_mask.sum().item())
|
||||
valid_action_stats += torch.tensor(valid_action, dtype=torch.int)
|
||||
valid_search_stats += torch.tensor(is_search, dtype=torch.int)
|
||||
|
||||
meta_info['turns_stats'] = turns_stats.tolist()
|
||||
meta_info['active_mask'] = active_mask.tolist()
|
||||
meta_info['valid_action_stats'] = valid_action_stats.tolist()
|
||||
meta_info['valid_search_stats'] = valid_search_stats.tolist()
|
||||
|
||||
original_right_side = self._update_right_side(
|
||||
original_right_side,
|
||||
@@ -355,7 +368,7 @@ class LLMGenerationManager:
|
||||
List of observation strings
|
||||
"""
|
||||
cur_actions, contents = self.postprocess_predictions(predictions)
|
||||
next_obs, dones = [], []
|
||||
next_obs, dones, valid_action, is_search = [], [], [], []
|
||||
|
||||
search_queries = [content for action, content in zip(cur_actions, contents) if action == 'search']
|
||||
if do_search:
|
||||
@@ -369,22 +382,30 @@ class LLMGenerationManager:
|
||||
if not active:
|
||||
next_obs.append('')
|
||||
dones.append(1)
|
||||
valid_action.append(0)
|
||||
is_search.append(0)
|
||||
else:
|
||||
if action == 'answer':
|
||||
next_obs.append('')
|
||||
dones.append(1)
|
||||
valid_action.append(1)
|
||||
is_search.append(0)
|
||||
elif action == 'search':
|
||||
next_obs.append(f'\n\n<information>{search_results.pop(0).strip()}</information>\n\n')
|
||||
dones.append(0)
|
||||
valid_action.append(1)
|
||||
is_search.append(1)
|
||||
else:
|
||||
next_obs.append(f'\nMy previous action is invalid. \
|
||||
If I want to search, I should put the query between <search> and </search>. \
|
||||
If I want to give the final answer, I should put the answer between <answer> and </answer>. Let me try again.\n')
|
||||
dones.append(0)
|
||||
valid_action.append(0)
|
||||
is_search.append(0)
|
||||
|
||||
assert len(search_results) == 0
|
||||
|
||||
return next_obs, dones
|
||||
return next_obs, dones, valid_action, is_search
|
||||
|
||||
def postprocess_predictions(self, predictions: List[Any]) -> Tuple[List[int], List[bool]]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user