feat: update jedi agent with support for o3 as planner
This commit is contained in:
@@ -146,16 +146,26 @@ class JediAgent3B:
|
||||
],
|
||||
})
|
||||
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
if self.planner_model == "o3":
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_completion_tokens": self.max_tokens
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
else:
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
|
||||
logger.info(f"Planner Output: {planner_response}")
|
||||
codes = self.parse_code_from_planner_response(planner_response)
|
||||
@@ -170,16 +180,26 @@ class JediAgent3B:
|
||||
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
|
||||
]
|
||||
})
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
if self.planner_model == "o3":
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_completion_tokens": self.max_tokens
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
else:
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
logger.info(f"Retry Planner Output: {planner_response}")
|
||||
codes = self.parse_code_from_planner_response(planner_response)
|
||||
retry_count += 1
|
||||
@@ -204,7 +224,7 @@ class JediAgent3B:
|
||||
self.thoughts.append(thought)
|
||||
self.observation_captions.append(observation_caption)
|
||||
self.current_step += 1
|
||||
return planner_response, pyautogui_actions, {}
|
||||
return planner_response, pyautogui_actions
|
||||
|
||||
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
|
||||
pattern = r"Observation:\n(.*?)\n"
|
||||
@@ -394,7 +414,7 @@ class JediAgent3B:
|
||||
max_tries=10,
|
||||
)
|
||||
def call_llm(self, payload, model):
|
||||
if model.startswith("gpt"):
|
||||
if model.startswith("gpt") or model.startswith("o3"):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
||||
|
||||
@@ -144,17 +144,27 @@ class JediAgent7B:
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
|
||||
if self.planner_model == "o3":
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_completion_tokens": self.max_tokens
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
else:
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
|
||||
logger.info(f"Planner Output: {planner_response}")
|
||||
codes = self.parse_code_from_planner_response(planner_response)
|
||||
@@ -169,16 +179,26 @@ class JediAgent7B:
|
||||
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
|
||||
]
|
||||
})
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
if self.planner_model == "o3":
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_completion_tokens": self.max_tokens
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
else:
|
||||
planner_response = self.call_llm(
|
||||
{
|
||||
"model": self.planner_model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
},
|
||||
self.planner_model,
|
||||
)
|
||||
logger.info(f"Retry Planner Output: {planner_response}")
|
||||
codes = self.parse_code_from_planner_response(planner_response)
|
||||
retry_count += 1
|
||||
@@ -203,7 +223,7 @@ class JediAgent7B:
|
||||
self.thoughts.append(thought)
|
||||
self.observation_captions.append(observation_caption)
|
||||
self.current_step += 1
|
||||
return planner_response, pyautogui_actions, {}
|
||||
return planner_response, pyautogui_actions
|
||||
|
||||
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
|
||||
pattern = r"Observation:\n(.*?)\n"
|
||||
@@ -379,7 +399,7 @@ class JediAgent7B:
|
||||
max_tries=10,
|
||||
)
|
||||
def call_llm(self, payload, model):
|
||||
if model.startswith("gpt"):
|
||||
if model.startswith("gpt") or model.startswith("o3"):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {OPENAI_API_KEY}"
|
||||
|
||||
Reference in New Issue
Block a user