Initialize GPT-4v agent, and prompt for current observation space
This commit is contained in:
128
mm_agents/gpt_4v_agent.py
Normal file
128
mm_agents/gpt_4v_agent.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
import base64
|
||||
from desktop_env.envs.desktop_env import Action, MouseClick
|
||||
import json5
|
||||
import requests
|
||||
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
|
||||
class GPT4v_Agent:
|
||||
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300):
|
||||
self.instruction = instruction
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
# load prompt from file
|
||||
self.prompt = ""
|
||||
with open("gpt_4v_prompt.txt", "r") as f:
|
||||
self.prompt = f.read()
|
||||
|
||||
self.trajectory = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.prompt
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def predict(self, obs):
|
||||
base64_image = encode_image(obs)
|
||||
self.trajectory.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's the next step for instruction '{}'?".format(self.instruction)
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self.trajectory,
|
||||
"max_tokens": self.max_tokens
|
||||
}
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload)
|
||||
|
||||
action = self.parse_action(response.json()['choices'][0]['message']['content'])
|
||||
|
||||
return action
|
||||
|
||||
def parse_action(self, response: str):
|
||||
# response example
|
||||
"""
|
||||
```json
|
||||
{
|
||||
"action_type": "CLICK",
|
||||
"click_type": "RIGHT"
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
# parse from the response
|
||||
if response.startswith("```json"):
|
||||
action = json5.loads(response[7:-3])
|
||||
elif response.startswith("```"):
|
||||
action = json5.loads(response[3:-3])
|
||||
else:
|
||||
action = json5.loads(response)
|
||||
|
||||
# add action into the trajectory
|
||||
self.trajectory.append({
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": response
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
# parse action
|
||||
parsed_action = {}
|
||||
action_type = Action[action['action_type']].value
|
||||
parsed_action["action_type"] = action_type
|
||||
|
||||
if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value:
|
||||
parsed_action["click_type"] = MouseClick[action['click_type']].value
|
||||
|
||||
if action_type == Action.MOUSE_MOVE.value:
|
||||
parsed_action["x"] = action["x"]
|
||||
parsed_action["y"] = action["y"]
|
||||
|
||||
# fixme: could these two actions be merged??
|
||||
if action_type == Action.KEY.value:
|
||||
parsed_action["key"] = [ord(c) for c in action["key"]]
|
||||
|
||||
if action_type == Action.TYPE.value:
|
||||
parsed_action["text"] = [ord(c) for c in action["text"]]
|
||||
|
||||
return parsed_action
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# OpenAI API Key
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
agent = GPT4v_Agent(api_key=api_key, instruction="Open Google Sheet")
|
||||
print(agent.predict(obs="stackoverflow.png"))
|
||||
|
||||
52
mm_agents/gpt_4v_prompt.txt
Normal file
52
mm_agents/gpt_4v_prompt.txt
Normal file
@@ -0,0 +1,52 @@
|
||||
You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection.
|
||||
For each step, you will get an observation of an image, which is the screenshot of the computer screen. And you will predict the action of the computer based on the image.
|
||||
Here is the description of the action space:
|
||||
|
||||
Firstly you need to predict the class of your action, select from one below:
|
||||
- **MOUSE_MOVE**: move the mouse to a specific position
|
||||
- **CLICK**: click on the screen
|
||||
- **MOUSE_DOWN**: press the mouse button
|
||||
- **MOUSE_UP**: release the mouse button
|
||||
- **KEY**: press a key on the keyboard
|
||||
- **KEY_DOWN**: press a key on the keyboard
|
||||
- **KEY_UP**: release a key on the keyboard
|
||||
- **TYPE**: type a string on the keyboard
|
||||
|
||||
Then you need to predict the parameters of your action:
|
||||
- For MOUSE_MOVE, you need to predict the x and y coordinate of the mouse cursor
|
||||
for example, format as:
|
||||
```
|
||||
{
|
||||
"action_type": "MOUSE_MOVE",
|
||||
"x": 1319.11,
|
||||
"y": 65.06
|
||||
}
|
||||
```
|
||||
- For [CLICK, MOUSE_DOWN, MOUSE_UP], you need to specify the click_type as well, select from [LEFT, MIDDLE, RIGHT, WHEEL_UP, WHEEL_DOWN], which means you click the left button, middle button, right button, wheel up or wheel down of your mouse:
|
||||
for example, format as:
|
||||
```
|
||||
{
|
||||
"action_type": "CLICK",
|
||||
"click_type": "LEFT"
|
||||
}
|
||||
```
|
||||
- For [KEY, KEY_DOWN, KEY_UP, TYPE], you need to choose a(multiple) key(s) from the keyboard, select from [A-Z, 0-9, F1-F12, ESC, TAB, ENTER, SPACE, BACKSPACE, SHIFT, CTRL, ALT, UP, DOWN, LEFT, RIGHT, CAPSLOCK, NUMLOCK, SCROLLLOCK, INSERT, DELETE, HOME, END, PAGEUP, PAGEDOWN]:
|
||||
for example, format as:
|
||||
```
|
||||
{
|
||||
"action_type": "TYPE",
|
||||
"text": [
|
||||
"w",
|
||||
"i",
|
||||
"k",
|
||||
"i",
|
||||
"p",
|
||||
"e",
|
||||
"d",
|
||||
"i",
|
||||
"a"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
For every setup, you should only return the action_type and the parameters of your action as a dict, without any other things.
|
||||
Reference in New Issue
Block a user