Files
sci-gui-agent-benchmark/mm_agents/gemini_pro_vision_agent.py
2024-01-20 18:55:21 +08:00

116 lines
3.8 KiB
Python

# todo: needs to be refactored
import time
from typing import Dict, List
import PIL.Image
import google.generativeai as genai
from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string
from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
class GeminiProV_Agent:
def __init__(self, api_key, instruction, model='gemini-pro-vision', max_tokens=300, temperature=0.0,
action_space="computer_13"):
genai.configure(api_key=api_key)
self.instruction = instruction
self.model = genai.GenerativeModel(model)
self.max_tokens = max_tokens
self.temperature = temperature
self.action_space = action_space
self.trajectory = [
{
"role": "system",
"parts": [
{
"computer_13": SYS_PROMPT_ACTION,
"pyautogui": SYS_PROMPT_CODE
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
]
}
]
def predict(self, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
Only support single-round conversation, only fill-in the last desktop screenshot.
"""
img = PIL.Image.open(obs["screenshot"])
self.trajectory.append({
"role": "user",
"parts": ["What's the next step that you will do to help with the task?", img]
})
# todo: Remove this step once the Gemini supports multi-round conversation
all_message_str = ""
for i in range(len(self.trajectory)):
if i == 0:
all_message_template = "<|im_start|>system\n{}\n<|im_end|>\n"
elif i % 2 == 1:
all_message_template = "<|im_start|>user\n{}\n<|im_end|>\n"
else:
all_message_template = "<|im_start|>assistant\n{}\n<|im_end|>\n"
all_message_str += all_message_template.format(self.trajectory[i]["parts"][0])
message_for_gemini = {
"role": "user",
"parts": [all_message_str, img]
}
traj_to_show = []
for i in range(len(self.trajectory)):
traj_to_show.append(self.trajectory[i]["parts"][0])
if len(self.trajectory[i]["parts"]) > 1:
traj_to_show.append("screenshot_obs")
print("Trajectory:", traj_to_show)
while True:
try:
response = self.model.generate_content(
message_for_gemini,
generation_config={
"max_output_tokens": self.max_tokens,
"temperature": self.temperature
}
)
break
except:
print("Failed to generate response, retrying...")
time.sleep(5)
pass
try:
response_text = response.text
except:
return []
try:
actions = self.parse_actions(response_text)
except:
print("Failed to parse action from response:", response_text)
actions = []
return actions
def parse_actions(self, response: str):
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)
elif self.action_space == "pyautogui":
actions = parse_code_from_string(response)
else:
raise ValueError("Invalid action space: " + self.action_space)
# add action into the trajectory
self.trajectory.append({
"role": "assistant",
"parts": [response]
})
return actions