Files
sci-gui-agent-benchmark/mm_agents/SoM_agent.py
Hilbert-Johnson 8ac88e9617 pass test case
2024-01-02 01:10:46 +08:00

283 lines
9.7 KiB
Python

# fixme: Need to be rewrite on new action space
import os
import re
import base64
import PIL.Image
import json
import requests
import torch
import argparse
# seem
from seem.modeling.BaseModel import BaseModel as BaseModel_Seem
from seem.utils.distributed import init_distributed as init_distributed_seem
from seem.modeling import build_model as build_model_seem
from task_adapter.seem.tasks import inference_seem_pano
# semantic sam
from semantic_sam.BaseModel import BaseModel
from semantic_sam import build_model
from semantic_sam.utils.dist import init_distributed_mode
from semantic_sam.utils.arguments import load_opt_from_config_file
from semantic_sam.utils.constants import COCO_PANOPTIC_CLASSES
from task_adapter.semantic_sam.tasks import inference_semsam_m2m_auto, prompt_switch
# sam
from segment_anything import sam_model_registry
from task_adapter.sam.tasks.inference_sam_m2m_auto import inference_sam_m2m_auto
from scipy.ndimage import label
from io import BytesIO
import numpy as np
SYS_PROMPT = '''
You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection.
For each step, you will get an observation of an image, which is the screenshot of the computer screen. And you will predict the action of the computer based on the image.
Firstly you need to predict the class of your action, select from one below:
- **CLICK**: click on the screen with the specified integer label
- **TYPE**: type a string on the keyboard
- For CLICK, you need to predict the correct integer label shown on the screenshot
for example, format as:
```
{
"action_type": "CLICK",
"label": 7
}
```
- For TYPE, you need to specify the text you want to type
for example, format as:
```
{
"action_type": "TYPE",
"text": "hello world"
}
```
For every step, you should only return the action_type and the parameters of your action as a dict, without any other things. You MUST wrap the dict with backticks (\`).
You can predict multiple actions at one step, but you should only return one action for each step.
You MUST choose and ONLY CHOOSE from the action space above, otherwise your action will be considered as invalid and you will get a penalty.
'''
# build args
semsam_cfg = "configs/semantic_sam_only_sa-1b_swinL.yaml"
seem_cfg = "configs/seem_focall_unicl_lang_v1.yaml"
semsam_ckpt = "./swinl_only_sam_many2many.pth"
sam_ckpt = "./sam_vit_h_4b8939.pth"
seem_ckpt = "./seem_focall_v1.pt"
opt_semsam = load_opt_from_config_file(semsam_cfg)
opt_seem = load_opt_from_config_file(seem_cfg)
opt_seem = init_distributed_seem(opt_seem)
# build model
model_semsam = BaseModel(opt_semsam, build_model(opt_semsam)).from_pretrained(semsam_ckpt).eval().cuda()
model_sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt).eval().cuda()
model_seem = BaseModel_Seem(opt_seem, build_model_seem(opt_seem)).from_pretrained(seem_ckpt).eval().cuda()
with torch.no_grad():
with torch.autocast(device_type='cuda', dtype=torch.float16):
model_seem.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True)
@torch.no_grad()
def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs):
if slider < 1.5:
model_name = 'seem'
elif slider > 2.5:
model_name = 'sam'
else:
model_name = 'semantic-sam'
if slider < 1.5 + 0.14:
level = [1]
elif slider < 1.5 + 0.28:
level = [2]
elif slider < 1.5 + 0.42:
level = [3]
elif slider < 1.5 + 0.56:
level = [4]
elif slider < 1.5 + 0.70:
level = [5]
elif slider < 1.5 + 0.84:
level = [6]
else:
level = [6, 1, 2, 3, 4, 5]
if label_mode == 'Alphabet':
label_mode = 'a'
else:
label_mode = '1'
text_size, hole_scale, island_scale = 1280, 100, 100
text, text_part, text_thresh = '', '', '0.0'
with torch.autocast(device_type='cuda', dtype=torch.float16):
semantic = False
if model_name == 'semantic-sam':
model = model_semsam
output, mask = inference_semsam_m2m_auto(model, image, level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs)
elif model_name == 'sam':
model = model_sam
output, mask = inference_sam_m2m_auto(model, image, text_size, label_mode, alpha, anno_mode)
elif model_name == 'seem':
model = model_seem
output, mask = inference_seem_pano(model, image, text_size, label_mode, alpha, anno_mode)
return output, mask
# Function to encode the image
def encode_image(image):
pil_img = PIL.Image.fromarray(image)
buff = BytesIO()
pil_img.save(buff, format="JPEG")
new_image_string = base64.b64encode(buff.getvalue()).decode("utf-8")
return new_image_string
def parse_actions_from_string(input_string):
# Search for a JSON string within the input string
actions = []
matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL)
if matches:
# Assuming there's only one match, parse the JSON string into a dictionary
try:
for match in matches:
action_dict = json.loads(match)
actions.append(action_dict)
return actions
except json.JSONDecodeError as e:
return f"Failed to parse JSON: {e}"
else:
matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL)
if matches:
# Assuming there's only one match, parse the JSON string into a dictionary
try:
for match in matches:
action_dict = json.loads(match)
actions.append(action_dict)
return actions
except json.JSONDecodeError as e:
return f"Failed to parse JSON: {e}"
else:
try:
action_dict = json.loads(input_string)
return [action_dict]
except json.JSONDecodeError as e:
raise ValueError("Invalid response format: " + input_string)
class GPT4v_Agent:
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300):
self.instruction = instruction
self.model = model
self.max_tokens = max_tokens
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
self.trajectory = [
{
"role": "system",
"content": [
{
"type": "text",
"text": SYS_PROMPT
},
]
}
]
def predict(self, obs):
obs, mask = inference(obs, slider=3.0, mode="Automatic", alpha=0.1, label_mode="Number", anno_mode=["Mark", "Box"])
PIL.Image.fromarray(obs).save("desktop.jpeg")
base64_image = encode_image(obs)
self.trajectory.append({
"role": "user",
"content": [
{
"type": "text",
"text": "What's the next step for instruction '{}'?".format(self.instruction)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
})
traj_to_show = []
for i in range(len(self.trajectory)):
traj_to_show.append(self.trajectory[i]["content"][0]["text"])
if len(self.trajectory[i]["content"]) > 1:
traj_to_show.append("screenshot_obs")
print("Trajectory:", traj_to_show)
payload = {
"model": self.model,
"messages": self.trajectory,
"max_tokens": self.max_tokens
}
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload)
try:
actions = self.parse_actions(response.json()['choices'][0]['message']['content'], mask)
except:
print("Failed to parse action from response:", response.json()['choices'][0]['message']['content'])
actions = None
return actions
def parse_actions(self, response: str, mask):
# response example
"""
```json
{
"action_type": "CLICK",
"click_type": "RIGHT"
}
```
"""
# parse from the response
actions = parse_actions_from_string(response)
print(actions)
# add action into the trajectory
self.trajectory.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": response
},
]
})
# parse action
parsed_actions = []
for action in actions:
action_type = action['action_type']
if action_type == "CLICK":
label = int(action['label'])
x, y, w, h = mask[label-1]['bbox']
parsed_actions.append({"action_type": action_type, "x": int(x + w//2) , "y": int(y + h//2)})
if action_type == "TYPE":
parsed_actions.append({"action_type": action_type, "text": action["text"]})
return parsed_actions
if __name__ == '__main__':
# OpenAI API Key
api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key, instruction="Open Firefox")
obs = PIL.Image.open('desktop.png')
print(agent.predict(obs=obs))