diff --git a/mm_agents/gta1_agent.py b/mm_agents/gta1_agent.py index 4d621fc..aa34746 100644 --- a/mm_agents/gta1_agent.py +++ b/mm_agents/gta1_agent.py @@ -1034,7 +1034,7 @@ class GTA1Agent: assert len(valid_responses) > int(self.N_SEQ) * 0.5, f"Not enough valid responses generated {len(valid_responses)}" if self.N_SEQ > 1: - history_cache = [f"Observation:\n{o}\nThought:\n{t}\nAction:\n{a}" for a,t,o in zip(self.actions, self.thoughts, self.observations)] + history_cache = [f"Observation:\n{o}\nThought:\n{t}\nAction:\n{a}" for a,t,o in zip(self.actions, self.thoughts, self.observation_captions)] planner_response = self.select(instruction, Image.open(BytesIO(obs['screenshot'])), valid_responses, history_cache) else: planner_response = valid_responses[0]