167 lines
5.3 KiB
Python
167 lines
5.3 KiB
Python
# fixme: Need to be rewrite on new action space
|
|
|
|
import os
|
|
import re
|
|
import base64
|
|
from desktop_env.envs.desktop_env import Action, MouseClick
|
|
import json
|
|
import requests
|
|
from mm_agents.gpt_4v_prompt import SYS_PROMPT
|
|
|
|
|
|
# Function to encode the image
|
|
def encode_image(image_path):
|
|
with open(image_path, "rb") as image_file:
|
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
|
|
|
|
|
def parse_actions_from_string(input_string):
|
|
# Search for a JSON string within the input string
|
|
actions = []
|
|
matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL)
|
|
if matches:
|
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
|
try:
|
|
for match in matches:
|
|
action_dict = json.loads(match)
|
|
actions.append(action_dict)
|
|
return actions
|
|
except json.JSONDecodeError as e:
|
|
return f"Failed to parse JSON: {e}"
|
|
else:
|
|
matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL)
|
|
if matches:
|
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
|
try:
|
|
for match in matches:
|
|
action_dict = json.loads(match)
|
|
actions.append(action_dict)
|
|
return actions
|
|
except json.JSONDecodeError as e:
|
|
return f"Failed to parse JSON: {e}"
|
|
else:
|
|
try:
|
|
action_dict = json.loads(input_string)
|
|
return [action_dict]
|
|
except json.JSONDecodeError as e:
|
|
raise ValueError("Invalid response format: " + input_string)
|
|
|
|
|
|
class GPT4v_Agent:
|
|
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300):
|
|
self.instruction = instruction
|
|
self.model = model
|
|
self.max_tokens = max_tokens
|
|
|
|
self.headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}"
|
|
}
|
|
|
|
self.trajectory = [
|
|
{
|
|
"role": "system",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": SYS_PROMPT
|
|
},
|
|
]
|
|
}
|
|
]
|
|
|
|
def predict(self, obs):
|
|
base64_image = encode_image(obs)
|
|
self.trajectory.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "What's the next step for instruction '{}'?".format(self.instruction)
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{base64_image}"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
traj_to_show = []
|
|
for i in range(len(self.trajectory)):
|
|
traj_to_show.append(self.trajectory[i]["content"][0]["text"])
|
|
if len(self.trajectory[i]["content"]) > 1:
|
|
traj_to_show.append("screenshot_obs")
|
|
print("Trajectory:", traj_to_show)
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": self.trajectory,
|
|
"max_tokens": self.max_tokens
|
|
}
|
|
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload)
|
|
|
|
try:
|
|
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
|
except:
|
|
print("Failed to parse action from response:", response.json()['choices'][0]['message']['content'])
|
|
actions = None
|
|
|
|
return actions
|
|
|
|
def parse_actions(self, response: str):
|
|
# response example
|
|
"""
|
|
```json
|
|
{
|
|
"action_type": "CLICK",
|
|
"click_type": "RIGHT"
|
|
}
|
|
```
|
|
"""
|
|
|
|
# parse from the response
|
|
actions = parse_actions_from_string(response)
|
|
|
|
# add action into the trajectory
|
|
self.trajectory.append({
|
|
"role": "assistant",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": response
|
|
},
|
|
]
|
|
})
|
|
|
|
# parse action
|
|
parsed_actions = []
|
|
for action in actions:
|
|
parsed_action = {}
|
|
action_type = Action[action['action_type']].value
|
|
parsed_action["action_type"] = action_type
|
|
|
|
if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value:
|
|
parsed_action["click_type"] = MouseClick[action['click_type']].value
|
|
|
|
if action_type == Action.MOUSE_MOVE.value:
|
|
parsed_action["x"] = action["x"]
|
|
parsed_action["y"] = action["y"]
|
|
|
|
if action_type == Action.KEY.value:
|
|
parsed_action["key"] = action["key"] # handle the condition of single key and multiple keys
|
|
|
|
if action_type == Action.TYPE.value:
|
|
parsed_action["text"] = action["text"]
|
|
|
|
parsed_actions.append(parsed_action)
|
|
|
|
return parsed_actions
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# OpenAI API Key
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
|
|
agent = GPT4v_Agent(api_key=api_key, instruction="Open Google Sheet")
|
|
print(agent.predict(obs="stackoverflow.png"))
|