Merge remote-tracking branch 'origin/main'
This commit is contained in:
277
mm_agents/SoM_agent.py
Normal file
277
mm_agents/SoM_agent.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# fixme: Need to be rewrite on new action space
|
||||
|
||||
import os
|
||||
import re
|
||||
import base64
|
||||
from desktop_env.envs.desktop_env import Action, MouseClick
|
||||
import json
|
||||
import requests
|
||||
from mm_agents.gpt_4v_prompt import SYS_PROMPT
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
# seem
|
||||
from seem.modeling.BaseModel import BaseModel as BaseModel_Seem
|
||||
from seem.utils.distributed import init_distributed as init_distributed_seem
|
||||
from seem.modeling import build_model as build_model_seem
|
||||
from task_adapter.seem.tasks import interactive_seem_m2m_auto, inference_seem_pano, inference_seem_interactive
|
||||
|
||||
# semantic sam
|
||||
from semantic_sam.BaseModel import BaseModel
|
||||
from semantic_sam import build_model
|
||||
from semantic_sam.utils.dist import init_distributed_mode
|
||||
from semantic_sam.utils.arguments import load_opt_from_config_file
|
||||
from semantic_sam.utils.constants import COCO_PANOPTIC_CLASSES
|
||||
from task_adapter.semantic_sam.tasks import inference_semsam_m2m_auto, prompt_switch
|
||||
|
||||
# sam
|
||||
from segment_anything import sam_model_registry
|
||||
from task_adapter.sam.tasks.inference_sam_m2m_auto import inference_sam_m2m_auto
|
||||
from task_adapter.sam.tasks.inference_sam_m2m_interactive import inference_sam_m2m_interactive
|
||||
|
||||
from scipy.ndimage import label
|
||||
import numpy as np
|
||||
|
||||
'''
|
||||
build args
|
||||
'''
|
||||
semsam_cfg = "configs/semantic_sam_only_sa-1b_swinL.yaml"
|
||||
seem_cfg = "configs/seem_focall_unicl_lang_v1.yaml"
|
||||
|
||||
semsam_ckpt = "./swinl_only_sam_many2many.pth"
|
||||
sam_ckpt = "./sam_vit_h_4b8939.pth"
|
||||
seem_ckpt = "./seem_focall_v1.pt"
|
||||
|
||||
opt_semsam = load_opt_from_config_file(semsam_cfg)
|
||||
opt_seem = load_opt_from_config_file(seem_cfg)
|
||||
opt_seem = init_distributed_seem(opt_seem)
|
||||
|
||||
'''
|
||||
build model
|
||||
'''
|
||||
model_semsam = BaseModel(opt_semsam, build_model(opt_semsam)).from_pretrained(semsam_ckpt).eval().cuda()
|
||||
model_sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt).eval().cuda()
|
||||
model_seem = BaseModel_Seem(opt_seem, build_model_seem(opt_seem)).from_pretrained(seem_ckpt).eval().cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
||||
model_seem.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs):
|
||||
if slider < 1.5:
|
||||
model_name = 'seem'
|
||||
elif slider > 2.5:
|
||||
model_name = 'sam'
|
||||
else:
|
||||
if mode == 'Automatic':
|
||||
model_name = 'semantic-sam'
|
||||
if slider < 1.5 + 0.14:
|
||||
level = [1]
|
||||
elif slider < 1.5 + 0.28:
|
||||
level = [2]
|
||||
elif slider < 1.5 + 0.42:
|
||||
level = [3]
|
||||
elif slider < 1.5 + 0.56:
|
||||
level = [4]
|
||||
elif slider < 1.5 + 0.70:
|
||||
level = [5]
|
||||
elif slider < 1.5 + 0.84:
|
||||
level = [6]
|
||||
else:
|
||||
level = [6, 1, 2, 3, 4, 5]
|
||||
else:
|
||||
model_name = 'sam'
|
||||
|
||||
if label_mode == 'Alphabet':
|
||||
label_mode = 'a'
|
||||
else:
|
||||
label_mode = '1'
|
||||
|
||||
text_size, hole_scale, island_scale = 640, 100, 100
|
||||
text, text_part, text_thresh = '', '', '0.0'
|
||||
|
||||
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
||||
semantic = False
|
||||
|
||||
if mode == "Interactive":
|
||||
labeled_array, num_features = label(np.asarray(image['mask'].convert('L')))
|
||||
spatial_masks = torch.stack([torch.from_numpy(labeled_array == i+1) for i in range(num_features)])
|
||||
|
||||
if model_name == 'semantic-sam':
|
||||
model = model_semsam
|
||||
output, mask = inference_semsam_m2m_auto(model, image['image'], level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs)
|
||||
|
||||
elif model_name == 'sam':
|
||||
model = model_sam
|
||||
if mode == "Automatic":
|
||||
output, mask = inference_sam_m2m_auto(model, image['image'], text_size, label_mode, alpha, anno_mode)
|
||||
elif mode == "Interactive":
|
||||
output, mask = inference_sam_m2m_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
|
||||
|
||||
elif model_name == 'seem':
|
||||
model = model_seem
|
||||
if mode == "Automatic":
|
||||
output, mask = inference_seem_pano(model, image['image'], text_size, label_mode, alpha, anno_mode)
|
||||
elif mode == "Interactive":
|
||||
output, mask = inference_seem_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
|
||||
|
||||
return output
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
|
||||
def parse_actions_from_string(input_string):
|
||||
# Search for a JSON string within the input string
|
||||
actions = []
|
||||
matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL)
|
||||
if matches:
|
||||
# Assuming there's only one match, parse the JSON string into a dictionary
|
||||
try:
|
||||
for match in matches:
|
||||
action_dict = json.loads(match)
|
||||
actions.append(action_dict)
|
||||
return actions
|
||||
except json.JSONDecodeError as e:
|
||||
return f"Failed to parse JSON: {e}"
|
||||
else:
|
||||
matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL)
|
||||
if matches:
|
||||
# Assuming there's only one match, parse the JSON string into a dictionary
|
||||
try:
|
||||
for match in matches:
|
||||
action_dict = json.loads(match)
|
||||
actions.append(action_dict)
|
||||
return actions
|
||||
except json.JSONDecodeError as e:
|
||||
return f"Failed to parse JSON: {e}"
|
||||
else:
|
||||
try:
|
||||
action_dict = json.loads(input_string)
|
||||
return [action_dict]
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid response format: " + input_string)
|
||||
|
||||
|
||||
class GPT4v_Agent:
|
||||
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300):
|
||||
self.instruction = instruction
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
self.trajectory = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": SYS_PROMPT
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def predict(self, obs):
|
||||
obs = inference(obs, slider=2.0, mode="Automatic", alpha=0.1, label_mode="Alphabet", anno_mode=["Mask", "Mark"])
|
||||
base64_image = encode_image(obs)
|
||||
self.trajectory.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's the next step for instruction '{}'?".format(self.instruction)
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
traj_to_show = []
|
||||
for i in range(len(self.trajectory)):
|
||||
traj_to_show.append(self.trajectory[i]["content"][0]["text"])
|
||||
if len(self.trajectory[i]["content"]) > 1:
|
||||
traj_to_show.append("screenshot_obs")
|
||||
print("Trajectory:", traj_to_show)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": self.trajectory,
|
||||
"max_tokens": self.max_tokens
|
||||
}
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload)
|
||||
|
||||
try:
|
||||
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
||||
except:
|
||||
print("Failed to parse action from response:", response.json()['choices'][0]['message']['content'])
|
||||
actions = None
|
||||
|
||||
return actions
|
||||
|
||||
def parse_actions(self, response: str):
|
||||
# response example
|
||||
"""
|
||||
```json
|
||||
{
|
||||
"action_type": "CLICK",
|
||||
"click_type": "RIGHT"
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
# parse from the response
|
||||
actions = parse_actions_from_string(response)
|
||||
|
||||
# add action into the trajectory
|
||||
self.trajectory.append({
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": response
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
# parse action
|
||||
parsed_actions = []
|
||||
for action in actions:
|
||||
parsed_action = {}
|
||||
action_type = Action[action['action_type']].value
|
||||
parsed_action["action_type"] = action_type
|
||||
|
||||
if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value:
|
||||
parsed_action["click_type"] = MouseClick[action['click_type']].value
|
||||
|
||||
if action_type == Action.MOUSE_MOVE.value:
|
||||
parsed_action["x"] = action["x"]
|
||||
parsed_action["y"] = action["y"]
|
||||
|
||||
if action_type == Action.KEY.value:
|
||||
parsed_action["key"] = action["key"] # handle the condition of single key and multiple keys
|
||||
|
||||
if action_type == Action.TYPE.value:
|
||||
parsed_action["text"] = action["text"]
|
||||
|
||||
parsed_actions.append(parsed_action)
|
||||
|
||||
return parsed_actions
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# OpenAI API Key
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
agent = GPT4v_Agent(api_key=api_key, instruction="Open Google Sheet")
|
||||
print(agent.predict(obs="stackoverflow.png"))
|
||||
Reference in New Issue
Block a user