feat: update jedi agent with support for o3 as planner

This commit is contained in:
MillanK0817
2025-07-30 14:06:37 +08:00
parent 99fa3b7cb9
commit 4ae9d41da4
2 changed files with 85 additions and 45 deletions

View File

@@ -146,16 +146,26 @@ class JediAgent3B:
], ],
}) })
planner_response = self.call_llm( if self.planner_model == "o3":
{ planner_response = self.call_llm(
"model": self.planner_model, {
"messages": messages, "model": self.planner_model,
"max_tokens": self.max_tokens, "messages": messages,
"top_p": self.top_p, "max_completion_tokens": self.max_tokens
"temperature": self.temperature, },
}, self.planner_model,
self.planner_model, )
) else:
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Planner Output: {planner_response}") logger.info(f"Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response) codes = self.parse_code_from_planner_response(planner_response)
@@ -170,16 +180,26 @@ class JediAgent3B:
{"type": "text", "text": "You didn't generate valid actions. Please try again."} {"type": "text", "text": "You didn't generate valid actions. Please try again."}
] ]
}) })
planner_response = self.call_llm( if self.planner_model == "o3":
{ planner_response = self.call_llm(
"model": self.planner_model, {
"messages": messages, "model": self.planner_model,
"max_tokens": self.max_tokens, "messages": messages,
"top_p": self.top_p, "max_completion_tokens": self.max_tokens
"temperature": self.temperature, },
}, self.planner_model,
self.planner_model, )
) else:
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Retry Planner Output: {planner_response}") logger.info(f"Retry Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response) codes = self.parse_code_from_planner_response(planner_response)
retry_count += 1 retry_count += 1
@@ -204,7 +224,7 @@ class JediAgent3B:
self.thoughts.append(thought) self.thoughts.append(thought)
self.observation_captions.append(observation_caption) self.observation_captions.append(observation_caption)
self.current_step += 1 self.current_step += 1
return planner_response, pyautogui_actions, {} return planner_response, pyautogui_actions
def parse_observation_caption_from_planner_response(self, input_string: str) -> str: def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
pattern = r"Observation:\n(.*?)\n" pattern = r"Observation:\n(.*?)\n"
@@ -394,7 +414,7 @@ class JediAgent3B:
max_tries=10, max_tries=10,
) )
def call_llm(self, payload, model): def call_llm(self, payload, model):
if model.startswith("gpt"): if model.startswith("gpt") or model.startswith("o3"):
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}", "Authorization": f"Bearer {OPENAI_API_KEY}",

View File

@@ -144,17 +144,27 @@ class JediAgent7B:
}, },
], ],
}) })
planner_response = self.call_llm( if self.planner_model == "o3":
{ planner_response = self.call_llm(
"model": self.planner_model, {
"messages": messages, "model": self.planner_model,
"max_tokens": self.max_tokens, "messages": messages,
"top_p": self.top_p, "max_completion_tokens": self.max_tokens
"temperature": self.temperature, },
}, self.planner_model,
self.planner_model, )
) else:
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Planner Output: {planner_response}") logger.info(f"Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response) codes = self.parse_code_from_planner_response(planner_response)
@@ -169,16 +179,26 @@ class JediAgent7B:
{"type": "text", "text": "You didn't generate valid actions. Please try again."} {"type": "text", "text": "You didn't generate valid actions. Please try again."}
] ]
}) })
planner_response = self.call_llm( if self.planner_model == "o3":
{ planner_response = self.call_llm(
"model": self.planner_model, {
"messages": messages, "model": self.planner_model,
"max_tokens": self.max_tokens, "messages": messages,
"top_p": self.top_p, "max_completion_tokens": self.max_tokens
"temperature": self.temperature, },
}, self.planner_model,
self.planner_model, )
) else:
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Retry Planner Output: {planner_response}") logger.info(f"Retry Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response) codes = self.parse_code_from_planner_response(planner_response)
retry_count += 1 retry_count += 1
@@ -203,7 +223,7 @@ class JediAgent7B:
self.thoughts.append(thought) self.thoughts.append(thought)
self.observation_captions.append(observation_caption) self.observation_captions.append(observation_caption)
self.current_step += 1 self.current_step += 1
return planner_response, pyautogui_actions, {} return planner_response, pyautogui_actions
def parse_observation_caption_from_planner_response(self, input_string: str) -> str: def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
pattern = r"Observation:\n(.*?)\n" pattern = r"Observation:\n(.*?)\n"
@@ -379,7 +399,7 @@ class JediAgent7B:
max_tries=10, max_tries=10,
) )
def call_llm(self, payload, model): def call_llm(self, payload, model):
if model.startswith("gpt"): if model.startswith("gpt") or model.startswith("o3"):
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}" "Authorization": f"Bearer {OPENAI_API_KEY}"