feat: update jedi agent with support for o3 as planner

This commit is contained in:
MillanK0817
2025-07-30 14:06:37 +08:00
parent 99fa3b7cb9
commit 4ae9d41da4
2 changed files with 85 additions and 45 deletions

View File

@@ -146,6 +146,16 @@ class JediAgent3B:
], ],
}) })
if self.planner_model == "o3":
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_completion_tokens": self.max_tokens
},
self.planner_model,
)
else:
planner_response = self.call_llm( planner_response = self.call_llm(
{ {
"model": self.planner_model, "model": self.planner_model,
@@ -170,6 +180,16 @@ class JediAgent3B:
{"type": "text", "text": "You didn't generate valid actions. Please try again."} {"type": "text", "text": "You didn't generate valid actions. Please try again."}
] ]
}) })
if self.planner_model == "o3":
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_completion_tokens": self.max_tokens
},
self.planner_model,
)
else:
planner_response = self.call_llm( planner_response = self.call_llm(
{ {
"model": self.planner_model, "model": self.planner_model,
@@ -204,7 +224,7 @@ class JediAgent3B:
self.thoughts.append(thought) self.thoughts.append(thought)
self.observation_captions.append(observation_caption) self.observation_captions.append(observation_caption)
self.current_step += 1 self.current_step += 1
return planner_response, pyautogui_actions, {} return planner_response, pyautogui_actions
def parse_observation_caption_from_planner_response(self, input_string: str) -> str: def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
pattern = r"Observation:\n(.*?)\n" pattern = r"Observation:\n(.*?)\n"
@@ -394,7 +414,7 @@ class JediAgent3B:
max_tries=10, max_tries=10,
) )
def call_llm(self, payload, model): def call_llm(self, payload, model):
if model.startswith("gpt"): if model.startswith("gpt") or model.startswith("o3"):
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}", "Authorization": f"Bearer {OPENAI_API_KEY}",

View File

@@ -145,6 +145,16 @@ class JediAgent7B:
], ],
}) })
if self.planner_model == "o3":
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_completion_tokens": self.max_tokens
},
self.planner_model,
)
else:
planner_response = self.call_llm( planner_response = self.call_llm(
{ {
"model": self.planner_model, "model": self.planner_model,
@@ -169,6 +179,16 @@ class JediAgent7B:
{"type": "text", "text": "You didn't generate valid actions. Please try again."} {"type": "text", "text": "You didn't generate valid actions. Please try again."}
] ]
}) })
if self.planner_model == "o3":
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_completion_tokens": self.max_tokens
},
self.planner_model,
)
else:
planner_response = self.call_llm( planner_response = self.call_llm(
{ {
"model": self.planner_model, "model": self.planner_model,
@@ -203,7 +223,7 @@ class JediAgent7B:
self.thoughts.append(thought) self.thoughts.append(thought)
self.observation_captions.append(observation_caption) self.observation_captions.append(observation_caption)
self.current_step += 1 self.current_step += 1
return planner_response, pyautogui_actions, {} return planner_response, pyautogui_actions
def parse_observation_caption_from_planner_response(self, input_string: str) -> str: def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
pattern = r"Observation:\n(.*?)\n" pattern = r"Observation:\n(.*?)\n"
@@ -379,7 +399,7 @@ class JediAgent7B:
max_tries=10, max_tries=10,
) )
def call_llm(self, payload, model): def call_llm(self, payload, model):
if model.startswith("gpt"): if model.startswith("gpt") or model.startswith("o3"):
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}" "Authorization": f"Bearer {OPENAI_API_KEY}"