feat: update jedi agent with support for o3 as planner

This commit is contained in:
MillanK0817
2025-07-30 14:06:37 +08:00
parent 99fa3b7cb9
commit 4ae9d41da4
2 changed files with 85 additions and 45 deletions

View File

@@ -146,16 +146,26 @@ class JediAgent3B:
],
})
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
if self.planner_model == "o3":
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_completion_tokens": self.max_tokens
},
self.planner_model,
)
else:
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response)
@@ -170,16 +180,26 @@ class JediAgent3B:
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
]
})
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
if self.planner_model == "o3":
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_completion_tokens": self.max_tokens
},
self.planner_model,
)
else:
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Retry Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response)
retry_count += 1
@@ -204,7 +224,7 @@ class JediAgent3B:
self.thoughts.append(thought)
self.observation_captions.append(observation_caption)
self.current_step += 1
return planner_response, pyautogui_actions, {}
return planner_response, pyautogui_actions
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
pattern = r"Observation:\n(.*?)\n"
@@ -394,7 +414,7 @@ class JediAgent3B:
max_tries=10,
)
def call_llm(self, payload, model):
if model.startswith("gpt"):
if model.startswith("gpt") or model.startswith("o3"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}",

View File

@@ -144,17 +144,27 @@ class JediAgent7B:
},
],
})
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
if self.planner_model == "o3":
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_completion_tokens": self.max_tokens
},
self.planner_model,
)
else:
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response)
@@ -169,16 +179,26 @@ class JediAgent7B:
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
]
})
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
if self.planner_model == "o3":
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_completion_tokens": self.max_tokens
},
self.planner_model,
)
else:
planner_response = self.call_llm(
{
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature,
},
self.planner_model,
)
logger.info(f"Retry Planner Output: {planner_response}")
codes = self.parse_code_from_planner_response(planner_response)
retry_count += 1
@@ -203,7 +223,7 @@ class JediAgent7B:
self.thoughts.append(thought)
self.observation_captions.append(observation_caption)
self.current_step += 1
return planner_response, pyautogui_actions, {}
return planner_response, pyautogui_actions
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
pattern = r"Observation:\n(.*?)\n"
@@ -379,7 +399,7 @@ class JediAgent7B:
max_tries=10,
)
def call_llm(self, payload, model):
if model.startswith("gpt"):
if model.startswith("gpt") or model.startswith("o3"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}"