update SoM_agent

This commit is contained in:
Hilbert-Johnson
2023-12-31 19:13:17 +08:00
parent f04e625ad9
commit 7560f4dc46
19 changed files with 3729 additions and 49 deletions

View File

@@ -3,10 +3,10 @@
import os
import re
import base64
import PIL.Image
from desktop_env.envs.desktop_env import Action, MouseClick
import json
import requests
from mm_agents.gpt_4v_prompt import SYS_PROMPT
import torch
import argparse
@@ -33,9 +33,37 @@ from task_adapter.sam.tasks.inference_sam_m2m_interactive import inference_sam_m
from scipy.ndimage import label
import numpy as np
SYS_PROMPT = '''
You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection.
For each step, you will get an observation of an image, which is the screenshot of the computer screen. And you will predict the action of the computer based on the image.
Firstly you need to predict the class of your action, select from one below:
- **CLICK**: click on the screen with the specified integer label
- **TYPE**: type a string on the keyboard
- For CLICK, you need to predict the correct integer label shown on the screenshot
for example, format as:
```
{
"action_type": "MOUSE_MOVE",
"label": 7
}
```
- For TYPE, you need to specify the text you want to type
for example, format as:
```
{
"action_type": "TYPE",
"text": "hello world"
}
```
For every step, you should only return the action_type and the parameters of your action as a dict, without any other things. You MUST wrap the dict with backticks (\`).
You can predict multiple actions at one step, but you should only return one action for each step.
You MUST choose and ONLY CHOOSE from the action space above, otherwise your action will be considered as invalid and you will get a penalty.
'''
build args
'''
# build args
semsam_cfg = "configs/semantic_sam_only_sa-1b_swinL.yaml"
seem_cfg = "configs/seem_focall_unicl_lang_v1.yaml"
@@ -47,9 +75,7 @@ opt_semsam = load_opt_from_config_file(semsam_cfg)
opt_seem = load_opt_from_config_file(seem_cfg)
opt_seem = init_distributed_seem(opt_seem)
'''
build model
'''
# build model
model_semsam = BaseModel(opt_semsam, build_model(opt_semsam)).from_pretrained(semsam_ckpt).eval().cuda()
model_sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt).eval().cuda()
model_seem = BaseModel_Seem(opt_seem, build_model_seem(opt_seem)).from_pretrained(seem_ckpt).eval().cuda()
@@ -65,65 +91,46 @@ def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs
elif slider > 2.5:
model_name = 'sam'
else:
if mode == 'Automatic':
model_name = 'semantic-sam'
if slider < 1.5 + 0.14:
level = [1]
elif slider < 1.5 + 0.28:
level = [2]
elif slider < 1.5 + 0.42:
level = [3]
elif slider < 1.5 + 0.56:
level = [4]
elif slider < 1.5 + 0.70:
level = [5]
elif slider < 1.5 + 0.84:
level = [6]
else:
level = [6, 1, 2, 3, 4, 5]
model_name = 'semantic-sam'
if slider < 1.5 + 0.14:
level = [1]
elif slider < 1.5 + 0.28:
level = [2]
elif slider < 1.5 + 0.42:
level = [3]
elif slider < 1.5 + 0.56:
level = [4]
elif slider < 1.5 + 0.70:
level = [5]
elif slider < 1.5 + 0.84:
level = [6]
else:
model_name = 'sam'
level = [6, 1, 2, 3, 4, 5]
if label_mode == 'Alphabet':
label_mode = 'a'
else:
label_mode = '1'
text_size, hole_scale, island_scale = 640, 100, 100
label_mode = 'a' if label_mode == 'Alphabet' else '1'
text_size, hole_scale, island_scale = 1280, 100, 100
text, text_part, text_thresh = '', '', '0.0'
with torch.autocast(device_type='cuda', dtype=torch.float16):
semantic = False
if mode == "Interactive":
labeled_array, num_features = label(np.asarray(image['mask'].convert('L')))
spatial_masks = torch.stack([torch.from_numpy(labeled_array == i+1) for i in range(num_features)])
if model_name == 'semantic-sam':
model = model_semsam
output, mask = inference_semsam_m2m_auto(model, image['image'], level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs)
output, mask = inference_semsam_m2m_auto(model, image, level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs)
elif model_name == 'sam':
model = model_sam
if mode == "Automatic":
output, mask = inference_sam_m2m_auto(model, image['image'], text_size, label_mode, alpha, anno_mode)
elif mode == "Interactive":
output, mask = inference_sam_m2m_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
output, mask = inference_sam_m2m_auto(model, image, text_size, label_mode, alpha, anno_mode)
elif model_name == 'seem':
model = model_seem
if mode == "Automatic":
output, mask = inference_seem_pano(model, image['image'], text_size, label_mode, alpha, anno_mode)
elif mode == "Interactive":
output, mask = inference_seem_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
output, mask = inference_seem_pano(model, image, text_size, label_mode, alpha, anno_mode)
return output
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def encode_image(image):
return base64.b64encode(image).decode('utf-8')
def parse_actions_from_string(input_string):
# Search for a JSON string within the input string
@@ -156,7 +163,6 @@ def parse_actions_from_string(input_string):
except json.JSONDecodeError as e:
raise ValueError("Invalid response format: " + input_string)
class GPT4v_Agent:
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300):
self.instruction = instruction
@@ -181,7 +187,7 @@ class GPT4v_Agent:
]
def predict(self, obs):
obs = inference(obs, slider=2.0, mode="Automatic", alpha=0.1, label_mode="Alphabet", anno_mode=["Mask", "Mark"])
obs = inference(obs, slider=2.0, mode="Automatic", alpha=0.1, label_mode="Number", anno_mode=["Mark", "Box"])
base64_image = encode_image(obs)
self.trajectory.append({
"role": "user",
@@ -274,4 +280,5 @@ if __name__ == '__main__':
api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key, instruction="Open Google Sheet")
print(agent.predict(obs="stackoverflow.png"))
obs = PIL.Image.open('stackoverflow.png')
print(agent.predict(obs=obs))