Initialize GPT-4v agent, and prompt for current observation space
This commit is contained in:
128
mm_agents/gpt_4v_agent.py
Normal file
128
mm_agents/gpt_4v_agent.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
import base64
|
||||
from desktop_env.envs.desktop_env import Action, MouseClick
|
||||
import json5
|
||||
import requests
|
||||
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
|
||||
class GPT4v_Agent:
|
||||
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300):
|
||||
self.instruction = instruction
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
# load prompt from file
|
||||
self.prompt = ""
|
||||
with open("gpt_4v_prompt.txt", "r") as f:
|
||||
self.prompt = f.read()
|
||||
|
||||
self.trajectory = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.prompt
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def predict(self, obs):
|
||||
base64_image = encode_image(obs)
|
||||
self.trajectory.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's the next step for instruction '{}'?".format(self.instruction)
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self.trajectory,
|
||||
"max_tokens": self.max_tokens
|
||||
}
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload)
|
||||
|
||||
action = self.parse_action(response.json()['choices'][0]['message']['content'])
|
||||
|
||||
return action
|
||||
|
||||
def parse_action(self, response: str):
|
||||
# response example
|
||||
"""
|
||||
```json
|
||||
{
|
||||
"action_type": "CLICK",
|
||||
"click_type": "RIGHT"
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
# parse from the response
|
||||
if response.startswith("```json"):
|
||||
action = json5.loads(response[7:-3])
|
||||
elif response.startswith("```"):
|
||||
action = json5.loads(response[3:-3])
|
||||
else:
|
||||
action = json5.loads(response)
|
||||
|
||||
# add action into the trajectory
|
||||
self.trajectory.append({
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": response
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
# parse action
|
||||
parsed_action = {}
|
||||
action_type = Action[action['action_type']].value
|
||||
parsed_action["action_type"] = action_type
|
||||
|
||||
if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value:
|
||||
parsed_action["click_type"] = MouseClick[action['click_type']].value
|
||||
|
||||
if action_type == Action.MOUSE_MOVE.value:
|
||||
parsed_action["x"] = action["x"]
|
||||
parsed_action["y"] = action["y"]
|
||||
|
||||
# fixme: could these two actions be merged??
|
||||
if action_type == Action.KEY.value:
|
||||
parsed_action["key"] = [ord(c) for c in action["key"]]
|
||||
|
||||
if action_type == Action.TYPE.value:
|
||||
parsed_action["text"] = [ord(c) for c in action["text"]]
|
||||
|
||||
return parsed_action
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# OpenAI API Key
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
agent = GPT4v_Agent(api_key=api_key, instruction="Open Google Sheet")
|
||||
print(agent.predict(obs="stackoverflow.png"))
|
||||
|
||||
Reference in New Issue
Block a user