feat: update jedi agent with support for o3 as planner
This commit is contained in:
@@ -146,16 +146,26 @@ class JediAgent3B:
|
|||||||
],
|
],
|
||||||
})
|
})
|
||||||
|
|
||||||
planner_response = self.call_llm(
|
if self.planner_model == "o3":
|
||||||
{
|
planner_response = self.call_llm(
|
||||||
"model": self.planner_model,
|
{
|
||||||
"messages": messages,
|
"model": self.planner_model,
|
||||||
"max_tokens": self.max_tokens,
|
"messages": messages,
|
||||||
"top_p": self.top_p,
|
"max_completion_tokens": self.max_tokens
|
||||||
"temperature": self.temperature,
|
},
|
||||||
},
|
self.planner_model,
|
||||||
self.planner_model,
|
)
|
||||||
)
|
else:
|
||||||
|
planner_response = self.call_llm(
|
||||||
|
{
|
||||||
|
"model": self.planner_model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
},
|
||||||
|
self.planner_model,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Planner Output: {planner_response}")
|
logger.info(f"Planner Output: {planner_response}")
|
||||||
codes = self.parse_code_from_planner_response(planner_response)
|
codes = self.parse_code_from_planner_response(planner_response)
|
||||||
@@ -170,16 +180,26 @@ class JediAgent3B:
|
|||||||
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
|
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
planner_response = self.call_llm(
|
if self.planner_model == "o3":
|
||||||
{
|
planner_response = self.call_llm(
|
||||||
"model": self.planner_model,
|
{
|
||||||
"messages": messages,
|
"model": self.planner_model,
|
||||||
"max_tokens": self.max_tokens,
|
"messages": messages,
|
||||||
"top_p": self.top_p,
|
"max_completion_tokens": self.max_tokens
|
||||||
"temperature": self.temperature,
|
},
|
||||||
},
|
self.planner_model,
|
||||||
self.planner_model,
|
)
|
||||||
)
|
else:
|
||||||
|
planner_response = self.call_llm(
|
||||||
|
{
|
||||||
|
"model": self.planner_model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
},
|
||||||
|
self.planner_model,
|
||||||
|
)
|
||||||
logger.info(f"Retry Planner Output: {planner_response}")
|
logger.info(f"Retry Planner Output: {planner_response}")
|
||||||
codes = self.parse_code_from_planner_response(planner_response)
|
codes = self.parse_code_from_planner_response(planner_response)
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
@@ -204,7 +224,7 @@ class JediAgent3B:
|
|||||||
self.thoughts.append(thought)
|
self.thoughts.append(thought)
|
||||||
self.observation_captions.append(observation_caption)
|
self.observation_captions.append(observation_caption)
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
return planner_response, pyautogui_actions, {}
|
return planner_response, pyautogui_actions
|
||||||
|
|
||||||
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
|
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
|
||||||
pattern = r"Observation:\n(.*?)\n"
|
pattern = r"Observation:\n(.*?)\n"
|
||||||
@@ -394,7 +414,7 @@ class JediAgent3B:
|
|||||||
max_tries=10,
|
max_tries=10,
|
||||||
)
|
)
|
||||||
def call_llm(self, payload, model):
|
def call_llm(self, payload, model):
|
||||||
if model.startswith("gpt"):
|
if model.startswith("gpt") or model.startswith("o3"):
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
"Authorization": f"Bearer {OPENAI_API_KEY}",
|
||||||
|
|||||||
@@ -144,17 +144,27 @@ class JediAgent7B:
|
|||||||
},
|
},
|
||||||
],
|
],
|
||||||
})
|
})
|
||||||
|
|
||||||
planner_response = self.call_llm(
|
if self.planner_model == "o3":
|
||||||
{
|
planner_response = self.call_llm(
|
||||||
"model": self.planner_model,
|
{
|
||||||
"messages": messages,
|
"model": self.planner_model,
|
||||||
"max_tokens": self.max_tokens,
|
"messages": messages,
|
||||||
"top_p": self.top_p,
|
"max_completion_tokens": self.max_tokens
|
||||||
"temperature": self.temperature,
|
},
|
||||||
},
|
self.planner_model,
|
||||||
self.planner_model,
|
)
|
||||||
)
|
else:
|
||||||
|
planner_response = self.call_llm(
|
||||||
|
{
|
||||||
|
"model": self.planner_model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
},
|
||||||
|
self.planner_model,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Planner Output: {planner_response}")
|
logger.info(f"Planner Output: {planner_response}")
|
||||||
codes = self.parse_code_from_planner_response(planner_response)
|
codes = self.parse_code_from_planner_response(planner_response)
|
||||||
@@ -169,16 +179,26 @@ class JediAgent7B:
|
|||||||
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
|
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
planner_response = self.call_llm(
|
if self.planner_model == "o3":
|
||||||
{
|
planner_response = self.call_llm(
|
||||||
"model": self.planner_model,
|
{
|
||||||
"messages": messages,
|
"model": self.planner_model,
|
||||||
"max_tokens": self.max_tokens,
|
"messages": messages,
|
||||||
"top_p": self.top_p,
|
"max_completion_tokens": self.max_tokens
|
||||||
"temperature": self.temperature,
|
},
|
||||||
},
|
self.planner_model,
|
||||||
self.planner_model,
|
)
|
||||||
)
|
else:
|
||||||
|
planner_response = self.call_llm(
|
||||||
|
{
|
||||||
|
"model": self.planner_model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
},
|
||||||
|
self.planner_model,
|
||||||
|
)
|
||||||
logger.info(f"Retry Planner Output: {planner_response}")
|
logger.info(f"Retry Planner Output: {planner_response}")
|
||||||
codes = self.parse_code_from_planner_response(planner_response)
|
codes = self.parse_code_from_planner_response(planner_response)
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
@@ -203,7 +223,7 @@ class JediAgent7B:
|
|||||||
self.thoughts.append(thought)
|
self.thoughts.append(thought)
|
||||||
self.observation_captions.append(observation_caption)
|
self.observation_captions.append(observation_caption)
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
return planner_response, pyautogui_actions, {}
|
return planner_response, pyautogui_actions
|
||||||
|
|
||||||
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
|
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
|
||||||
pattern = r"Observation:\n(.*?)\n"
|
pattern = r"Observation:\n(.*?)\n"
|
||||||
@@ -379,7 +399,7 @@ class JediAgent7B:
|
|||||||
max_tries=10,
|
max_tries=10,
|
||||||
)
|
)
|
||||||
def call_llm(self, payload, model):
|
def call_llm(self, payload, model):
|
||||||
if model.startswith("gpt"):
|
if model.startswith("gpt") or model.startswith("o3"):
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {OPENAI_API_KEY}"
|
"Authorization": f"Bearer {OPENAI_API_KEY}"
|
||||||
|
|||||||
Reference in New Issue
Block a user