Files
sci-gui-agent-benchmark/mm_agents/gpt_4v_agent.py

166 lines
5.4 KiB
Python

import base64
import json
import re
from typing import Dict
import requests
from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def parse_actions_from_string(input_string):
# Search for a JSON string within the input string
actions = []
matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL)
if matches:
# Assuming there's only one match, parse the JSON string into a dictionary
try:
for match in matches:
action_dict = json.loads(match)
actions.append(action_dict)
return actions
except json.JSONDecodeError as e:
return f"Failed to parse JSON: {e}"
else:
matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL)
if matches:
# Assuming there's only one match, parse the JSON string into a dictionary
try:
for match in matches:
action_dict = json.loads(match)
actions.append(action_dict)
return actions
except json.JSONDecodeError as e:
return f"Failed to parse JSON: {e}"
else:
try:
action_dict = json.loads(input_string)
return [action_dict]
except json.JSONDecodeError as e:
raise ValueError("Invalid response format: " + input_string)
def parse_code_from_string(input_string):
# This regular expression will match both ```code``` and ```python code```
# and capture the `code` part. It uses a non-greedy match for the content inside.
pattern = r"```(?:\w+\s+)?(.*?)```"
# Find all non-overlapping matches in the string
matches = re.findall(pattern, input_string, re.DOTALL)
# The regex above captures the content inside the triple backticks.
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
# so the code inside backticks can span multiple lines.
# matches now contains all the captured code snippets
return matches
class GPT4v_Agent:
def __init__(self, api_key, model="gpt-4-vision-preview", max_tokens=300, action_space="computer_13"):
self.model = model
self.max_tokens = max_tokens
self.action_space = action_space
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
self.trajectory = [
{
"role": "system",
"content": [
{
"type": "text",
"text": {
"computer_13": SYS_PROMPT_ACTION,
"pyautogui": SYS_PROMPT_CODE
}[action_space]
},
]
}
]
def predict(self, obs: Dict):
"""
Predict the next action(s) based on the current observation.
"""
base64_image = encode_image(obs["screenshot"])
self.trajectory.append({
"role": "user",
"content": [
{
"type": "text",
"text": "To accomplish the task '{}' and given the current screenshot, what's the next step?".format(
obs["instruction"])
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
})
traj_to_show = []
for i in range(len(self.trajectory)):
traj_to_show.append(self.trajectory[i]["content"][0]["text"])
if len(self.trajectory[i]["content"]) > 1:
traj_to_show.append("screenshot_obs")
print("Trajectory:", traj_to_show)
payload = {
"model": self.model,
"messages": self.trajectory,
"max_tokens": self.max_tokens
}
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload)
try:
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
except:
# todo: add error handling
print("Failed to parse action from response:", response.json()['choices'][0]['message']['content'])
actions = None
return actions
def parse_actions(self, response: str):
# response example
"""
```json
{
"action_type": "CLICK",
"click_type": "RIGHT"
}
```
"""
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)
elif self.action_space == "pyautogui":
actions = parse_code_from_string(response)
# add action into the trajectory
self.trajectory.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": response
},
]
})
return actions