diff --git a/mm_agents/jedi_3b_agent.py b/mm_agents/jedi_3b_agent.py index 3b456e2..f0b28c9 100644 --- a/mm_agents/jedi_3b_agent.py +++ b/mm_agents/jedi_3b_agent.py @@ -146,16 +146,26 @@ class JediAgent3B: ], }) - planner_response = self.call_llm( - { - "model": self.planner_model, - "messages": messages, - "max_tokens": self.max_tokens, - "top_p": self.top_p, - "temperature": self.temperature, - }, - self.planner_model, - ) + if self.planner_model == "o3": + planner_response = self.call_llm( + { + "model": self.planner_model, + "messages": messages, + "max_completion_tokens": self.max_tokens + }, + self.planner_model, + ) + else: + planner_response = self.call_llm( + { + "model": self.planner_model, + "messages": messages, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "temperature": self.temperature, + }, + self.planner_model, + ) logger.info(f"Planner Output: {planner_response}") codes = self.parse_code_from_planner_response(planner_response) @@ -170,16 +180,26 @@ class JediAgent3B: {"type": "text", "text": "You didn't generate valid actions. Please try again."} ] }) - planner_response = self.call_llm( - { - "model": self.planner_model, - "messages": messages, - "max_tokens": self.max_tokens, - "top_p": self.top_p, - "temperature": self.temperature, - }, - self.planner_model, - ) + if self.planner_model == "o3": + planner_response = self.call_llm( + { + "model": self.planner_model, + "messages": messages, + "max_completion_tokens": self.max_tokens + }, + self.planner_model, + ) + else: + planner_response = self.call_llm( + { + "model": self.planner_model, + "messages": messages, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "temperature": self.temperature, + }, + self.planner_model, + ) logger.info(f"Retry Planner Output: {planner_response}") codes = self.parse_code_from_planner_response(planner_response) retry_count += 1 @@ -204,7 +224,7 @@ class JediAgent3B: self.thoughts.append(thought) self.observation_captions.append(observation_caption) self.current_step += 1 - return planner_response, pyautogui_actions, {} + return planner_response, pyautogui_actions def parse_observation_caption_from_planner_response(self, input_string: str) -> str: pattern = r"Observation:\n(.*?)\n" @@ -394,7 +414,7 @@ class JediAgent3B: max_tries=10, ) def call_llm(self, payload, model): - if model.startswith("gpt"): + if model.startswith("gpt") or model.startswith("o3"): headers = { "Content-Type": "application/json", "Authorization": f"Bearer {OPENAI_API_KEY}", diff --git a/mm_agents/jedi_7b_agent.py b/mm_agents/jedi_7b_agent.py index ba4bed9..8894f78 100644 --- a/mm_agents/jedi_7b_agent.py +++ b/mm_agents/jedi_7b_agent.py @@ -144,17 +144,27 @@ class JediAgent7B: }, ], }) - - planner_response = self.call_llm( - { - "model": self.planner_model, - "messages": messages, - "max_tokens": self.max_tokens, - "top_p": self.top_p, - "temperature": self.temperature, - }, - self.planner_model, - ) + + if self.planner_model == "o3": + planner_response = self.call_llm( + { + "model": self.planner_model, + "messages": messages, + "max_completion_tokens": self.max_tokens + }, + self.planner_model, + ) + else: + planner_response = self.call_llm( + { + "model": self.planner_model, + "messages": messages, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "temperature": self.temperature, + }, + self.planner_model, + ) logger.info(f"Planner Output: {planner_response}") codes = self.parse_code_from_planner_response(planner_response) @@ -169,16 +179,26 @@ class JediAgent7B: {"type": "text", "text": "You didn't generate valid actions. Please try again."} ] }) - planner_response = self.call_llm( - { - "model": self.planner_model, - "messages": messages, - "max_tokens": self.max_tokens, - "top_p": self.top_p, - "temperature": self.temperature, - }, - self.planner_model, - ) + if self.planner_model == "o3": + planner_response = self.call_llm( + { + "model": self.planner_model, + "messages": messages, + "max_completion_tokens": self.max_tokens + }, + self.planner_model, + ) + else: + planner_response = self.call_llm( + { + "model": self.planner_model, + "messages": messages, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "temperature": self.temperature, + }, + self.planner_model, + ) logger.info(f"Retry Planner Output: {planner_response}") codes = self.parse_code_from_planner_response(planner_response) retry_count += 1 @@ -203,7 +223,7 @@ class JediAgent7B: self.thoughts.append(thought) self.observation_captions.append(observation_caption) self.current_step += 1 - return planner_response, pyautogui_actions, {} + return planner_response, pyautogui_actions def parse_observation_caption_from_planner_response(self, input_string: str) -> str: pattern = r"Observation:\n(.*?)\n" @@ -379,7 +399,7 @@ class JediAgent7B: max_tries=10, ) def call_llm(self, payload, model): - if model.startswith("gpt"): + if model.startswith("gpt") or model.startswith("o3"): headers = { "Content-Type": "application/json", "Authorization": f"Bearer {OPENAI_API_KEY}"