pass test case
This commit is contained in:
@@ -4,7 +4,6 @@ import os
|
||||
import re
|
||||
import base64
|
||||
import PIL.Image
|
||||
from desktop_env.envs.desktop_env import Action, MouseClick
|
||||
import json
|
||||
import requests
|
||||
|
||||
@@ -15,7 +14,7 @@ import argparse
|
||||
from seem.modeling.BaseModel import BaseModel as BaseModel_Seem
|
||||
from seem.utils.distributed import init_distributed as init_distributed_seem
|
||||
from seem.modeling import build_model as build_model_seem
|
||||
from task_adapter.seem.tasks import interactive_seem_m2m_auto, inference_seem_pano, inference_seem_interactive
|
||||
from task_adapter.seem.tasks import inference_seem_pano
|
||||
|
||||
# semantic sam
|
||||
from semantic_sam.BaseModel import BaseModel
|
||||
@@ -28,9 +27,9 @@ from task_adapter.semantic_sam.tasks import inference_semsam_m2m_auto, prompt_sw
|
||||
# sam
|
||||
from segment_anything import sam_model_registry
|
||||
from task_adapter.sam.tasks.inference_sam_m2m_auto import inference_sam_m2m_auto
|
||||
from task_adapter.sam.tasks.inference_sam_m2m_interactive import inference_sam_m2m_interactive
|
||||
|
||||
from scipy.ndimage import label
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
|
||||
SYS_PROMPT = '''
|
||||
@@ -45,7 +44,7 @@ Firstly you need to predict the class of your action, select from one below:
|
||||
for example, format as:
|
||||
```
|
||||
{
|
||||
"action_type": "MOUSE_MOVE",
|
||||
"action_type": "CLICK",
|
||||
"label": 7
|
||||
}
|
||||
```
|
||||
@@ -107,7 +106,11 @@ def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs
|
||||
else:
|
||||
level = [6, 1, 2, 3, 4, 5]
|
||||
|
||||
label_mode = 'a' if label_mode == 'Alphabet' else '1'
|
||||
if label_mode == 'Alphabet':
|
||||
label_mode = 'a'
|
||||
else:
|
||||
label_mode = '1'
|
||||
|
||||
text_size, hole_scale, island_scale = 1280, 100, 100
|
||||
text, text_part, text_thresh = '', '', '0.0'
|
||||
|
||||
@@ -126,11 +129,15 @@ def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs
|
||||
model = model_seem
|
||||
output, mask = inference_seem_pano(model, image, text_size, label_mode, alpha, anno_mode)
|
||||
|
||||
return output
|
||||
return output, mask
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image):
|
||||
return base64.b64encode(image).decode('utf-8')
|
||||
pil_img = PIL.Image.fromarray(image)
|
||||
buff = BytesIO()
|
||||
pil_img.save(buff, format="JPEG")
|
||||
new_image_string = base64.b64encode(buff.getvalue()).decode("utf-8")
|
||||
return new_image_string
|
||||
|
||||
def parse_actions_from_string(input_string):
|
||||
# Search for a JSON string within the input string
|
||||
@@ -187,7 +194,8 @@ class GPT4v_Agent:
|
||||
]
|
||||
|
||||
def predict(self, obs):
|
||||
obs = inference(obs, slider=2.0, mode="Automatic", alpha=0.1, label_mode="Number", anno_mode=["Mark", "Box"])
|
||||
obs, mask = inference(obs, slider=3.0, mode="Automatic", alpha=0.1, label_mode="Number", anno_mode=["Mark", "Box"])
|
||||
PIL.Image.fromarray(obs).save("desktop.jpeg")
|
||||
base64_image = encode_image(obs)
|
||||
self.trajectory.append({
|
||||
"role": "user",
|
||||
@@ -218,14 +226,14 @@ class GPT4v_Agent:
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload)
|
||||
|
||||
try:
|
||||
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
||||
actions = self.parse_actions(response.json()['choices'][0]['message']['content'], mask)
|
||||
except:
|
||||
print("Failed to parse action from response:", response.json()['choices'][0]['message']['content'])
|
||||
actions = None
|
||||
|
||||
return actions
|
||||
|
||||
def parse_actions(self, response: str):
|
||||
def parse_actions(self, response: str, mask):
|
||||
# response example
|
||||
"""
|
||||
```json
|
||||
@@ -238,6 +246,7 @@ class GPT4v_Agent:
|
||||
|
||||
# parse from the response
|
||||
actions = parse_actions_from_string(response)
|
||||
print(actions)
|
||||
|
||||
# add action into the trajectory
|
||||
self.trajectory.append({
|
||||
@@ -253,24 +262,14 @@ class GPT4v_Agent:
|
||||
# parse action
|
||||
parsed_actions = []
|
||||
for action in actions:
|
||||
parsed_action = {}
|
||||
action_type = Action[action['action_type']].value
|
||||
parsed_action["action_type"] = action_type
|
||||
action_type = action['action_type']
|
||||
if action_type == "CLICK":
|
||||
label = int(action['label'])
|
||||
x, y, w, h = mask[label-1]['bbox']
|
||||
parsed_actions.append({"action_type": action_type, "x": int(x + w//2) , "y": int(y + h//2)})
|
||||
|
||||
if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value:
|
||||
parsed_action["click_type"] = MouseClick[action['click_type']].value
|
||||
|
||||
if action_type == Action.MOUSE_MOVE.value:
|
||||
parsed_action["x"] = action["x"]
|
||||
parsed_action["y"] = action["y"]
|
||||
|
||||
if action_type == Action.KEY.value:
|
||||
parsed_action["key"] = action["key"] # handle the condition of single key and multiple keys
|
||||
|
||||
if action_type == Action.TYPE.value:
|
||||
parsed_action["text"] = action["text"]
|
||||
|
||||
parsed_actions.append(parsed_action)
|
||||
if action_type == "TYPE":
|
||||
parsed_actions.append({"action_type": action_type, "text": action["text"]})
|
||||
|
||||
return parsed_actions
|
||||
|
||||
@@ -279,6 +278,6 @@ if __name__ == '__main__':
|
||||
# OpenAI API Key
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
agent = GPT4v_Agent(api_key=api_key, instruction="Open Google Sheet")
|
||||
obs = PIL.Image.open('stackoverflow.png')
|
||||
agent = GPT4v_Agent(api_key=api_key, instruction="Open Firefox")
|
||||
obs = PIL.Image.open('desktop.png')
|
||||
print(agent.predict(obs=obs))
|
||||
Reference in New Issue
Block a user