diff --git a/mm_agents/SoM_agent.py b/mm_agents/SoM_agent.py new file mode 100644 index 0000000..d4a265c --- /dev/null +++ b/mm_agents/SoM_agent.py @@ -0,0 +1,277 @@ +# fixme: Need to be rewrite on new action space + +import os +import re +import base64 +from desktop_env.envs.desktop_env import Action, MouseClick +import json +import requests +from mm_agents.gpt_4v_prompt import SYS_PROMPT + +import torch +import argparse + +# seem +from seem.modeling.BaseModel import BaseModel as BaseModel_Seem +from seem.utils.distributed import init_distributed as init_distributed_seem +from seem.modeling import build_model as build_model_seem +from task_adapter.seem.tasks import interactive_seem_m2m_auto, inference_seem_pano, inference_seem_interactive + +# semantic sam +from semantic_sam.BaseModel import BaseModel +from semantic_sam import build_model +from semantic_sam.utils.dist import init_distributed_mode +from semantic_sam.utils.arguments import load_opt_from_config_file +from semantic_sam.utils.constants import COCO_PANOPTIC_CLASSES +from task_adapter.semantic_sam.tasks import inference_semsam_m2m_auto, prompt_switch + +# sam +from segment_anything import sam_model_registry +from task_adapter.sam.tasks.inference_sam_m2m_auto import inference_sam_m2m_auto +from task_adapter.sam.tasks.inference_sam_m2m_interactive import inference_sam_m2m_interactive + +from scipy.ndimage import label +import numpy as np + +''' +build args +''' +semsam_cfg = "configs/semantic_sam_only_sa-1b_swinL.yaml" +seem_cfg = "configs/seem_focall_unicl_lang_v1.yaml" + +semsam_ckpt = "./swinl_only_sam_many2many.pth" +sam_ckpt = "./sam_vit_h_4b8939.pth" +seem_ckpt = "./seem_focall_v1.pt" + +opt_semsam = load_opt_from_config_file(semsam_cfg) +opt_seem = load_opt_from_config_file(seem_cfg) +opt_seem = init_distributed_seem(opt_seem) + +''' +build model +''' +model_semsam = BaseModel(opt_semsam, build_model(opt_semsam)).from_pretrained(semsam_ckpt).eval().cuda() +model_sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt).eval().cuda() +model_seem = BaseModel_Seem(opt_seem, build_model_seem(opt_seem)).from_pretrained(seem_ckpt).eval().cuda() + +with torch.no_grad(): + with torch.autocast(device_type='cuda', dtype=torch.float16): + model_seem.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True) + +@torch.no_grad() +def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs): + if slider < 1.5: + model_name = 'seem' + elif slider > 2.5: + model_name = 'sam' + else: + if mode == 'Automatic': + model_name = 'semantic-sam' + if slider < 1.5 + 0.14: + level = [1] + elif slider < 1.5 + 0.28: + level = [2] + elif slider < 1.5 + 0.42: + level = [3] + elif slider < 1.5 + 0.56: + level = [4] + elif slider < 1.5 + 0.70: + level = [5] + elif slider < 1.5 + 0.84: + level = [6] + else: + level = [6, 1, 2, 3, 4, 5] + else: + model_name = 'sam' + + if label_mode == 'Alphabet': + label_mode = 'a' + else: + label_mode = '1' + + text_size, hole_scale, island_scale = 640, 100, 100 + text, text_part, text_thresh = '', '', '0.0' + + with torch.autocast(device_type='cuda', dtype=torch.float16): + semantic = False + + if mode == "Interactive": + labeled_array, num_features = label(np.asarray(image['mask'].convert('L'))) + spatial_masks = torch.stack([torch.from_numpy(labeled_array == i+1) for i in range(num_features)]) + + if model_name == 'semantic-sam': + model = model_semsam + output, mask = inference_semsam_m2m_auto(model, image['image'], level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs) + + elif model_name == 'sam': + model = model_sam + if mode == "Automatic": + output, mask = inference_sam_m2m_auto(model, image['image'], text_size, label_mode, alpha, anno_mode) + elif mode == "Interactive": + output, mask = inference_sam_m2m_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode) + + elif model_name == 'seem': + model = model_seem + if mode == "Automatic": + output, mask = inference_seem_pano(model, image['image'], text_size, label_mode, alpha, anno_mode) + elif mode == "Interactive": + output, mask = inference_seem_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode) + + return output + +# Function to encode the image +def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + +def parse_actions_from_string(input_string): + # Search for a JSON string within the input string + actions = [] + matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL) + if matches: + # Assuming there's only one match, parse the JSON string into a dictionary + try: + for match in matches: + action_dict = json.loads(match) + actions.append(action_dict) + return actions + except json.JSONDecodeError as e: + return f"Failed to parse JSON: {e}" + else: + matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL) + if matches: + # Assuming there's only one match, parse the JSON string into a dictionary + try: + for match in matches: + action_dict = json.loads(match) + actions.append(action_dict) + return actions + except json.JSONDecodeError as e: + return f"Failed to parse JSON: {e}" + else: + try: + action_dict = json.loads(input_string) + return [action_dict] + except json.JSONDecodeError as e: + raise ValueError("Invalid response format: " + input_string) + + +class GPT4v_Agent: + def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300): + self.instruction = instruction + self.model = model + self.max_tokens = max_tokens + + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + + self.trajectory = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": SYS_PROMPT + }, + ] + } + ] + + def predict(self, obs): + obs = inference(obs, slider=2.0, mode="Automatic", alpha=0.1, label_mode="Alphabet", anno_mode=["Mask", "Mark"]) + base64_image = encode_image(obs) + self.trajectory.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "What's the next step for instruction '{}'?".format(self.instruction) + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + } + ] + }) + traj_to_show = [] + for i in range(len(self.trajectory)): + traj_to_show.append(self.trajectory[i]["content"][0]["text"]) + if len(self.trajectory[i]["content"]) > 1: + traj_to_show.append("screenshot_obs") + print("Trajectory:", traj_to_show) + payload = { + "model": self.model, + "messages": self.trajectory, + "max_tokens": self.max_tokens + } + response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload) + + try: + actions = self.parse_actions(response.json()['choices'][0]['message']['content']) + except: + print("Failed to parse action from response:", response.json()['choices'][0]['message']['content']) + actions = None + + return actions + + def parse_actions(self, response: str): + # response example + """ + ```json + { + "action_type": "CLICK", + "click_type": "RIGHT" + } + ``` + """ + + # parse from the response + actions = parse_actions_from_string(response) + + # add action into the trajectory + self.trajectory.append({ + "role": "assistant", + "content": [ + { + "type": "text", + "text": response + }, + ] + }) + + # parse action + parsed_actions = [] + for action in actions: + parsed_action = {} + action_type = Action[action['action_type']].value + parsed_action["action_type"] = action_type + + if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value: + parsed_action["click_type"] = MouseClick[action['click_type']].value + + if action_type == Action.MOUSE_MOVE.value: + parsed_action["x"] = action["x"] + parsed_action["y"] = action["y"] + + if action_type == Action.KEY.value: + parsed_action["key"] = action["key"] # handle the condition of single key and multiple keys + + if action_type == Action.TYPE.value: + parsed_action["text"] = action["text"] + + parsed_actions.append(parsed_action) + + return parsed_actions + + +if __name__ == '__main__': + # OpenAI API Key + api_key = os.environ.get("OPENAI_API_KEY") + + agent = GPT4v_Agent(api_key=api_key, instruction="Open Google Sheet") + print(agent.predict(obs="stackoverflow.png")) \ No newline at end of file