From 09f3e776aede4a912a972c3f57c7929e33a961df Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Sat, 20 Jan 2024 00:13:46 +0800 Subject: [PATCH] Initialize all baselines: screenshot, a11y tree, both, SoM, SeeAct --- experiment_pure_text.py | 2 + experiment.py => experiment_screenshot.py | 0 mm_agents/SoM_agent.py | 283 ---- .../heuristic_retrieve.py | 27 +- mm_agents/gpt_4_agent.py | 195 --- mm_agents/gpt_4_prompt_action.py | 244 --- mm_agents/gpt_4_prompt_code.py | 18 - mm_agents/gpt_4v_agent.py | 373 ++++- mm_agents/gpt_4v_prompt_action.py | 244 --- mm_agents/gpt_4v_prompt_code.py | 18 - mm_agents/prompts.py | 862 ++++++++++ mm_agents/sam_test.py | 124 -- mm_agents/visualizer.py | 1405 +++++++++++++++++ requirements.txt | 1 + 14 files changed, 2588 insertions(+), 1208 deletions(-) rename experiment.py => experiment_screenshot.py (100%) delete mode 100644 mm_agents/SoM_agent.py delete mode 100644 mm_agents/gpt_4_agent.py delete mode 100644 mm_agents/gpt_4_prompt_action.py delete mode 100644 mm_agents/gpt_4_prompt_code.py delete mode 100644 mm_agents/gpt_4v_prompt_action.py delete mode 100644 mm_agents/gpt_4v_prompt_code.py create mode 100644 mm_agents/prompts.py delete mode 100644 mm_agents/sam_test.py create mode 100644 mm_agents/visualizer.py diff --git a/experiment_pure_text.py b/experiment_pure_text.py index cfcbd46..4fd19b1 100644 --- a/experiment_pure_text.py +++ b/experiment_pure_text.py @@ -62,6 +62,8 @@ def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_tr env.controller.start_recording() while not done and step_num < max_steps: + with open("accessibility_tree.xml", "w", encoding="utf-8") as f: + f.write(observation["accessibility_tree"]) actions = agent.predict(observation) step_num += 1 for action in actions: diff --git a/experiment.py b/experiment_screenshot.py similarity index 100% rename from experiment.py rename to experiment_screenshot.py diff --git a/mm_agents/SoM_agent.py b/mm_agents/SoM_agent.py deleted file mode 100644 index e3b3e59..0000000 --- a/mm_agents/SoM_agent.py +++ /dev/null @@ -1,283 +0,0 @@ -# fixme: Need to be rewrite on new action space - -import os -import re -import base64 -import PIL.Image -import json -import requests - -import torch -import argparse - -# seem -from seem.modeling.BaseModel import BaseModel as BaseModel_Seem -from seem.utils.distributed import init_distributed as init_distributed_seem -from seem.modeling import build_model as build_model_seem -from task_adapter.seem.tasks import inference_seem_pano - -# semantic sam -from semantic_sam.BaseModel import BaseModel -from semantic_sam import build_model -from semantic_sam.utils.dist import init_distributed_mode -from semantic_sam.utils.arguments import load_opt_from_config_file -from semantic_sam.utils.constants import COCO_PANOPTIC_CLASSES -from task_adapter.semantic_sam.tasks import inference_semsam_m2m_auto, prompt_switch - -# sam -from segment_anything import sam_model_registry -from task_adapter.sam.tasks.inference_sam_m2m_auto import inference_sam_m2m_auto - -from scipy.ndimage import label -from io import BytesIO -import numpy as np - -SYS_PROMPT = ''' -You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection. -For each step, you will get an observation of an image, which is the screenshot of the computer screen. And you will predict the action of the computer based on the image. - -Firstly you need to predict the class of your action, select from one below: -- **CLICK**: click on the screen with the specified integer label -- **TYPE**: type a string on the keyboard - -- For CLICK, you need to predict the correct integer label shown on the screenshot -for example, format as: -``` -{ - "action_type": "CLICK", - "label": 7 -} -``` -- For TYPE, you need to specify the text you want to type -for example, format as: -``` -{ - "action_type": "TYPE", - "text": "hello world" -} -``` - -For every step, you should only return the action_type and the parameters of your action as a dict, without any other things. You MUST wrap the dict with backticks (\`). -You can predict multiple actions at one step, but you should only return one action for each step. -You MUST choose and ONLY CHOOSE from the action space above, otherwise your action will be considered as invalid and you will get a penalty. -''' - -# build args -semsam_cfg = "configs/semantic_sam_only_sa-1b_swinL.yaml" -seem_cfg = "configs/seem_focall_unicl_lang_v1.yaml" - -semsam_ckpt = "./swinl_only_sam_many2many.pth" -sam_ckpt = "./sam_vit_h_4b8939.pth" -seem_ckpt = "./seem_focall_v1.pt" - -opt_semsam = load_opt_from_config_file(semsam_cfg) -opt_seem = load_opt_from_config_file(seem_cfg) -opt_seem = init_distributed_seem(opt_seem) - -# build model -model_semsam = BaseModel(opt_semsam, build_model(opt_semsam)).from_pretrained(semsam_ckpt).eval().cuda() -model_sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt).eval().cuda() -model_seem = BaseModel_Seem(opt_seem, build_model_seem(opt_seem)).from_pretrained(seem_ckpt).eval().cuda() - -with torch.no_grad(): - with torch.autocast(device_type='cuda', dtype=torch.float16): - model_seem.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True) - -@torch.no_grad() -def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs): - if slider < 1.5: - model_name = 'seem' - elif slider > 2.5: - model_name = 'sam' - else: - model_name = 'semantic-sam' - if slider < 1.5 + 0.14: - level = [1] - elif slider < 1.5 + 0.28: - level = [2] - elif slider < 1.5 + 0.42: - level = [3] - elif slider < 1.5 + 0.56: - level = [4] - elif slider < 1.5 + 0.70: - level = [5] - elif slider < 1.5 + 0.84: - level = [6] - else: - level = [6, 1, 2, 3, 4, 5] - - if label_mode == 'Alphabet': - label_mode = 'a' - else: - label_mode = '1' - - text_size, hole_scale, island_scale = 1280, 100, 100 - text, text_part, text_thresh = '', '', '0.0' - - with torch.autocast(device_type='cuda', dtype=torch.float16): - semantic = False - - if model_name == 'semantic-sam': - model = model_semsam - output, mask = inference_semsam_m2m_auto(model, image, level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs) - - elif model_name == 'sam': - model = model_sam - output, mask = inference_sam_m2m_auto(model, image, text_size, label_mode, alpha, anno_mode) - - elif model_name == 'seem': - model = model_seem - output, mask = inference_seem_pano(model, image, text_size, label_mode, alpha, anno_mode) - - return output, mask - -# Function to encode the image -def encode_image(image): - pil_img = PIL.Image.fromarray(image) - buff = BytesIO() - pil_img.save(buff, format="JPEG") - new_image_string = base64.b64encode(buff.getvalue()).decode("utf-8") - return new_image_string - -def parse_actions_from_string(input_string): - # Search for a JSON string within the input string - actions = [] - matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL) - if matches: - # Assuming there's only one match, parse the JSON string into a dictionary - try: - for match in matches: - action_dict = json.loads(match) - actions.append(action_dict) - return actions - except json.JSONDecodeError as e: - return f"Failed to parse JSON: {e}" - else: - matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL) - if matches: - # Assuming there's only one match, parse the JSON string into a dictionary - try: - for match in matches: - action_dict = json.loads(match) - actions.append(action_dict) - return actions - except json.JSONDecodeError as e: - return f"Failed to parse JSON: {e}" - else: - try: - action_dict = json.loads(input_string) - return [action_dict] - except json.JSONDecodeError as e: - raise ValueError("Invalid response format: " + input_string) - -class GPT4v_Agent: - def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300): - self.instruction = instruction - self.model = model - self.max_tokens = max_tokens - - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" - } - - self.trajectory = [ - { - "role": "system", - "content": [ - { - "type": "text", - "text": SYS_PROMPT - }, - ] - } - ] - - def predict(self, obs): - obs, mask = inference(obs, slider=3.0, mode="Automatic", alpha=0.1, label_mode="Number", anno_mode=["Mark", "Box"]) - PIL.Image.fromarray(obs).save("desktop.jpeg") - base64_image = encode_image(obs) - self.trajectory.append({ - "role": "user", - "content": [ - { - "type": "text", - "text": "What's the next step for instruction '{}'?".format(self.instruction) - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}" - } - } - ] - }) - traj_to_show = [] - for i in range(len(self.trajectory)): - traj_to_show.append(self.trajectory[i]["content"][0]["text"]) - if len(self.trajectory[i]["content"]) > 1: - traj_to_show.append("screenshot_obs") - print("Trajectory:", traj_to_show) - payload = { - "model": self.model, - "messages": self.trajectory, - "max_tokens": self.max_tokens - } - response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload) - - try: - actions = self.parse_actions(response.json()['choices'][0]['message']['content'], mask) - except: - print("Failed to parse action from response:", response.json()['choices'][0]['message']['content']) - actions = None - - return actions - - def parse_actions(self, response: str, mask): - # response example - """ - ```json - { - "action_type": "CLICK", - "click_type": "RIGHT" - } - ``` - """ - - # parse from the response - actions = parse_actions_from_string(response) - print(actions) - - # add action into the trajectory - self.trajectory.append({ - "role": "assistant", - "content": [ - { - "type": "text", - "text": response - }, - ] - }) - - # parse action - parsed_actions = [] - for action in actions: - action_type = action['action_type'] - if action_type == "CLICK": - label = int(action['label']) - x, y, w, h = mask[label-1]['bbox'] - parsed_actions.append({"action_type": action_type, "x": int(x + w//2) , "y": int(y + h//2)}) - - if action_type == "TYPE": - parsed_actions.append({"action_type": action_type, "text": action["text"]}) - - return parsed_actions - - -if __name__ == '__main__': - # OpenAI API Key - api_key = os.environ.get("OPENAI_API_KEY") - - agent = GPT4v_Agent(api_key=api_key, instruction="Open Firefox") - obs = PIL.Image.open('desktop.png') - print(agent.predict(obs=obs)) \ No newline at end of file diff --git a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py index d6f83eb..c59060c 100644 --- a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py +++ b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py @@ -41,10 +41,12 @@ def filter_nodes(nodes): elif node.tag == 'text': continue else: - coords = tuple(map(int, node.attrib.get('{uri:deskat:component.at-spi.gnome.org}screencoord').strip('()').split(', '))) + coords = tuple( + map(int, node.attrib.get('{uri:deskat:component.at-spi.gnome.org}screencoord').strip('()').split(', '))) if coords[0] < 0 or coords[1] < 0: continue - size = tuple(map(int, node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size').strip('()').split(', '))) + size = tuple( + map(int, node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size').strip('()').split(', '))) if size[0] <= 0 or size[1] <= 0: continue # Node is not a 'panel', add to the list. @@ -57,6 +59,9 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path): # Load the screenshot image image = Image.open(image_file_path) draw = ImageDraw.Draw(image) + marks = [] + + # todo: change the image tagger to align with SoM paper # Optional: Load a font. If you don't specify a font, a default one will be used. try: @@ -95,8 +100,26 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path): text_position = (coords[0], bottom_right[1]) # Adjust Y to be above the bottom right draw.text(text_position, str(index), font=font, fill="purple") + # each mark is an x, y, w, h tuple + marks.append([coords[0], coords[1], size[0], size[1]]) + except ValueError as e: pass # Save the result image.save(output_image_file_path) + return marks + + +def print_nodes_with_indent(nodes, indent=0): + for node in nodes: + print(' ' * indent, node.tag, node.attrib) + print_nodes_with_indent(node, indent + 2) + + +if __name__ == '__main__': + with open('chrome_desktop_example_1.xml', 'r', encoding='utf-8') as f: + xml_file_str = f.read() + + nodes = ET.fromstring(xml_file_str) + print_nodes_with_indent(nodes) diff --git a/mm_agents/gpt_4_agent.py b/mm_agents/gpt_4_agent.py deleted file mode 100644 index aa19185..0000000 --- a/mm_agents/gpt_4_agent.py +++ /dev/null @@ -1,195 +0,0 @@ -import base64 -import json -import re -import time -from typing import Dict, List - -import requests - -from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes -from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION -from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE - - -# Function to encode the image -def encode_image(image_path): - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode('utf-8') - - -def parse_actions_from_string(input_string): - # Search for a JSON string within the input string - actions = [] - matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL) - if matches: - # Assuming there's only one match, parse the JSON string into a dictionary - try: - for match in matches: - action_dict = json.loads(match) - actions.append(action_dict) - return actions - except json.JSONDecodeError as e: - return f"Failed to parse JSON: {e}" - else: - matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL) - if matches: - # Assuming there's only one match, parse the JSON string into a dictionary - try: - for match in matches: - action_dict = json.loads(match) - actions.append(action_dict) - return actions - except json.JSONDecodeError as e: - return f"Failed to parse JSON: {e}" - else: - try: - action_dict = json.loads(input_string) - return [action_dict] - except json.JSONDecodeError as e: - raise ValueError("Invalid response format: " + input_string) - - -def parse_code_from_string(input_string): - # This regular expression will match both ```code``` and ```python code``` - # and capture the `code` part. It uses a non-greedy match for the content inside. - pattern = r"```(?:\w+\s+)?(.*?)```" - # Find all non-overlapping matches in the string - matches = re.findall(pattern, input_string, re.DOTALL) - - # The regex above captures the content inside the triple backticks. - # The `re.DOTALL` flag allows the dot `.` to match newline characters as well, - # so the code inside backticks can span multiple lines. - - # matches now contains all the captured code snippets - - codes = [] - - for match in matches: - match = match.strip() - commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands - - if match in commands: - codes.append(match.strip()) - elif match.split('\n')[-1] in commands: - if len(match.split('\n')) > 1: - codes.append("\n".join(match.split('\n')[:-1])) - codes.append(match.split('\n')[-1]) - else: - codes.append(match) - - return codes - - -class GPT4_Agent: - def __init__(self, api_key, instruction, model="gpt-4-1106-preview", max_tokens=600, action_space="computer_13"): - self.instruction = instruction - self.model = model - self.max_tokens = max_tokens - self.action_space = action_space - - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}" - } - - self.trajectory = [ - { - "role": "system", - "content": [ - { - "type": "text", - "text": { - "computer_13": SYS_PROMPT_ACTION, - "pyautogui": SYS_PROMPT_CODE - }[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction) - }, - ] - } - ] - - def predict(self, obs: Dict) -> List: - """ - Predict the next action(s) based on the current observation. - """ - accessibility_tree = obs["accessibility_tree"] - - leaf_nodes = find_leaf_nodes(accessibility_tree) - filtered_nodes = filter_nodes(leaf_nodes) - - linearized_accessibility_tree = "tag\ttext\tposition\tsize\n" - # Linearize the accessibility tree nodes into a table format - - for node in filtered_nodes: - linearized_accessibility_tree += node.tag + "\t" - linearized_accessibility_tree += node.attrib.get('name') + "\t" - linearized_accessibility_tree += node.attrib.get( - '{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t" - linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n" - - self.trajectory.append({ - "role": "user", - "content": [ - { - "type": "text", - "text": "Given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( - linearized_accessibility_tree) - } - ] - }) - - # print( - # "Given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( - # linearized_accessibility_tree) - # ) - - traj_to_show = [] - for i in range(len(self.trajectory)): - traj_to_show.append(self.trajectory[i]["content"][0]["text"]) - if len(self.trajectory[i]["content"]) > 1: - traj_to_show.append("screenshot_obs") - - payload = { - "model": self.model, - "messages": self.trajectory, - "max_tokens": self.max_tokens - } - - while True: - try: - response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, - json=payload) - break - except: - print("Failed to generate response, retrying...") - time.sleep(5) - pass - - try: - actions = self.parse_actions(response.json()['choices'][0]['message']['content']) - except: - print("Failed to parse action from response:", response.json()) - actions = None - - return actions - - def parse_actions(self, response: str): - # parse from the response - if self.action_space == "computer_13": - actions = parse_actions_from_string(response) - elif self.action_space == "pyautogui": - actions = parse_code_from_string(response) - else: - raise ValueError("Invalid action space: " + self.action_space) - - # add action into the trajectory - self.trajectory.append({ - "role": "assistant", - "content": [ - { - "type": "text", - "text": response - }, - ] - }) - - return actions diff --git a/mm_agents/gpt_4_prompt_action.py b/mm_agents/gpt_4_prompt_action.py deleted file mode 100644 index 3019074..0000000 --- a/mm_agents/gpt_4_prompt_action.py +++ /dev/null @@ -1,244 +0,0 @@ -SYS_PROMPT = """ -You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection. -For each step, you will get an observation of the desktop by the XML format of accessibility tree, which is based on AT-SPI library. And you will predict the action of the computer based on the accessibility tree. - -HERE is the description of the action space you need to predict, follow the format and choose the correct action type and parameters: -ACTION_SPACE = [ - { - "action_type": "MOVE_TO", - "note": "move the cursor to the specified position", - "parameters": { - "x": { - "type": float, - "range": [0, X_MAX], - "optional": False, - }, - "y": { - "type": float, - "range": [0, Y_MAX], - "optional": False, - } - } - }, - { - "action_type": "CLICK", - "note": "click the left button if the button not specified, otherwise click the specified button; click at the current position if x and y are not specified, otherwise click at the specified position", - "parameters": { - "button": { - "type": str, - "range": ["left", "right", "middle"], - "optional": True, - }, - "x": { - "type": float, - "range": [0, X_MAX], - "optional": True, - }, - "y": { - "type": float, - "range": [0, Y_MAX], - "optional": True, - }, - "num_clicks": { - "type": int, - "range": [1, 2, 3], - "optional": True, - }, - } - }, - { - "action_type": "MOUSE_DOWN", - "note": "press the left button if the button not specified, otherwise press the specified button", - "parameters": { - "button": { - "type": str, - "range": ["left", "right", "middle"], - "optional": True, - } - } - }, - { - "action_type": "MOUSE_UP", - "note": "release the left button if the button not specified, otherwise release the specified button", - "parameters": { - "button": { - "type": str, - "range": ["left", "right", "middle"], - "optional": True, - } - } - }, - { - "action_type": "RIGHT_CLICK", - "note": "right click at the current position if x and y are not specified, otherwise right click at the specified position", - "parameters": { - "x": { - "type": float, - "range": [0, X_MAX], - "optional": True, - }, - "y": { - "type": float, - "range": [0, Y_MAX], - "optional": True, - } - } - }, - { - "action_type": "DOUBLE_CLICK", - "note": "double click at the current position if x and y are not specified, otherwise double click at the specified position", - "parameters": { - "x": { - "type": float, - "range": [0, X_MAX], - "optional": True, - }, - "y": { - "type": float, - "range": [0, Y_MAX], - "optional": True, - } - } - }, - { - "action_type": "DRAG_TO", - "note": "drag the cursor to the specified position with the left button pressed", - "parameters": { - "x": { - "type": float, - "range": [0, X_MAX], - "optional": False, - }, - "y": { - "type": float, - "range": [0, Y_MAX], - "optional": False, - } - } - }, - { - "action_type": "SCROLL", - "note": "scroll the mouse wheel up or down", - "parameters": { - "dx": { - "type": int, - "range": None, - "optional": False, - }, - "dy": { - "type": int, - "range": None, - "optional": False, - } - } - }, - { - "action_type": "TYPING", - "note": "type the specified text", - "parameters": { - "text": { - "type": str, - "range": None, - "optional": False, - } - } - }, - { - "action_type": "PRESS", - "note": "press the specified key and release it", - "parameters": { - "key": { - "type": str, - "range": KEYBOARD_KEYS, - "optional": False, - } - } - }, - { - "action_type": "KEY_DOWN", - "note": "press the specified key", - "parameters": { - "key": { - "type": str, - "range": KEYBOARD_KEYS, - "optional": False, - } - } - }, - { - "action_type": "KEY_UP", - "note": "release the specified key", - "parameters": { - "key": { - "type": str, - "range": KEYBOARD_KEYS, - "optional": False, - } - } - }, - { - "action_type": "HOTKEY", - "note": "press the specified key combination", - "parameters": { - "keys": { - "type": list, - "range": [KEYBOARD_KEYS], - "optional": False, - } - } - }, - ############################################################################################################ - { - "action_type": "WAIT", - "note": "wait until the next action", - }, - { - "action_type": "FAIL", - "note": "decide the task can not be performed", - }, - { - "action_type": "DONE", - "note": "decide the task is done", - } -] -Firstly you need to predict the class of your action, then you need to predict the parameters of your action: -- For MOUSE_MOVE, you need to predict the x and y coordinate of the mouse cursor, the left top corner of the screen is (0, 0), the right bottom corner of the screen is (1920, 1080) -for example, format as: -``` -{ - "action_type": "MOUSE_MOVE", - "x": 1319.11, - "y": 65.06 -} -``` -- For [CLICK, MOUSE_DOWN, MOUSE_UP], you need to specify the click_type as well, select from [LEFT, MIDDLE, RIGHT, WHEEL_UP, WHEEL_DOWN], which means you click the left button, middle button, right button, wheel up or wheel down of your mouse: -for example, format as: -``` -{ - "action_type": "CLICK", - "click_type": "LEFT" -} -``` -- For [KEY, KEY_DOWN, KEY_UP], you need to choose a(multiple) key(s) from the keyboard -for example, format as: -``` -{ - "action_type": "KEY", - "key": "ctrl+c" -} -``` -- For TYPE, you need to specify the text you want to type -for example, format as: -``` -{ - "action_type": "TYPE", - "text": "hello world" -} -``` - -REMEMBER: -For every step, you should only RETURN ME THE action_type AND parameters I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. -You MUST wrap the dict with backticks (\`). -You MUST choose and ONLY CHOOSE from the action space above, otherwise your action will be considered as invalid and you will get a penalty. -You CAN predict multiple actions at one step, but you should only return one action for each step. -""" \ No newline at end of file diff --git a/mm_agents/gpt_4_prompt_code.py b/mm_agents/gpt_4_prompt_code.py deleted file mode 100644 index 25e4083..0000000 --- a/mm_agents/gpt_4_prompt_code.py +++ /dev/null @@ -1,18 +0,0 @@ -SYS_PROMPT = """ -You are an agent which follow my instruction and perform desktop computer tasks as instructed. -You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. -For each step, you will get an observation of the desktop by the XML format of accessibility tree, which is based on AT-SPI library. And you will predict the action of the computer based on the accessibility tree. - -You are required to use `pyautogui` to perform the action. -Return one line or multiple lines of python code to perform the action each time, be time efficient. -You ONLY need to return the code inside a code block, like this: -```python -# your code here -``` -Specially, it is also allowed to return the following special code: -When you think you have to wait for some time, return ```WAIT```; -When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task; -When you think the task is done, return ```DONE```. - -First give the current screenshot and previous things we did a reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. -""" \ No newline at end of file diff --git a/mm_agents/gpt_4v_agent.py b/mm_agents/gpt_4v_agent.py index 0dc3cb1..6e2000c 100644 --- a/mm_agents/gpt_4v_agent.py +++ b/mm_agents/gpt_4v_agent.py @@ -1,14 +1,27 @@ import base64 import json +import os import re import time +import uuid from typing import Dict, List +import backoff import requests +from openai.error import ( + APIConnectionError, + APIError, + RateLimitError, + ServiceUnavailableError, + InvalidRequestError +) -from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes -from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION -from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE +from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes +from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ + SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ + SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ + SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \ + SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT # Function to encode the image @@ -17,6 +30,35 @@ def encode_image(image_path): return base64.b64encode(image_file.read()).decode('utf-8') +def linearize_accessibility_tree(accessibility_tree): + leaf_nodes = find_leaf_nodes(accessibility_tree) + filtered_nodes = filter_nodes(leaf_nodes) + + linearized_accessibility_tree = "tag\ttext\tposition\tsize\n" + # Linearize the accessibility tree nodes into a table format + + for node in filtered_nodes: + linearized_accessibility_tree += node.tag + "\t" + linearized_accessibility_tree += node.attrib.get('name') + "\t" + linearized_accessibility_tree += node.attrib.get( + '{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t" + linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n" + + return linearized_accessibility_tree + + +def tag_screenshot(screenshot, accessibility_tree): + # Creat a tmp file to store the screenshot in random name + uuid_str = str(uuid.uuid4()) + os.makedirs("tmp/images", exist_ok=True) + tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png") + nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) + # Make tag screenshot + marks = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path) + + return marks, tagged_screenshot_file_path + + def parse_actions_from_string(input_string): # Search for a JSON string within the input string actions = [] @@ -61,124 +103,295 @@ def parse_code_from_string(input_string): # so the code inside backticks can span multiple lines. # matches now contains all the captured code snippets - return matches + + codes = [] + + for match in matches: + match = match.strip() + commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands + + if match in commands: + codes.append(match.strip()) + elif match.split('\n')[-1] in commands: + if len(match.split('\n')) > 1: + codes.append("\n".join(match.split('\n')[:-1])) + codes.append(match.split('\n')[-1]) + else: + codes.append(match) + + return codes + + +def parse_code_from_som_string(input_string, masks): + for i, mask in enumerate(masks): + x, y, w, h = mask + input_string = input_string.replace("tag#" + str(i), "{}, {}".format(int(x + w // 2), int(y + h // 2))) + + return parse_code_from_string(input_string) class GPT4v_Agent: - def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300, action_space="computer_13", add_a11y_tree=False): + def __init__( + self, + api_key, + instruction, + model="gpt-4-vision-preview", + max_tokens=300, + action_space="computer_13", + exp="screenshot_a11y_tree" + # exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"] + ): + self.instruction = instruction self.model = model self.max_tokens = max_tokens self.action_space = action_space - self.add_a11y_tree = add_a11y_tree + self.exp = exp self.headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" } - self.trajectory = [ - { - "role": "system", - "content": [ - { - "type": "text", - "text": { - "computer_13": SYS_PROMPT_ACTION, - "pyautogui": SYS_PROMPT_CODE - }[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction) - }, - ] - } - ] + self.actions = [] + self.observations = [] + + if exp == "screenshot": + if action_space == "computer_13": + self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION + elif action_space == "pyautogui": + self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE + else: + raise ValueError("Invalid action space: " + action_space) + elif exp == "a11y_tree": + if action_space == "computer_13": + self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION + elif action_space == "pyautogui": + self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE + else: + raise ValueError("Invalid action space: " + action_space) + elif exp == "both": + if action_space == "computer_13": + self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION + elif action_space == "pyautogui": + self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE + else: + raise ValueError("Invalid action space: " + action_space) + elif exp == "som": + if action_space == "computer_13": + raise ValueError("Invalid action space: " + action_space) + elif action_space == "pyautogui": + self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG + else: + raise ValueError("Invalid action space: " + action_space) + elif exp == "seeact": + if action_space == "computer_13": + raise ValueError("Invalid action space: " + action_space) + elif action_space == "pyautogui": + self.system_message = SYS_PROMPT_SEEACT + else: + raise ValueError("Invalid action space: " + action_space) + else: + raise ValueError("Invalid experiment type: " + exp) + + self.system_message = (self.system_message + + "\nHere is the instruction for the task: {}".format(self.instruction)) def predict(self, obs: Dict) -> List: """ Predict the next action(s) based on the current observation. """ - base64_image = encode_image(obs["screenshot"]) - accessibility_tree = obs["accessibility_tree"] - leaf_nodes = find_leaf_nodes(accessibility_tree) - filtered_nodes = filter_nodes(leaf_nodes) + # Prepare the payload for the API call + messages = [] - linearized_accessibility_tree = "tag\ttext\tposition\tsize\n" - # Linearize the accessibility tree nodes into a table format + if len(self.actions) > 0: + system_message = self.system_message + "\nHere are the actions you have done so far:\n" + "\n->\n".join( + self.actions) + else: + system_message = self.system_message - for node in filtered_nodes: - linearized_accessibility_tree += node.tag + "\t" - linearized_accessibility_tree += node.attrib.get('name') + "\t" - linearized_accessibility_tree += node.attrib.get( - '{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t" - linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n" - - self.trajectory.append({ - "role": "user", + messages.append({ + "role": "system", "content": [ { "type": "text", - "text": "What's the next step that you will do to help with the task?" if not self.add_a11y_tree - else "And given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(linearized_accessibility_tree) + "text": system_message }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}", - "detail": "high" - } - } ] }) - traj_to_show = [] - for i in range(len(self.trajectory)): - traj_to_show.append(self.trajectory[i]["content"][0]["text"]) - if len(self.trajectory[i]["content"]) > 1: - traj_to_show.append("screenshot_obs") + masks = None - print("Trajectory:", traj_to_show) + if self.exp in ["screenshot", "both"]: + base64_image = encode_image(obs["screenshot"]) + linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "Given the screenshot as below. What's the next step that you will do to help with the task?" + if self.exp == "screenshot" + else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( + linearized_accessibility_tree) + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": "high" + } + } + ] + }) + elif self.exp == "a11y_tree": + linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( + linearized_accessibility_tree) + } + ] + }) + elif self.exp == "som": + # Add som to the screenshot + masks, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) - payload = { + base64_image = encode_image(tagged_screenshot) + linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) + + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "Given the info from the tagged screenshot as below:\n{}\nWhat's the next step that you will do to help with the task?".format( + linearized_accessibility_tree) + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": "high" + } + } + ] + }) + elif self.exp == "seeact": + # Add som to the screenshot + masks, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) + + base64_image = encode_image(tagged_screenshot) + linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) + + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree) + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": "high" + } + } + ] + }) + else: + raise ValueError("Invalid experiment type: " + self.exp) + + response = self.call_llm({ "model": self.model, - "messages": self.trajectory, + "messages": messages, "max_tokens": self.max_tokens - } + }) + if self.exp == "seeact": + messages.append({ + "role": "assistant", + "content": [ + { + "type": "text", + "text": response + } + ] + }) + + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "{}\n\nWhat's the next step that you will do to help with the task?".format( + ACTION_GROUNDING_PROMPT_SEEACT) + } + ] + }) + + response = self.call_llm({ + "model": self.model, + "messages": messages, + "max_tokens": self.max_tokens + }) + + try: + actions = self.parse_actions(response, masks) + except Exception as e: + print("Failed to parse action from response", e) + actions = None + + return actions + + @backoff.on_exception( + backoff.expo, + (APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError), + ) + def call_llm(self, payload): while True: try: - response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, - json=payload) + response = requests.post( + "https://api.openai.com/v1/chat/completions", + headers=self.headers, + json=payload + ) break except: print("Failed to generate response, retrying...") time.sleep(5) pass - try: - actions = self.parse_actions(response.json()['choices'][0]['message']['content']) - except: - print("Failed to parse action from response:", response.json()) - actions = None - return actions + return response.json()['choices'][0]['message']['content'] - def parse_actions(self, response: str): - # parse from the response - if self.action_space == "computer_13": - actions = parse_actions_from_string(response) - elif self.action_space == "pyautogui": - actions = parse_code_from_string(response) - else: - raise ValueError("Invalid action space: " + self.action_space) + def parse_actions(self, response: str, masks=None): - # add action into the trajectory - self.trajectory.append({ - "role": "assistant", - "content": [ - { - "type": "text", - "text": response - }, - ] - }) + if self.exp in ["screenshot", "a11y_tree", "both"]: + # parse from the response + if self.action_space == "computer_13": + actions = parse_actions_from_string(response) + elif self.action_space == "pyautogui": + actions = parse_code_from_string(response) + else: + raise ValueError("Invalid action space: " + self.action_space) - return actions + self.actions.append(actions) + + return actions + elif self.exp in ["som", "seeact"]: + # parse from the response + if self.action_space == "computer_13": + raise ValueError("Invalid action space: " + self.action_space) + elif self.action_space == "pyautogui": + actions = parse_code_from_som_string(response, masks) + else: + raise ValueError("Invalid action space: " + self.action_space) + + self.actions.append(actions) + + return actions diff --git a/mm_agents/gpt_4v_prompt_action.py b/mm_agents/gpt_4v_prompt_action.py deleted file mode 100644 index 4323df6..0000000 --- a/mm_agents/gpt_4v_prompt_action.py +++ /dev/null @@ -1,244 +0,0 @@ -SYS_PROMPT = """ -You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection. -For each step, you will get an observation of an image, which is the screenshot of the computer screen. And you will predict the action of the computer based on the image. - -HERE is the description of the action space you need to predict, follow the format and choose the correct action type and parameters: -ACTION_SPACE = [ - { - "action_type": "MOVE_TO", - "note": "move the cursor to the specified position", - "parameters": { - "x": { - "type": float, - "range": [0, X_MAX], - "optional": False, - }, - "y": { - "type": float, - "range": [0, Y_MAX], - "optional": False, - } - } - }, - { - "action_type": "CLICK", - "note": "click the left button if the button not specified, otherwise click the specified button; click at the current position if x and y are not specified, otherwise click at the specified position", - "parameters": { - "button": { - "type": str, - "range": ["left", "right", "middle"], - "optional": True, - }, - "x": { - "type": float, - "range": [0, X_MAX], - "optional": True, - }, - "y": { - "type": float, - "range": [0, Y_MAX], - "optional": True, - }, - "num_clicks": { - "type": int, - "range": [1, 2, 3], - "optional": True, - }, - } - }, - { - "action_type": "MOUSE_DOWN", - "note": "press the left button if the button not specified, otherwise press the specified button", - "parameters": { - "button": { - "type": str, - "range": ["left", "right", "middle"], - "optional": True, - } - } - }, - { - "action_type": "MOUSE_UP", - "note": "release the left button if the button not specified, otherwise release the specified button", - "parameters": { - "button": { - "type": str, - "range": ["left", "right", "middle"], - "optional": True, - } - } - }, - { - "action_type": "RIGHT_CLICK", - "note": "right click at the current position if x and y are not specified, otherwise right click at the specified position", - "parameters": { - "x": { - "type": float, - "range": [0, X_MAX], - "optional": True, - }, - "y": { - "type": float, - "range": [0, Y_MAX], - "optional": True, - } - } - }, - { - "action_type": "DOUBLE_CLICK", - "note": "double click at the current position if x and y are not specified, otherwise double click at the specified position", - "parameters": { - "x": { - "type": float, - "range": [0, X_MAX], - "optional": True, - }, - "y": { - "type": float, - "range": [0, Y_MAX], - "optional": True, - } - } - }, - { - "action_type": "DRAG_TO", - "note": "drag the cursor to the specified position with the left button pressed", - "parameters": { - "x": { - "type": float, - "range": [0, X_MAX], - "optional": False, - }, - "y": { - "type": float, - "range": [0, Y_MAX], - "optional": False, - } - } - }, - { - "action_type": "SCROLL", - "note": "scroll the mouse wheel up or down", - "parameters": { - "dx": { - "type": int, - "range": None, - "optional": False, - }, - "dy": { - "type": int, - "range": None, - "optional": False, - } - } - }, - { - "action_type": "TYPING", - "note": "type the specified text", - "parameters": { - "text": { - "type": str, - "range": None, - "optional": False, - } - } - }, - { - "action_type": "PRESS", - "note": "press the specified key and release it", - "parameters": { - "key": { - "type": str, - "range": KEYBOARD_KEYS, - "optional": False, - } - } - }, - { - "action_type": "KEY_DOWN", - "note": "press the specified key", - "parameters": { - "key": { - "type": str, - "range": KEYBOARD_KEYS, - "optional": False, - } - } - }, - { - "action_type": "KEY_UP", - "note": "release the specified key", - "parameters": { - "key": { - "type": str, - "range": KEYBOARD_KEYS, - "optional": False, - } - } - }, - { - "action_type": "HOTKEY", - "note": "press the specified key combination", - "parameters": { - "keys": { - "type": list, - "range": [KEYBOARD_KEYS], - "optional": False, - } - } - }, - ############################################################################################################ - { - "action_type": "WAIT", - "note": "wait until the next action", - }, - { - "action_type": "FAIL", - "note": "decide the task can not be performed", - }, - { - "action_type": "DONE", - "note": "decide the task is done", - } -] -Firstly you need to predict the class of your action, then you need to predict the parameters of your action: -- For MOUSE_MOVE, you need to predict the x and y coordinate of the mouse cursor, the left top corner of the screen is (0, 0), the right bottom corner of the screen is (1920, 1080) -for example, format as: -``` -{ - "action_type": "MOUSE_MOVE", - "x": 1319.11, - "y": 65.06 -} -``` -- For [CLICK, MOUSE_DOWN, MOUSE_UP], you need to specify the click_type as well, select from [LEFT, MIDDLE, RIGHT, WHEEL_UP, WHEEL_DOWN], which means you click the left button, middle button, right button, wheel up or wheel down of your mouse: -for example, format as: -``` -{ - "action_type": "CLICK", - "click_type": "LEFT" -} -``` -- For [KEY, KEY_DOWN, KEY_UP], you need to choose a(multiple) key(s) from the keyboard -for example, format as: -``` -{ - "action_type": "KEY", - "key": "ctrl+c" -} -``` -- For TYPE, you need to specify the text you want to type -for example, format as: -``` -{ - "action_type": "TYPE", - "text": "hello world" -} -``` - -REMEMBER: -For every step, you should only RETURN ME THE action_type AND parameters I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. -You MUST wrap the dict with backticks (\`). -You MUST choose and ONLY CHOOSE from the action space above, otherwise your action will be considered as invalid and you will get a penalty. -You CAN predict multiple actions at one step, but you should only return one action for each step. -""" \ No newline at end of file diff --git a/mm_agents/gpt_4v_prompt_code.py b/mm_agents/gpt_4v_prompt_code.py deleted file mode 100644 index 8f256da..0000000 --- a/mm_agents/gpt_4v_prompt_code.py +++ /dev/null @@ -1,18 +0,0 @@ -SYS_PROMPT = """ -You are an agent which follow my instruction and perform desktop computer tasks as instructed. -You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. -For each step, you will get an observation of an image, which is the screenshot of the computer screen and you will predict the action of the computer based on the image. - -You are required to use `pyautogui` to perform the action. -Return one line or multiple lines of python code to perform the action each time, be time efficient. -You ONLY need to return the code inside a code block, like this: -```python -# your code here -``` -Specially, it is also allowed to return the following special code: -When you think you have to wait for some time, return ```WAIT```; -When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task; -When you think the task is done, return ```DONE```. - -First give the current screenshot and previous things we did a reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. -""" \ No newline at end of file diff --git a/mm_agents/prompts.py b/mm_agents/prompts.py new file mode 100644 index 0000000..dcc9a85 --- /dev/null +++ b/mm_agents/prompts.py @@ -0,0 +1,862 @@ +SYS_PROMPT_IN_SCREENSHOT_OUT_CODE = """ +You are an agent which follow my instruction and perform desktop computer tasks as instructed. +You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. +For each step, you will get an observation of an image, which is the screenshot of the computer screen and you will predict the action of the computer based on the image. + +You are required to use `pyautogui` to perform the action. +Return one line or multiple lines of python code to perform the action each time, be time efficient. +You ONLY need to return the code inside a code block, like this: +```python +# your code here +``` +Specially, it is also allowed to return the following special code: +When you think you have to wait for some time, return ```WAIT```; +When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task; +When you think the task is done, return ```DONE```. + +First give the current screenshot and previous things we did a reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. +""".strip() + +SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION = """ +You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection. +For each step, you will get an observation of an image, which is the screenshot of the computer screen. And you will predict the action of the computer based on the image. + +HERE is the description of the action space you need to predict, follow the format and choose the correct action type and parameters: +ACTION_SPACE = [ + { + "action_type": "MOVE_TO", + "note": "move the cursor to the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": False, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": False, + } + } + }, + { + "action_type": "CLICK", + "note": "click the left button if the button not specified, otherwise click the specified button; click at the current position if x and y are not specified, otherwise click at the specified position", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + }, + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + }, + "num_clicks": { + "type": int, + "range": [1, 2, 3], + "optional": True, + }, + } + }, + { + "action_type": "MOUSE_DOWN", + "note": "press the left button if the button not specified, otherwise press the specified button", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + } + } + }, + { + "action_type": "MOUSE_UP", + "note": "release the left button if the button not specified, otherwise release the specified button", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + } + } + }, + { + "action_type": "RIGHT_CLICK", + "note": "right click at the current position if x and y are not specified, otherwise right click at the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + } + } + }, + { + "action_type": "DOUBLE_CLICK", + "note": "double click at the current position if x and y are not specified, otherwise double click at the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + } + } + }, + { + "action_type": "DRAG_TO", + "note": "drag the cursor to the specified position with the left button pressed", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": False, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": False, + } + } + }, + { + "action_type": "SCROLL", + "note": "scroll the mouse wheel up or down", + "parameters": { + "dx": { + "type": int, + "range": None, + "optional": False, + }, + "dy": { + "type": int, + "range": None, + "optional": False, + } + } + }, + { + "action_type": "TYPING", + "note": "type the specified text", + "parameters": { + "text": { + "type": str, + "range": None, + "optional": False, + } + } + }, + { + "action_type": "PRESS", + "note": "press the specified key and release it", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "KEY_DOWN", + "note": "press the specified key", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "KEY_UP", + "note": "release the specified key", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "HOTKEY", + "note": "press the specified key combination", + "parameters": { + "keys": { + "type": list, + "range": [KEYBOARD_KEYS], + "optional": False, + } + } + }, + ############################################################################################################ + { + "action_type": "WAIT", + "note": "wait until the next action", + }, + { + "action_type": "FAIL", + "note": "decide the task can not be performed", + }, + { + "action_type": "DONE", + "note": "decide the task is done", + } +] +Firstly you need to predict the class of your action, then you need to predict the parameters of your action: +- For MOUSE_MOVE, you need to predict the x and y coordinate of the mouse cursor, the left top corner of the screen is (0, 0), the right bottom corner of the screen is (1920, 1080) +for example, format as: +``` +{ + "action_type": "MOUSE_MOVE", + "x": 1319.11, + "y": 65.06 +} +``` +- For [CLICK, MOUSE_DOWN, MOUSE_UP], you need to specify the click_type as well, select from [LEFT, MIDDLE, RIGHT, WHEEL_UP, WHEEL_DOWN], which means you click the left button, middle button, right button, wheel up or wheel down of your mouse: +for example, format as: +``` +{ + "action_type": "CLICK", + "click_type": "LEFT" +} +``` +- For [KEY, KEY_DOWN, KEY_UP], you need to choose a(multiple) key(s) from the keyboard +for example, format as: +``` +{ + "action_type": "KEY", + "key": "ctrl+c" +} +``` +- For TYPE, you need to specify the text you want to type +for example, format as: +``` +{ + "action_type": "TYPE", + "text": "hello world" +} +``` + +REMEMBER: +For every step, you should only RETURN ME THE action_type AND parameters I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. +You MUST wrap the dict with backticks (\`). +You MUST choose and ONLY CHOOSE from the action space above, otherwise your action will be considered as invalid and you will get a penalty. +You CAN predict multiple actions at one step, but you should only return one action for each step. +""".strip() + +SYS_PROMPT_IN_A11Y_OUT_CODE = """ +You are an agent which follow my instruction and perform desktop computer tasks as instructed. +You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. +For each step, you will get an observation of the desktop by accessibility tree, which is based on AT-SPI library. And you will predict the action of the computer based on the accessibility tree. + +You are required to use `pyautogui` to perform the action. +Return one line or multiple lines of python code to perform the action each time, be time efficient. +You ONLY need to return the code inside a code block, like this: +```python +# your code here +``` +Specially, it is also allowed to return the following special code: +When you think you have to wait for some time, return ```WAIT```; +When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task; +When you think the task is done, return ```DONE```. + +First give the current screenshot and previous things we did a reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. +""".strip() + +SYS_PROMPT_IN_A11Y_OUT_ACTION = """ +You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection. +For each step, you will get an observation of the desktop by accessibility tree, which is based on AT-SPI library. And you will predict the action of the computer based on the accessibility tree. + +HERE is the description of the action space you need to predict, follow the format and choose the correct action type and parameters: +ACTION_SPACE = [ + { + "action_type": "MOVE_TO", + "note": "move the cursor to the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": False, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": False, + } + } + }, + { + "action_type": "CLICK", + "note": "click the left button if the button not specified, otherwise click the specified button; click at the current position if x and y are not specified, otherwise click at the specified position", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + }, + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + }, + "num_clicks": { + "type": int, + "range": [1, 2, 3], + "optional": True, + }, + } + }, + { + "action_type": "MOUSE_DOWN", + "note": "press the left button if the button not specified, otherwise press the specified button", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + } + } + }, + { + "action_type": "MOUSE_UP", + "note": "release the left button if the button not specified, otherwise release the specified button", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + } + } + }, + { + "action_type": "RIGHT_CLICK", + "note": "right click at the current position if x and y are not specified, otherwise right click at the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + } + } + }, + { + "action_type": "DOUBLE_CLICK", + "note": "double click at the current position if x and y are not specified, otherwise double click at the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + } + } + }, + { + "action_type": "DRAG_TO", + "note": "drag the cursor to the specified position with the left button pressed", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": False, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": False, + } + } + }, + { + "action_type": "SCROLL", + "note": "scroll the mouse wheel up or down", + "parameters": { + "dx": { + "type": int, + "range": None, + "optional": False, + }, + "dy": { + "type": int, + "range": None, + "optional": False, + } + } + }, + { + "action_type": "TYPING", + "note": "type the specified text", + "parameters": { + "text": { + "type": str, + "range": None, + "optional": False, + } + } + }, + { + "action_type": "PRESS", + "note": "press the specified key and release it", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "KEY_DOWN", + "note": "press the specified key", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "KEY_UP", + "note": "release the specified key", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "HOTKEY", + "note": "press the specified key combination", + "parameters": { + "keys": { + "type": list, + "range": [KEYBOARD_KEYS], + "optional": False, + } + } + }, + ############################################################################################################ + { + "action_type": "WAIT", + "note": "wait until the next action", + }, + { + "action_type": "FAIL", + "note": "decide the task can not be performed", + }, + { + "action_type": "DONE", + "note": "decide the task is done", + } +] +Firstly you need to predict the class of your action, then you need to predict the parameters of your action: +- For MOUSE_MOVE, you need to predict the x and y coordinate of the mouse cursor, the left top corner of the screen is (0, 0), the right bottom corner of the screen is (1920, 1080) +for example, format as: +``` +{ + "action_type": "MOUSE_MOVE", + "x": 1319.11, + "y": 65.06 +} +``` +- For [CLICK, MOUSE_DOWN, MOUSE_UP], you need to specify the click_type as well, select from [LEFT, MIDDLE, RIGHT, WHEEL_UP, WHEEL_DOWN], which means you click the left button, middle button, right button, wheel up or wheel down of your mouse: +for example, format as: +``` +{ + "action_type": "CLICK", + "click_type": "LEFT" +} +``` +- For [KEY, KEY_DOWN, KEY_UP], you need to choose a(multiple) key(s) from the keyboard +for example, format as: +``` +{ + "action_type": "KEY", + "key": "ctrl+c" +} +``` +- For TYPE, you need to specify the text you want to type +for example, format as: +``` +{ + "action_type": "TYPE", + "text": "hello world" +} +``` + +REMEMBER: +For every step, you should only RETURN ME THE action_type AND parameters I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. +You MUST wrap the dict with backticks (\`). +You MUST choose and ONLY CHOOSE from the action space above, otherwise your action will be considered as invalid and you will get a penalty. +You CAN predict multiple actions at one step, but you should only return one action for each step. +""".strip() + +SYS_PROMPT_IN_BOTH_OUT_CODE = """ +You are an agent which follow my instruction and perform desktop computer tasks as instructed. +You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. +For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library. +And you will predict the action of the computer based on the screenshot and accessibility tree. + +You are required to use `pyautogui` to perform the action. +Return one line or multiple lines of python code to perform the action each time, be time efficient. +You ONLY need to return the code inside a code block, like this: +```python +# your code here +``` +Specially, it is also allowed to return the following special code: +When you think you have to wait for some time, return ```WAIT```; +When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task; +When you think the task is done, return ```DONE```. + +First give the current screenshot and previous things we did a reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. +""".strip() + +SYS_PROMPT_IN_BOTH_OUT_ACTION = """ +You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection. +For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library. +And you will predict the action of the computer based on the screenshot and accessibility tree. + +HERE is the description of the action space you need to predict, follow the format and choose the correct action type and parameters: +ACTION_SPACE = [ + { + "action_type": "MOVE_TO", + "note": "move the cursor to the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": False, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": False, + } + } + }, + { + "action_type": "CLICK", + "note": "click the left button if the button not specified, otherwise click the specified button; click at the current position if x and y are not specified, otherwise click at the specified position", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + }, + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + }, + "num_clicks": { + "type": int, + "range": [1, 2, 3], + "optional": True, + }, + } + }, + { + "action_type": "MOUSE_DOWN", + "note": "press the left button if the button not specified, otherwise press the specified button", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + } + } + }, + { + "action_type": "MOUSE_UP", + "note": "release the left button if the button not specified, otherwise release the specified button", + "parameters": { + "button": { + "type": str, + "range": ["left", "right", "middle"], + "optional": True, + } + } + }, + { + "action_type": "RIGHT_CLICK", + "note": "right click at the current position if x and y are not specified, otherwise right click at the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + } + } + }, + { + "action_type": "DOUBLE_CLICK", + "note": "double click at the current position if x and y are not specified, otherwise double click at the specified position", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": True, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": True, + } + } + }, + { + "action_type": "DRAG_TO", + "note": "drag the cursor to the specified position with the left button pressed", + "parameters": { + "x": { + "type": float, + "range": [0, X_MAX], + "optional": False, + }, + "y": { + "type": float, + "range": [0, Y_MAX], + "optional": False, + } + } + }, + { + "action_type": "SCROLL", + "note": "scroll the mouse wheel up or down", + "parameters": { + "dx": { + "type": int, + "range": None, + "optional": False, + }, + "dy": { + "type": int, + "range": None, + "optional": False, + } + } + }, + { + "action_type": "TYPING", + "note": "type the specified text", + "parameters": { + "text": { + "type": str, + "range": None, + "optional": False, + } + } + }, + { + "action_type": "PRESS", + "note": "press the specified key and release it", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "KEY_DOWN", + "note": "press the specified key", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "KEY_UP", + "note": "release the specified key", + "parameters": { + "key": { + "type": str, + "range": KEYBOARD_KEYS, + "optional": False, + } + } + }, + { + "action_type": "HOTKEY", + "note": "press the specified key combination", + "parameters": { + "keys": { + "type": list, + "range": [KEYBOARD_KEYS], + "optional": False, + } + } + }, + ############################################################################################################ + { + "action_type": "WAIT", + "note": "wait until the next action", + }, + { + "action_type": "FAIL", + "note": "decide the task can not be performed", + }, + { + "action_type": "DONE", + "note": "decide the task is done", + } +] +Firstly you need to predict the class of your action, then you need to predict the parameters of your action: +- For MOUSE_MOVE, you need to predict the x and y coordinate of the mouse cursor, the left top corner of the screen is (0, 0), the right bottom corner of the screen is (1920, 1080) +for example, format as: +``` +{ + "action_type": "MOUSE_MOVE", + "x": 1319.11, + "y": 65.06 +} +``` +- For [CLICK, MOUSE_DOWN, MOUSE_UP], you need to specify the click_type as well, select from [LEFT, MIDDLE, RIGHT, WHEEL_UP, WHEEL_DOWN], which means you click the left button, middle button, right button, wheel up or wheel down of your mouse: +for example, format as: +``` +{ + "action_type": "CLICK", + "click_type": "LEFT" +} +``` +- For [KEY, KEY_DOWN, KEY_UP], you need to choose a(multiple) key(s) from the keyboard +for example, format as: +``` +{ + "action_type": "KEY", + "key": "ctrl+c" +} +``` +- For TYPE, you need to specify the text you want to type +for example, format as: +``` +{ + "action_type": "TYPE", + "text": "hello world" +} +``` + +REMEMBER: +For every step, you should only RETURN ME THE action_type AND parameters I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. +You MUST wrap the dict with backticks (\`). +You MUST choose and ONLY CHOOSE from the action space above, otherwise your action will be considered as invalid and you will get a penalty. +You CAN predict multiple actions at one step, but you should only return one action for each step. +""".strip() + +SYS_PROMPT_IN_SOM_A11Y_OUT_TAG = """ +You are an agent which follow my instruction and perform desktop computer tasks as instructed. +You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. +For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library. + +You are required to use `pyautogui` to perform the action. But replace x, y in the code with the tag of the element you want to operate with. such as: +```python +pyautogui.moveTo(tag#3) +pyautogui.click(tag#2) +pyautogui.dragTo(tag#1, button='left') +``` +Return one line or multiple lines of python code to perform the action each time, be time efficient. +You ONLY need to return the code inside a code block, like this: +```python +# your code here +``` +Specially, it is also allowed to return the following special code: +When you think you have to wait for some time, return ```WAIT```; +When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task; +When you think the task is done, return ```DONE```. + +First give the current screenshot and previous things we did a reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. +""".strip() + +SYS_PROMPT_SEEACT = """ +You are an agent which follow my instruction and perform desktop computer tasks as instructed. +You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. +For each step, you will get an observation of an image, which is the screenshot of the computer screen and you will predict the action of the computer based on the image. +""".strip() + +ACTION_DESCRIPTION_PROMPT_SEEACT = """ +The text and image shown below is the observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library. +{} + +Follow the following guidance to think step by step before outlining the next action step at the current stage: + +(Current Screenshot Identification) +Firstly, think about what the current screenshot is. + +(Previous Action Analysis) +Secondly, combined with the screenshot, analyze each step of the previous action history and their intention one by one. Particularly, pay more attention to the last step, which may be more related to what you should do now as the next step. + +(Screenshot Details Analysis) +Closely examine the screenshot to check the status of every part of the webpage to understand what you can operate with and what has been set or completed. You should closely examine the screenshot details to see what steps have been completed by previous actions even though you are given the textual previous actions. Because the textual history may not clearly and sufficiently record some effects of previous actions, you should closely evaluate the status of every part of the webpage to understand what you have done. + +(Next Action Based on Screenshot and Analysis) +Then, based on your analysis, in conjunction with human desktop using habits and the logic of app GUI design, decide on the following action. And clearly outline which button in the screenshot users will operate with as the first next target element, its detailed location, and the corresponding operation. +""" + +ACTION_GROUNDING_PROMPT_SEEACT = """ +You are required to use `pyautogui` to perform the action. But replace x, y in the code with the tag of the element you want to operate with. such as: +```python +pyautogui.moveTo(tag#3) +pyautogui.click(tag#2) +pyautogui.dragTo(tag#1, button='left') +``` +Return one line or multiple lines of python code to perform the action each time, be time efficient. +You ONLY need to return the code inside a code block, like this: +```python +# your code here +``` +Specially, it is also allowed to return the following special code: +When you think you have to wait for some time, return ```WAIT```; +When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task; +When you think the task is done, return ```DONE```. + +First give the current screenshot and previous things we did a reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE. +""" diff --git a/mm_agents/sam_test.py b/mm_agents/sam_test.py deleted file mode 100644 index 9d4ce44..0000000 --- a/mm_agents/sam_test.py +++ /dev/null @@ -1,124 +0,0 @@ -import torch -from PIL import Image -import requests -from transformers import SamModel, SamProcessor -import numpy as np -import matplotlib.pyplot as plt -import os -os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" - -def show_mask(mask, ax, random_color=False): - if random_color: - color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) - else: - color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) - h, w = mask.shape[-2:] - mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) - ax.imshow(mask_image) - - -def show_box(box, ax): - x0, y0 = box[0], box[1] - w, h = box[2] - box[0], box[3] - box[1] - ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2)) - - -def show_boxes_on_image(raw_image, boxes): - plt.figure(figsize=(10, 10)) - plt.imshow(raw_image) - for box in boxes: - show_box(box, plt.gca()) - plt.axis('on') - plt.show() - - -def show_points_on_image(raw_image, input_points, input_labels=None): - plt.figure(figsize=(10, 10)) - plt.imshow(raw_image) - input_points = np.array(input_points) - if input_labels is None: - labels = np.ones_like(input_points[:, 0]) - else: - labels = np.array(input_labels) - show_points(input_points, labels, plt.gca()) - plt.axis('on') - plt.show() - - -def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None): - plt.figure(figsize=(10, 10)) - plt.imshow(raw_image) - input_points = np.array(input_points) - if input_labels is None: - labels = np.ones_like(input_points[:, 0]) - else: - labels = np.array(input_labels) - show_points(input_points, labels, plt.gca()) - for box in boxes: - show_box(box, plt.gca()) - plt.axis('on') - plt.show() - - -def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None): - plt.figure(figsize=(10, 10)) - plt.imshow(raw_image) - input_points = np.array(input_points) - if input_labels is None: - labels = np.ones_like(input_points[:, 0]) - else: - labels = np.array(input_labels) - show_points(input_points, labels, plt.gca()) - for box in boxes: - show_box(box, plt.gca()) - plt.axis('on') - plt.show() - - -def show_points(coords, labels, ax, marker_size=375): - pos_points = coords[labels == 1] - neg_points = coords[labels == 0] - ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', - linewidth=1.25) - ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', - linewidth=1.25) - - -def show_masks_on_image(raw_image, masks, scores): - if len(masks.shape) == 4: - masks = masks.squeeze() - if scores.shape[0] == 1: - scores = scores.squeeze() - - nb_predictions = scores.shape[-1] - fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15)) - - for i, (mask, score) in enumerate(zip(masks, scores)): - mask = mask.cpu().detach() - axes[i].imshow(np.array(raw_image)) - show_mask(mask, axes[i]) - axes[i].title.set_text(f"Mask {i + 1}, Score: {score.item():.3f}") - axes[i].axis("off") - plt.show() - - -device = "cuda" if torch.cuda.is_available() else "cpu" -model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) -processor = SamProcessor.from_pretrained("facebook/sam-vit-huge") - -img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" -raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") - -plt.imshow(raw_image) - -inputs = processor(raw_image, return_tensors="pt").to(device) -with torch.no_grad(): - outputs = model(**inputs) - -masks = processor.image_processor.post_process_masks( - outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() -) - - -scores = outputs.iou_scores -show_masks_on_image(raw_image, masks[0], scores) diff --git a/mm_agents/visualizer.py b/mm_agents/visualizer.py new file mode 100644 index 0000000..bd78a98 --- /dev/null +++ b/mm_agents/visualizer.py @@ -0,0 +1,1405 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import colorsys +import logging +import math +import numpy as np +from enum import Enum, unique +import cv2 +import matplotlib as mpl +import matplotlib.colors as mplc +import matplotlib.figure as mplfigure +import pycocotools.mask as mask_util +import torch +from matplotlib.backends.backend_agg import FigureCanvasAgg +from PIL import Image + +from detectron2.data import MetadataCatalog +from detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes +from detectron2.utils.file_io import PathManager + +from detectron2.utils.colormap import random_color +import random + +logger = logging.getLogger(__name__) + +__all__ = ["ColorMode", "VisImage", "Visualizer"] + + +_SMALL_OBJECT_AREA_THRESH = 1000 +_LARGE_MASK_AREA_THRESH = 120000 +_OFF_WHITE = (1.0, 1.0, 240.0 / 255) +_BLACK = (0, 0, 0) +_RED = (1.0, 0, 0) + +_KEYPOINT_THRESHOLD = 0.05 + + +@unique +class ColorMode(Enum): + """ + Enum of different color modes to use for instance visualizations. + """ + + IMAGE = 0 + """ + Picks a random color for every instance and overlay segmentations with low opacity. + """ + SEGMENTATION = 1 + """ + Let instances of the same category have similar colors + (from metadata.thing_colors), and overlay them with + high opacity. This provides more attention on the quality of segmentation. + """ + IMAGE_BW = 2 + """ + Same as IMAGE, but convert all areas without masks to gray-scale. + Only available for drawing per-instance mask predictions. + """ + + +class GenericMask: + """ + Attribute: + polygons (list[ndarray]): list[ndarray]: polygons for this mask. + Each ndarray has format [x, y, x, y, ...] + mask (ndarray): a binary mask + """ + + def __init__(self, mask_or_polygons, height, width): + self._mask = self._polygons = self._has_holes = None + self.height = height + self.width = width + + m = mask_or_polygons + if isinstance(m, dict): + # RLEs + assert "counts" in m and "size" in m + if isinstance(m["counts"], list): # uncompressed RLEs + h, w = m["size"] + assert h == height and w == width + m = mask_util.frPyObjects(m, h, w) + self._mask = mask_util.decode(m)[:, :] + return + + if isinstance(m, list): # list[ndarray] + self._polygons = [np.asarray(x).reshape(-1) for x in m] + return + + if isinstance(m, np.ndarray): # assumed to be a binary mask + assert m.shape[1] != 2, m.shape + assert m.shape == ( + height, + width, + ), f"mask shape: {m.shape}, target dims: {height}, {width}" + self._mask = m.astype("uint8") + return + + raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m))) + + @property + def mask(self): + if self._mask is None: + self._mask = self.polygons_to_mask(self._polygons) + return self._mask + + @property + def polygons(self): + if self._polygons is None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + return self._polygons + + @property + def has_holes(self): + if self._has_holes is None: + if self._mask is not None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + else: + self._has_holes = False # if original format is polygon, does not have holes + return self._has_holes + + def mask_to_polygons(self, mask): + # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level + # hierarchy. External contours (boundary) of the object are placed in hierarchy-1. + # Internal contours (holes) are placed in hierarchy-2. + # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours. + mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr + res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) + hierarchy = res[-1] + if hierarchy is None: # empty mask + return [], False + has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0 + res = res[-2] + res = [x.flatten() for x in res] + # These coordinates from OpenCV are integers in range [0, W-1 or H-1]. + # We add 0.5 to turn them into real-value coordinate space. A better solution + # would be to first +0.5 and then dilate the returned polygon by 0.5. + res = [x + 0.5 for x in res if len(x) >= 6] + return res, has_holes + + def polygons_to_mask(self, polygons): + rle = mask_util.frPyObjects(polygons, self.height, self.width) + rle = mask_util.merge(rle) + return mask_util.decode(rle)[:, :] + + def area(self): + return self.mask.sum() + + def bbox(self): + p = mask_util.frPyObjects(self.polygons, self.height, self.width) + p = mask_util.merge(p) + bbox = mask_util.toBbox(p) + bbox[2] += bbox[0] + bbox[3] += bbox[1] + return bbox + + +class _PanopticPrediction: + """ + Unify different panoptic annotation/prediction formats + """ + + def __init__(self, panoptic_seg, segments_info, metadata=None): + if segments_info is None: + assert metadata is not None + # If "segments_info" is None, we assume "panoptic_img" is a + # H*W int32 image storing the panoptic_id in the format of + # category_id * label_divisor + instance_id. We reserve -1 for + # VOID label. + label_divisor = metadata.label_divisor + segments_info = [] + for panoptic_label in np.unique(panoptic_seg.numpy()): + if panoptic_label == -1: + # VOID region. + continue + pred_class = panoptic_label // label_divisor + isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values() + segments_info.append( + { + "id": int(panoptic_label), + "category_id": int(pred_class), + "isthing": bool(isthing), + } + ) + del metadata + + self._seg = panoptic_seg + + self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info + segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True) + areas = areas.numpy() + sorted_idxs = np.argsort(-areas) + self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs] + self._seg_ids = self._seg_ids.tolist() + for sid, area in zip(self._seg_ids, self._seg_areas): + if sid in self._sinfo: + self._sinfo[sid]["area"] = float(area) + + def non_empty_mask(self): + """ + Returns: + (H, W) array, a mask for all pixels that have a prediction + """ + empty_ids = [] + for id in self._seg_ids: + if id not in self._sinfo: + empty_ids.append(id) + if len(empty_ids) == 0: + return np.zeros(self._seg.shape, dtype=np.uint8) + assert ( + len(empty_ids) == 1 + ), ">1 ids corresponds to no labels. This is currently not supported" + return (self._seg != empty_ids[0]).numpy().astype(np.bool) + + def semantic_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or sinfo["isthing"]: + # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions. + continue + yield (self._seg == sid).numpy().astype(np.bool), sinfo + + def instance_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or not sinfo["isthing"]: + continue + mask = (self._seg == sid).numpy().astype(np.bool) + if mask.sum() > 0: + yield mask, sinfo + + +def _create_text_labels(classes, scores, class_names, is_crowd=None): + """ + Args: + classes (list[int] or None): + scores (list[float] or None): + class_names (list[str] or None): + is_crowd (list[bool] or None): + + Returns: + list[str] or None + """ + labels = None + if classes is not None: + if class_names is not None and len(class_names) > 0: + labels = [class_names[i] for i in classes] + else: + labels = [str(i) for i in classes] + if scores is not None: + if labels is None: + labels = ["{:.0f}%".format(s * 100) for s in scores] + else: + labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)] + if labels is not None and is_crowd is not None: + labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)] + return labels + + +class VisImage: + def __init__(self, img, scale=1.0): + """ + Args: + img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255]. + scale (float): scale the input image + """ + self.img = img + self.scale = scale + self.width, self.height = img.shape[1], img.shape[0] + self._setup_figure(img) + + def _setup_figure(self, img): + """ + Args: + Same as in :meth:`__init__()`. + + Returns: + fig (matplotlib.pyplot.figure): top level container for all the image plot elements. + ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system. + """ + fig = mplfigure.Figure(frameon=False) + self.dpi = fig.get_dpi() + # add a small 1e-2 to avoid precision lost due to matplotlib's truncation + # (https://github.com/matplotlib/matplotlib/issues/15363) + fig.set_size_inches( + (self.width * self.scale + 1e-2) / self.dpi, + (self.height * self.scale + 1e-2) / self.dpi, + ) + self.canvas = FigureCanvasAgg(fig) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) + ax.axis("off") + self.fig = fig + self.ax = ax + self.reset_image(img) + + def reset_image(self, img): + """ + Args: + img: same as in __init__ + """ + img = img.astype("uint8") + self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest") + + def save(self, filepath): + """ + Args: + filepath (str): a string that contains the absolute path, including the file name, where + the visualized image will be saved. + """ + self.fig.savefig(filepath) + + def get_image(self): + """ + Returns: + ndarray: + the visualized image of shape (H, W, 3) (RGB) in uint8 type. + The shape is scaled w.r.t the input image using the given `scale` argument. + """ + canvas = self.canvas + s, (width, height) = canvas.print_to_buffer() + # buf = io.BytesIO() # works for cairo backend + # canvas.print_rgba(buf) + # width, height = self.width, self.height + # s = buf.getvalue() + + buffer = np.frombuffer(s, dtype="uint8") + + img_rgba = buffer.reshape(height, width, 4) + rgb, alpha = np.split(img_rgba, [3], axis=2) + return rgb.astype("uint8") + + +class Visualizer: + """ + Visualizer that draws data about detection/segmentation on images. + + It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}` + that draw primitive objects to images, as well as high-level wrappers like + `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}` + that draw composite data in some pre-defined style. + + Note that the exact visualization style for the high-level wrappers are subject to change. + Style such as color, opacity, label contents, visibility of labels, or even the visibility + of objects themselves (e.g. when the object is too small) may change according + to different heuristics, as long as the results still look visually reasonable. + + To obtain a consistent style, you can implement custom drawing functions with the + abovementioned primitive methods instead. If you need more customized visualization + styles, you can process the data yourself following their format documented in + tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not + intend to satisfy everyone's preference on drawing styles. + + This visualizer focuses on high rendering quality rather than performance. It is not + designed to be used for real-time applications. + """ + + # TODO implement a fast, rasterized version using OpenCV + + def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE): + """ + Args: + img_rgb: a numpy array of shape (H, W, C), where H and W correspond to + the height and width of the image respectively. C is the number of + color channels. The image is required to be in RGB format since that + is a requirement of the Matplotlib library. The image is also expected + to be in the range [0, 255]. + metadata (Metadata): dataset metadata (e.g. class names and colors) + instance_mode (ColorMode): defines one of the pre-defined style for drawing + instances on an image. + """ + self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8) + if metadata is None: + metadata = MetadataCatalog.get("__nonexist__") + self.metadata = metadata + self.output = VisImage(self.img, scale=scale) + self.cpu_device = torch.device("cpu") + + # too small texts are useless, therefore clamp to 9 + self._default_font_size = max( + np.sqrt(self.output.height * self.output.width) // 90, 10 // scale + ) + self._default_font_size = 18 + self._instance_mode = instance_mode + self.keypoint_threshold = _KEYPOINT_THRESHOLD + + import matplotlib.colors as mcolors + css4_colors = mcolors.CSS4_COLORS + self.color_proposals = [list(mcolors.hex2color(color)) for color in css4_colors.values()] + + def draw_instance_predictions(self, predictions): + """ + Draw instance-level prediction results on an image. + + Args: + predictions (Instances): the output of an instance detection/segmentation + model. Following fields will be used to draw: + "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). + + Returns: + output (VisImage): image object with visualizations. + """ + boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None + scores = predictions.scores if predictions.has("scores") else None + classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None + labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) + keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None + + keep = (scores > 0.5).cpu() + boxes = boxes[keep] + scores = scores[keep] + classes = np.array(classes) + classes = classes[np.array(keep)] + labels = np.array(labels) + labels = labels[np.array(keep)] + + if predictions.has("pred_masks"): + masks = np.asarray(predictions.pred_masks) + masks = masks[np.array(keep)] + masks = [GenericMask(x, self.output.height, self.output.width) for x in masks] + else: + masks = None + + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): + # if self.metadata.get("thing_colors"): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes + ] + alpha = 0.4 + else: + colors = None + alpha = 0.4 + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image( + self._create_grayscale_image( + (predictions.pred_masks.any(dim=0) > 0).numpy() + if predictions.has("pred_masks") + else None + ) + ) + alpha = 0.3 + + self.overlay_instances( + masks=masks, + boxes=boxes, + labels=labels, + keypoints=keypoints, + assigned_colors=colors, + alpha=alpha, + ) + return self.output + + def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.7): + """ + Draw semantic segmentation predictions/labels. + + Args: + sem_seg (Tensor or ndarray): the segmentation of shape (H, W). + Each value is the integer label of the pixel. + area_threshold (int): segments with less than `area_threshold` are not drawn. + alpha (float): the larger it is, the more opaque the segmentations are. + + Returns: + output (VisImage): image object with visualizations. + """ + if isinstance(sem_seg, torch.Tensor): + sem_seg = sem_seg.numpy() + labels, areas = np.unique(sem_seg, return_counts=True) + sorted_idxs = np.argsort(-areas).tolist() + labels = labels[sorted_idxs] + for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels): + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] + except (AttributeError, IndexError): + mask_color = None + + binary_mask = (sem_seg == label).astype(np.uint8) + text = self.metadata.stuff_classes[label] + self.draw_binary_mask( + binary_mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + return self.output + + def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7): + """ + Draw panoptic prediction annotations or results. + + Args: + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each + segment. + segments_info (list[dict] or None): Describe each segment in `panoptic_seg`. + If it is a ``list[dict]``, each dict contains keys "id", "category_id". + If None, category id of each pixel is computed by + ``pixel // metadata.label_divisor``. + area_threshold (int): stuff segments with less than `area_threshold` are not drawn. + + Returns: + output (VisImage): image object with visualizations. + """ + pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask())) + + # draw mask for all semantic segments first i.e. "stuff" + for mask, sinfo in pred.semantic_masks(): + category_idx = sinfo["category_id"] + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] + except AttributeError: + mask_color = None + + text = self.metadata.stuff_classes[category_idx].replace('-other','').replace('-merged','') + self.draw_binary_mask( + mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + + # draw mask for all instances second + all_instances = list(pred.instance_masks()) + if len(all_instances) == 0: + return self.output + masks, sinfo = list(zip(*all_instances)) + category_ids = [x["category_id"] for x in sinfo] + + try: + scores = [x["score"] for x in sinfo] + except KeyError: + scores = None + class_names = [name.replace('-other','').replace('-merged','') for name in self.metadata.thing_classes] + labels = _create_text_labels( + category_ids, scores, class_names, [x.get("iscrowd", 0) for x in sinfo] + ) + + try: + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids + ] + except AttributeError: + colors = None + self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha) + + return self.output + + draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility + + def draw_dataset_dict(self, dic): + """ + Draw annotations/segmentaions in Detectron2 Dataset format. + + Args: + dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format. + + Returns: + output (VisImage): image object with visualizations. + """ + annos = dic.get("annotations", None) + if annos: + if "segmentation" in annos[0]: + masks = [x["segmentation"] for x in annos] + else: + masks = None + if "keypoints" in annos[0]: + keypts = [x["keypoints"] for x in annos] + keypts = np.array(keypts).reshape(len(annos), -1, 3) + else: + keypts = None + + boxes = [ + BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS) + if len(x["bbox"]) == 4 + else x["bbox"] + for x in annos + ] + + colors = None + category_ids = [x["category_id"] for x in annos] + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) + for c in category_ids + ] + names = self.metadata.get("thing_classes", None) + labels = _create_text_labels( + category_ids, + scores=None, + class_names=names, + is_crowd=[x.get("iscrowd", 0) for x in annos], + ) + self.overlay_instances( + labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors + ) + + sem_seg = dic.get("sem_seg", None) + if sem_seg is None and "sem_seg_file_name" in dic: + with PathManager.open(dic["sem_seg_file_name"], "rb") as f: + sem_seg = Image.open(f) + sem_seg = np.asarray(sem_seg, dtype="uint8") + if sem_seg is not None: + self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.4) + + pan_seg = dic.get("pan_seg", None) + if pan_seg is None and "pan_seg_file_name" in dic: + with PathManager.open(dic["pan_seg_file_name"], "rb") as f: + pan_seg = Image.open(f) + pan_seg = np.asarray(pan_seg) + from panopticapi.utils import rgb2id + + pan_seg = rgb2id(pan_seg) + if pan_seg is not None: + segments_info = dic["segments_info"] + pan_seg = torch.tensor(pan_seg) + self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.7) + return self.output + + def overlay_instances( + self, + *, + boxes=None, + labels=None, + masks=None, + keypoints=None, + assigned_colors=None, + alpha=0.5, + ): + """ + Args: + boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`, + or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image, + or a :class:`RotatedBoxes`, + or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image, + labels (list[str]): the text to be displayed for each instance. + masks (masks-like object): Supported types are: + + * :class:`detectron2.structures.PolygonMasks`, + :class:`detectron2.structures.BitMasks`. + * list[list[ndarray]]: contains the segmentation masks for all objects in one image. + The first level of the list corresponds to individual instances. The second + level to all the polygon that compose the instance, and the third level + to the polygon coordinates. The third level should have the format of + [x0, y0, x1, y1, ..., xn, yn] (n >= 3). + * list[ndarray]: each ndarray is a binary mask of shape (H, W). + * list[dict]: each dict is a COCO-style RLE. + keypoints (Keypoint or array like): an array-like object of shape (N, K, 3), + where the N is the number of instances and K is the number of keypoints. + The last dimension corresponds to (x, y, visibility or score). + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = 0 + if boxes is not None: + boxes = self._convert_boxes(boxes) + num_instances = len(boxes) + if masks is not None: + masks = self._convert_masks(masks) + if num_instances: + assert len(masks) == num_instances + else: + num_instances = len(masks) + if keypoints is not None: + if num_instances: + assert len(keypoints) == num_instances + else: + num_instances = len(keypoints) + keypoints = self._convert_keypoints(keypoints) + if labels is not None: + assert len(labels) == num_instances + if assigned_colors is None: + assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] + if num_instances == 0: + return self.output + if boxes is not None and boxes.shape[1] == 5: + return self.overlay_rotated_instances( + boxes=boxes, labels=labels, assigned_colors=assigned_colors + ) + + # Display in largest to smallest order to reduce occlusion. + areas = None + if boxes is not None: + areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1) + elif masks is not None: + areas = np.asarray([x.area() for x in masks]) + + if areas is not None: + sorted_idxs = np.argsort(-areas).tolist() + # Re-order overlapped instances in descending order. + boxes = boxes[sorted_idxs] if boxes is not None else None + labels = [labels[k] for k in sorted_idxs] if labels is not None else None + masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None + assigned_colors = [assigned_colors[idx] for idx in sorted_idxs] + keypoints = keypoints[sorted_idxs] if keypoints is not None else None + + for i in range(num_instances): + color = assigned_colors[i] + if boxes is not None: + self.draw_box(boxes[i], edge_color=color) + + if masks is not None: + for segment in masks[i].polygons: + self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha) + + if labels is not None: + # first get a box + if boxes is not None: + x0, y0, x1, y1 = boxes[i] + text_pos = (x0, y0) # if drawing boxes, put text on the box corner. + horiz_align = "left" + elif masks is not None: + # skip small mask without polygon + if len(masks[i].polygons) == 0: + continue + + x0, y0, x1, y1 = masks[i].bbox() + + # draw text in the center (defined by median) when box is not drawn + # median is less sensitive to outliers. + text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1] + horiz_align = "center" + else: + continue # drawing the box confidence for keypoints isn't very useful. + # for small objects, draw text at the side to avoid occlusion + instance_area = (y1 - y0) * (x1 - x0) + if ( + instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale + or y1 - y0 < 40 * self.output.scale + ): + if y1 >= self.output.height - 5: + text_pos = (x1, y0) + else: + text_pos = (x0, y1) + + height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width) + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) + * 0.5 + * self._default_font_size + ) + self.draw_text( + labels[i], + text_pos, + color=lighter_color, + horizontal_alignment=horiz_align, + font_size=font_size, + ) + + # draw keypoints + if keypoints is not None: + for keypoints_per_instance in keypoints: + self.draw_and_connect_keypoints(keypoints_per_instance) + + return self.output + + def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None): + """ + Args: + boxes (ndarray): an Nx5 numpy array of + (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image. + labels (list[str]): the text to be displayed for each instance. + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = len(boxes) + + if assigned_colors is None: + assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] + if num_instances == 0: + return self.output + + # Display in largest to smallest order to reduce occlusion. + if boxes is not None: + areas = boxes[:, 2] * boxes[:, 3] + + sorted_idxs = np.argsort(-areas).tolist() + # Re-order overlapped instances in descending order. + boxes = boxes[sorted_idxs] + labels = [labels[k] for k in sorted_idxs] if labels is not None else None + colors = [assigned_colors[idx] for idx in sorted_idxs] + + for i in range(num_instances): + self.draw_rotated_box_with_label( + boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None + ) + + return self.output + + def draw_and_connect_keypoints(self, keypoints): + """ + Draws keypoints of an instance and follows the rules for keypoint connections + to draw lines between appropriate keypoints. This follows color heuristics for + line color. + + Args: + keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints + and the last dimension corresponds to (x, y, probability). + + Returns: + output (VisImage): image object with visualizations. + """ + visible = {} + keypoint_names = self.metadata.get("keypoint_names") + for idx, keypoint in enumerate(keypoints): + + # draw keypoint + x, y, prob = keypoint + if prob > self.keypoint_threshold: + self.draw_circle((x, y), color=_RED) + if keypoint_names: + keypoint_name = keypoint_names[idx] + visible[keypoint_name] = (x, y) + + if self.metadata.get("keypoint_connection_rules"): + for kp0, kp1, color in self.metadata.keypoint_connection_rules: + if kp0 in visible and kp1 in visible: + x0, y0 = visible[kp0] + x1, y1 = visible[kp1] + color = tuple(x / 255.0 for x in color) + self.draw_line([x0, x1], [y0, y1], color=color) + + # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip + # Note that this strategy is specific to person keypoints. + # For other keypoints, it should just do nothing + try: + ls_x, ls_y = visible["left_shoulder"] + rs_x, rs_y = visible["right_shoulder"] + mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2 + except KeyError: + pass + else: + # draw line from nose to mid-shoulder + nose_x, nose_y = visible.get("nose", (None, None)) + if nose_x is not None: + self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED) + + try: + # draw line from mid-shoulder to mid-hip + lh_x, lh_y = visible["left_hip"] + rh_x, rh_y = visible["right_hip"] + except KeyError: + pass + else: + mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2 + self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED) + return self.output + + """ + Primitive drawing functions: + """ + + def draw_text( + self, + text, + position, + *, + font_size=None, + color="g", + horizontal_alignment="center", + rotation=0, + ): + """ + Args: + text (str): class label + position (tuple): a tuple of the x and y coordinates to place text on image. + font_size (int, optional): font of the text. If not provided, a font size + proportional to the image width is calculated and used. + color: color of the text. Refer to `matplotlib.colors` for full list + of formats that are accepted. + horizontal_alignment (str): see `matplotlib.text.Text` + rotation: rotation angle in degrees CCW + + Returns: + output (VisImage): image object with text drawn. + """ + if not font_size: + font_size = self._default_font_size + + # since the text background is dark, we don't want the text to be dark + color = np.maximum(list(mplc.to_rgb(color)), 0.15) + color[np.argmax(color)] = max(0.8, np.max(color)) + + def contrasting_color(rgb): + """Returns 'white' or 'black' depending on which color contrasts more with the given RGB value.""" + + # Decompose the RGB tuple + R, G, B = rgb + + # Calculate the Y value + Y = 0.299 * R + 0.587 * G + 0.114 * B + + # If Y value is greater than 128, it's closer to white so return black. Otherwise, return white. + return 'black' if Y > 128 else 'white' + + bbox_background = contrasting_color(color*255) + + x, y = position + self.output.ax.text( + x, + y, + text, + size=font_size * self.output.scale, + family="sans-serif", + bbox={"facecolor": bbox_background, "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, + verticalalignment="top", + horizontalalignment=horizontal_alignment, + color=color, + zorder=10, + rotation=rotation, + ) + return self.output + + def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"): + """ + Args: + box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0 + are the coordinates of the image's top left corner. x1 and y1 are the + coordinates of the image's bottom right corner. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + + Returns: + output (VisImage): image object with box drawn. + """ + x0, y0, x1, y1 = box_coord + width = x1 - x0 + height = y1 - y0 + + linewidth = max(self._default_font_size / 12, 1) + + self.output.ax.add_patch( + mpl.patches.Rectangle( + (x0, y0), + width, + height, + fill=False, + edgecolor=edge_color, + linewidth=linewidth * self.output.scale, + alpha=alpha, + linestyle=line_style, + ) + ) + return self.output + + def draw_rotated_box_with_label( + self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None + ): + """ + Draw a rotated box with label on its top-left corner. + + Args: + rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle), + where cnt_x and cnt_y are the center coordinates of the box. + w and h are the width and height of the box. angle represents how + many degrees the box is rotated CCW with regard to the 0-degree box. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + label (string): label for rotated box. It will not be rendered when set to None. + + Returns: + output (VisImage): image object with box drawn. + """ + cnt_x, cnt_y, w, h, angle = rotated_box + area = w * h + # use thinner lines when the box is small + linewidth = self._default_font_size / ( + 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3 + ) + + theta = angle * math.pi / 180.0 + c = math.cos(theta) + s = math.sin(theta) + rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)] + # x: left->right ; y: top->down + rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect] + for k in range(4): + j = (k + 1) % 4 + self.draw_line( + [rotated_rect[k][0], rotated_rect[j][0]], + [rotated_rect[k][1], rotated_rect[j][1]], + color=edge_color, + linestyle="--" if k == 1 else line_style, + linewidth=linewidth, + ) + + if label is not None: + text_pos = rotated_rect[1] # topleft corner + + height_ratio = h / np.sqrt(self.output.height * self.output.width) + label_color = self._change_color_brightness(edge_color, brightness_factor=0.7) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size + ) + self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle) + + return self.output + + def draw_circle(self, circle_coord, color, radius=3): + """ + Args: + circle_coord (list(int) or tuple(int)): contains the x and y coordinates + of the center of the circle. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + radius (int): radius of the circle. + + Returns: + output (VisImage): image object with box drawn. + """ + x, y = circle_coord + self.output.ax.add_patch( + mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color) + ) + return self.output + + def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None): + """ + Args: + x_data (list[int]): a list containing x values of all the points being drawn. + Length of list should match the length of y_data. + y_data (list[int]): a list containing y values of all the points being drawn. + Length of list should match the length of x_data. + color: color of the line. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + linestyle: style of the line. Refer to `matplotlib.lines.Line2D` + for a full list of formats that are accepted. + linewidth (float or None): width of the line. When it's None, + a default value will be computed and used. + + Returns: + output (VisImage): image object with line drawn. + """ + if linewidth is None: + linewidth = self._default_font_size / 3 + linewidth = max(linewidth, 1) + self.output.ax.add_line( + mpl.lines.Line2D( + x_data, + y_data, + linewidth=linewidth * self.output.scale, + color=color, + linestyle=linestyle, + ) + ) + return self.output + + def draw_binary_mask( + self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.7, area_threshold=10 + ): + """ + Args: + binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and + W is the image width. Each value in the array is either a 0 or 1 value of uint8 + type. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + area_threshold (float): a connected component smaller than this area will not be shown. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + has_valid_segment = False + binary_mask = binary_mask.astype("uint8") # opencv needs uint8 + mask = GenericMask(binary_mask, self.output.height, self.output.width) + shape2d = (binary_mask.shape[0], binary_mask.shape[1]) + + if not mask.has_holes: + # draw polygons for regular masks + for segment in mask.polygons: + area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1])) + if area < (area_threshold or 0): + continue + has_valid_segment = True + segment = segment.reshape(-1, 2) + self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha) + else: + # TODO: Use Path/PathPatch to draw vector graphics: + # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha + has_valid_segment = True + self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) + + if text is not None and has_valid_segment: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_binary_mask_with_number( + self, binary_mask, color=None, *, edge_color=None, text=None, label_mode='1', alpha=0.1, anno_mode=['Mask'], area_threshold=10 + ): + """ + Args: + binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and + W is the image width. Each value in the array is either a 0 or 1 value of uint8 + type. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + area_threshold (float): a connected component smaller than this area will not be shown. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + randint = random.randint(0, len(self.color_proposals)-1) + color = self.color_proposals[randint] + color = mplc.to_rgb(color) + + has_valid_segment = True + binary_mask = binary_mask.astype("uint8") # opencv needs uint8 + mask = GenericMask(binary_mask, self.output.height, self.output.width) + shape2d = (binary_mask.shape[0], binary_mask.shape[1]) + bbox = mask.bbox() + + if 'Mask' in anno_mode: + if not mask.has_holes: + # draw polygons for regular masks + for segment in mask.polygons: + area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1])) + if area < (area_threshold or 0): + continue + has_valid_segment = True + segment = segment.reshape(-1, 2) + self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha) + else: + # TODO: Use Path/PathPatch to draw vector graphics: + # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha + has_valid_segment = True + self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) + + if 'Box' in anno_mode: + self.draw_box(bbox, edge_color=color, alpha=0.75) + + if 'Mark' in anno_mode: + has_valid_segment = True + else: + has_valid_segment = False + + if text is not None and has_valid_segment: + # lighter_color = tuple([x*0.2 for x in color]) + lighter_color = [1,1,1] # self._change_color_brightness(color, brightness_factor=0.7) + self._draw_number_in_mask(binary_mask, text, lighter_color, label_mode) + return self.output + + def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5): + """ + Args: + soft_mask (ndarray): float array of shape (H, W), each value in [0, 1]. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + shape2d = (soft_mask.shape[0], soft_mask.shape[1]) + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = soft_mask * alpha + self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) + + if text is not None: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + binary_mask = (soft_mask > 0.5).astype("uint8") + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_polygon(self, segment, color, edge_color=None, alpha=0.5): + """ + Args: + segment: numpy array of shape Nx2, containing all the points in the polygon. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. If not provided, a darker shade + of the polygon color will be used instead. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with polygon drawn. + """ + if edge_color is None: + # make edge color darker than the polygon color + if alpha > 0.8: + edge_color = self._change_color_brightness(color, brightness_factor=-0.7) + else: + edge_color = color + edge_color = mplc.to_rgb(edge_color) + (1,) + + polygon = mpl.patches.Polygon( + segment, + fill=True, + facecolor=mplc.to_rgb(color) + (alpha,), + edgecolor=edge_color, + linewidth=max(self._default_font_size // 15 * self.output.scale, 1), + ) + self.output.ax.add_patch(polygon) + return self.output + + """ + Internal methods: + """ + + def _jitter(self, color): + """ + Randomly modifies given color to produce a slightly different color than the color given. + + Args: + color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color + picked. The values in the list are in the [0.0, 1.0] range. + + Returns: + jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the + color after being jittered. The values in the list are in the [0.0, 1.0] range. + """ + color = mplc.to_rgb(color) + # np.random.seed(0) + vec = np.random.rand(3) + # better to do it in another color space + vec = vec / np.linalg.norm(vec) * 0.5 + res = np.clip(vec + color, 0, 1) + return tuple(res) + + def _create_grayscale_image(self, mask=None): + """ + Create a grayscale version of the original image. + The colors in masked area, if given, will be kept. + """ + img_bw = self.img.astype("f4").mean(axis=2) + img_bw = np.stack([img_bw] * 3, axis=2) + if mask is not None: + img_bw[mask] = self.img[mask] + return img_bw + + def _change_color_brightness(self, color, brightness_factor): + """ + Depending on the brightness_factor, gives a lighter or darker color i.e. a color with + less or more saturation than the original color. + + Args: + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of + 0 will correspond to no change, a factor in [-1.0, 0) range will result in + a darker color and a factor in (0, 1.0] range will result in a lighter color. + + Returns: + modified_color (tuple[double]): a tuple containing the RGB values of the + modified color. Each value in the tuple is in the [0.0, 1.0] range. + """ + assert brightness_factor >= -1.0 and brightness_factor <= 1.0 + color = mplc.to_rgb(color) + polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) + modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) + modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness + modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness + modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2]) + return modified_color + + def _convert_boxes(self, boxes): + """ + Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension. + """ + if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes): + return boxes.tensor.detach().numpy() + else: + return np.asarray(boxes) + + def _convert_masks(self, masks_or_polygons): + """ + Convert different format of masks or polygons to a tuple of masks and polygons. + + Returns: + list[GenericMask]: + """ + + m = masks_or_polygons + if isinstance(m, PolygonMasks): + m = m.polygons + if isinstance(m, BitMasks): + m = m.tensor.numpy() + if isinstance(m, torch.Tensor): + m = m.numpy() + ret = [] + for x in m: + if isinstance(x, GenericMask): + ret.append(x) + else: + ret.append(GenericMask(x, self.output.height, self.output.width)) + return ret + + def _draw_number_in_mask(self, binary_mask, text, color, label_mode='1'): + """ + Find proper places to draw text given a binary mask. + """ + + def number_to_string(n): + chars = [] + while n: + n, remainder = divmod(n-1, 26) + chars.append(chr(97 + remainder)) + return ''.join(reversed(chars)) + + binary_mask = np.pad(binary_mask, ((1, 1), (1, 1)), 'constant') + mask_dt = cv2.distanceTransform(binary_mask, cv2.DIST_L2, 0) + mask_dt = mask_dt[1:-1, 1:-1] + max_dist = np.max(mask_dt) + coords_y, coords_x = np.where(mask_dt == max_dist) # coords is [y, x] + + if label_mode == 'a': + text = number_to_string(int(text)) + else: + text = text + + self.draw_text(text, (coords_x[len(coords_x)//2] + 2, coords_y[len(coords_y)//2] - 6), color=color) + + # TODO sometimes drawn on wrong objects. the heuristics here can improve. + # _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8) + # if stats[1:, -1].size == 0: + # return + # largest_component_id = np.argmax(stats[1:, -1]) + 1 + + # # draw text on the largest component, as well as other very large components. + # for cid in range(1, _num_cc): + # if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH: + # # median is more stable than centroid + # # center = centroids[largest_component_id] + # center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1] + # # bottom=np.max((cc_labels == cid).nonzero(), axis=1)[::-1] + # # center[1]=bottom[1]+2 + # self.draw_text(text, center, color=color) + + def _draw_text_in_mask(self, binary_mask, text, color): + """ + Find proper places to draw text given a binary mask. + """ + # TODO sometimes drawn on wrong objects. the heuristics here can improve. + _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8) + if stats[1:, -1].size == 0: + return + largest_component_id = np.argmax(stats[1:, -1]) + 1 + + # draw text on the largest component, as well as other very large components. + for cid in range(1, _num_cc): + if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH: + # median is more stable than centroid + # center = centroids[largest_component_id] + center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1] + bottom=np.max((cc_labels == cid).nonzero(), axis=1)[::-1] + center[1]=bottom[1]+2 + self.draw_text(text, center, color=color) + + def _convert_keypoints(self, keypoints): + if isinstance(keypoints, Keypoints): + keypoints = keypoints.tensor + keypoints = np.asarray(keypoints) + return keypoints + + def get_output(self): + """ + Returns: + output (VisImage): the image output containing the visualizations added + to the image. + """ + return self.output \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f2c4cc3..97019b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,3 +32,4 @@ librosa pymupdf chardet playwright +backoff