85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
from typing import Dict
|
|
|
|
import PIL.Image
|
|
import google.generativeai as genai
|
|
|
|
from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string
|
|
from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
|
|
from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
|
|
|
|
|
|
class GeminiPro_Agent:
|
|
def __init__(self, api_key, model='gemini-pro-vision', max_tokens=300, action_space="computer_13"):
|
|
genai.configure(api_key)
|
|
self.model = genai.GenerativeModel(model)
|
|
self.max_tokens = max_tokens
|
|
self.action_space = action_space
|
|
|
|
self.trajectory = [
|
|
{
|
|
"role": "system",
|
|
"parts": [
|
|
{
|
|
"computer_13": SYS_PROMPT_ACTION,
|
|
"pyautogui": SYS_PROMPT_CODE
|
|
}[action_space]
|
|
]
|
|
}
|
|
]
|
|
|
|
def predict(self, obs: Dict):
|
|
"""
|
|
Predict the next action(s) based on the current observation.
|
|
"""
|
|
img = PIL.Image.open(obs["screenshot"])
|
|
self.trajectory.append({
|
|
"role": "user",
|
|
"parts": ["To accomplish the task '{}' and given the current screenshot, what's the next step?".format(
|
|
obs["instruction"]), img]
|
|
})
|
|
|
|
traj_to_show = []
|
|
for i in range(len(self.trajectory)):
|
|
traj_to_show.append(self.trajectory[i]["parts"][0])
|
|
if len(self.trajectory[i]["parts"]) > 1:
|
|
traj_to_show.append("screenshot_obs")
|
|
|
|
print("Trajectory:", traj_to_show)
|
|
|
|
response = self.model.generate_content(self.trajectory, max_tokens=self.max_tokens)
|
|
|
|
try:
|
|
# fixme: change to fit the new response format from gemini pro
|
|
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
|
except:
|
|
# todo: add error handling
|
|
print("Failed to parse action from response:", response.json()['choices'][0]['message']['content'])
|
|
actions = None
|
|
|
|
return actions
|
|
|
|
def parse_actions(self, response: str):
|
|
# response example
|
|
"""
|
|
```json
|
|
{
|
|
"action_type": "CLICK",
|
|
"click_type": "RIGHT"
|
|
}
|
|
```
|
|
"""
|
|
|
|
# parse from the response
|
|
if self.action_space == "computer_13":
|
|
actions = parse_actions_from_string(response)
|
|
elif self.action_space == "pyautogui":
|
|
actions = parse_code_from_string(response)
|
|
|
|
# add action into the trajectory
|
|
self.trajectory.append({
|
|
"role": "assistant",
|
|
"parts": [response]
|
|
})
|
|
|
|
return actions
|