diff --git a/.vscode/launch.json b/.vscode/launch.json index bc0f472..cf0e7fc 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,8 +11,8 @@ "program": "${file}", "console": "integratedTerminal", "args": [ - "--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx", - "--example_time_limit", "60" + "--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx" + // "--example_time_limit", "60" ] } ] diff --git a/demo.py b/demo.py deleted file mode 100644 index 736adfe..0000000 --- a/demo.py +++ /dev/null @@ -1,16 +0,0 @@ -import signal -import time - -def handler(signo, frame): - raise RuntimeError("Timeout") - -signal.signal(signal.SIGALRM, handler) - -while True: - try: - signal.alarm(5) # seconds - time.sleep(10) - print("Working...") - except Exception as e : - print(e) - continue \ No newline at end of file diff --git a/desktop_env/controllers/python.py b/desktop_env/controllers/python.py index 60a4bb4..4159cde 100644 --- a/desktop_env/controllers/python.py +++ b/desktop_env/controllers/python.py @@ -263,16 +263,19 @@ class PythonController: """ Ends recording the screen. """ - response = requests.post(self.http_server + "/end_recording") - if response.status_code == 200: - logger.info("Recording stopped successfully") - with open(dest, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - else: - logger.error("Failed to stop recording. Status code: %d", response.status_code) - return None + try: + response = requests.post(self.http_server + "/end_recording") + if response.status_code == 200: + logger.info("Recording stopped successfully") + with open(dest, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + else: + logger.error("Failed to stop recording. Status code: %d", response.status_code) + return None + except Exception as e: + logger.error("An error occurred while trying to download the recording: %s", e) # Additional info def get_vm_platform(self): diff --git a/evaluation_examples/examples/multi_apps/2b9493d7-49b8-493a-a71b-56cd1f4d6908.json b/evaluation_examples/examples/multi_apps/2b9493d7-49b8-493a-a71b-56cd1f4d6908.json index 99e148b..fd85e1b 100644 --- a/evaluation_examples/examples/multi_apps/2b9493d7-49b8-493a-a71b-56cd1f4d6908.json +++ b/evaluation_examples/examples/multi_apps/2b9493d7-49b8-493a-a71b-56cd1f4d6908.json @@ -9,7 +9,7 @@ "parameters": { "files": [ { - "url": "https://drive.usercontent.google.com/download?id=104pg3yochKyH2Uvlp3BdvKmHgYmSIESu&export=download&authuser=0&confirm=t&uuid=d1926366-4e54-4a44-8dcd-fc49ed6524d7&at=APZUnTXcBFV9kcacsA0toU83lMKJ:1706505549057d", + "url": "https://drive.usercontent.google.com/download?id=1gqqY56robX1tb4YPa3Yk1d72T_k-Rgz3&export=download&authuser=0&confirm=t", "path": "/home/user/Desktop/15-MB-docx-file-download.docx" } ] diff --git a/evaluation_examples/examples/multi_apps/3c8f201a-009d-4bbe-8b65-a6f8b35bb57f.json b/evaluation_examples/examples/multi_apps/3c8f201a-009d-4bbe-8b65-a6f8b35bb57f.json index 283a3ad..015e3a6 100644 --- a/evaluation_examples/examples/multi_apps/3c8f201a-009d-4bbe-8b65-a6f8b35bb57f.json +++ b/evaluation_examples/examples/multi_apps/3c8f201a-009d-4bbe-8b65-a6f8b35bb57f.json @@ -1,7 +1,7 @@ { "id": "3c8f201a-009d-4bbe-8b65-a6f8b35bb57f", "snapshot": "gimp", - "instruction": "Download the image from \"https://drive.google.com/uc?export=download&id=1i8j5dGS57sA07jEuPNAlQW-sn5uqUnuK\", and then use GIMP to compress it to under 600KB. Resize if needed.", + "instruction": "Download the image from \"https://drive.google.com/uc?export=download&id=1i8j5dGS57sA07jEuPNAlQW-sn5uqUnuK\", and then use GIMP to compress it to under 600KB as \"compressed.jpeg\" on the Desktop. Resize if needed.", "source": "", "config": [ { diff --git a/evaluation_examples/examples/multi_apps/e2392362-125e-4f76-a2ee-524b183a3412.json b/evaluation_examples/examples/multi_apps/e2392362-125e-4f76-a2ee-524b183a3412.json index ea08560..b591cfd 100644 --- a/evaluation_examples/examples/multi_apps/e2392362-125e-4f76-a2ee-524b183a3412.json +++ b/evaluation_examples/examples/multi_apps/e2392362-125e-4f76-a2ee-524b183a3412.json @@ -1,13 +1,17 @@ { "id": "e2392362-125e-4f76-a2ee-524b183a3412", "snapshot": "chrome", - "instruction": "I recently started using the famous personal academic homepage template from academicpages.github.io to build my own personal homepage, and I have cloned it to my local ~/Code/Website folder. According to an online tutorial, I can configure my name and contact information in the _config.yaml file. However, I am not familiar with the YAML file format. Please help me find the sections related to the name and contact information in this file and change them to “Test Account” and “Test@gmail.com”.", + "instruction": "I recently started using the famous personal academic homepage template from academicpages.github.io to build my own personal homepage, and I have cloned it to my local ~/Code/Website folder. According to an online tutorial, I can configure my name and contact information in the _config.yaml file. However, I am not familiar with the YAML file format. Please help me find the sections related to the name and contact information in this file and change them to \"Test Account\" and \"Test@gmail.com\".", "source": "authors", "config": [ { "type": "command", "parameters": { - "command": ["mkdir", "-p", "/home/user/Code/Website"] + "command": [ + "mkdir", + "-p", + "/home/user/Code/Website" + ] } }, { @@ -24,13 +28,22 @@ { "type": "execute", "parameters": { - "command": ["tar", "-xJvf", ".tmp.tar.xz", "-C", "/home/user/Code/Website/"] + "command": [ + "tar", + "-xJvf", + ".tmp.tar.xz", + "-C", + "/home/user/Code/Website/" + ] } }, { "type": "launch", "parameters": { - "command": ["google-chrome", "--remote-debugging-port=1337"] + "command": [ + "google-chrome", + "--remote-debugging-port=1337" + ] } }, { @@ -46,14 +59,20 @@ { "type": "chrome_open_tabs", "parameters": { - "urls_to_open": ["https://academicpages.github.io/"] + "urls_to_open": [ + "https://academicpages.github.io/" + ] } } ], "trajectory": "trajectories/e2392362-125e-4f76-a2ee-524b183a3412", - "related_apps": ["chrome", "os", "vscode"], + "related_apps": [ + "chrome", + "os", + "vscode" + ], "evaluator": { - "postconfig":[ + "postconfig": [ { "type": "execute", "parameters": { @@ -66,23 +85,33 @@ } ], "func": "check_json", - "options": {"is_yaml": true}, + "options": { + "is_yaml": true + }, "expected": { "type": "rule", "rules": { "expect": [ { - "key": ["name"], + "key": [ + "name" + ], "method": "eq", "ref": "Test Account" }, { - "key": ["author", "name"], + "key": [ + "author", + "name" + ], "method": "eq", "ref": "Test Account" }, { - "key": ["author", "email"], + "key": [ + "author", + "email" + ], "method": "eq", "ref": "Test@gmail.com" } @@ -95,4 +124,4 @@ "dest": "_config.yaml" } } -} +} \ No newline at end of file diff --git a/evaluation_examples/test_all.json b/evaluation_examples/test_all.json index 0514d47..7153d86 100644 --- a/evaluation_examples/test_all.json +++ b/evaluation_examples/test_all.json @@ -103,7 +103,6 @@ "1e8df695-bd1b-45b3-b557-e7d599cf7597", "ecb0df7a-4e8d-4a03-b162-053391d3afaf", "8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14", - "7b802dad-6e0f-4204-9815-d4e3f57627d8", "a01fbce3-2793-461f-ab86-43680ccbae25", "0326d92d-d218-48a8-9ca1-981cd6d064c7", "0a2e43bf-b26c-4631-a966-af9dfa12c9e5", @@ -380,7 +379,6 @@ "9439a27b-18ae-42d8-9778-5f68f891805e", "ae506c68-352c-4094-9caa-ee9d42052317", "ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae", - "c714dcee-cad3-4e12-8f3c-12bdcfcdb048", "930fdb3b-11a8-46fe-9bac-577332e2640e", "276cc624-87ea-4f08-ab93-f770e3790175", "9d425400-e9b2-4424-9a4b-d4c7abac4140", diff --git a/evaluation_examples/test_small.json b/evaluation_examples/test_small.json new file mode 100644 index 0000000..4c1feb7 --- /dev/null +++ b/evaluation_examples/test_small.json @@ -0,0 +1,102 @@ +{ + "chrome": [ + "bb5e4c0d-f964-439c-97b6-bdb9747de3f4", + "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3" + ], + "gimp": [ + "7a4deb26-d57d-4ea9-9a73-630f66a7b568", + "554785e9-4523-4e7a-b8e1-8016f565f56a" + ], + "libreoffice_calc": [ + "357ef137-7eeb-4c80-a3bb-0951f26a8aff", + "42e0a640-4f19-4b28-973d-729602b5a4a7" + ], + "libreoffice_impress": [ + "5d901039-a89c-4bfb-967b-bf66f4df075e", + "550ce7e7-747b-495f-b122-acdc4d0b8e54" + ], + "libreoffice_writer": [ + "0810415c-bde4-4443-9047-d5f70165a697", + "0a0faba3-5580-44df-965d-f562a99b291c" + ], + "multi_apps": [ + "2b9493d7-49b8-493a-a71b-56cd1f4d6908", + "46407397-a7d5-4c6b-92c6-dbe038b1457b", + "4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc", + "510f64c8-9bcc-4be1-8d30-638705850618", + "897e3b53-5d4d-444b-85cb-2cdc8a97d903", + "c867c42d-a52d-4a24-8ae3-f75d256b5618", + "e135df7c-7687-4ac0-a5f0-76b74438b53e", + "f7dfbef3-7697-431c-883a-db8583a4e4f9", + "6d72aad6-187a-4392-a4c4-ed87269c51cf", + "f918266a-b3e0-4914-865d-4faa564f1aef", + "da52d699-e8d2-4dc5-9191-a2199e0b6a9b", + "74d5859f-ed66-4d3e-aa0e-93d7a592ce41", + "b5062e3e-641c-4e3a-907b-ac864d2e7652", + "48d05431-6cd5-4e76-82eb-12b60d823f7d", + "eb303e01-261e-4972-8c07-c9b4e7a4922a", + "d1acdb87-bb67-4f30-84aa-990e56a09c92", + "deec51c9-3b1e-4b9e-993c-4776f20e8bb2", + "8e116af7-7db7-4e35-a68b-b0939c066c78", + "185f29bd-5da0-40a6-b69c-ba7f4e0324ef", + "2c1ebcd7-9c6d-4c9a-afad-900e381ecd5e", + "3a93cae4-ad3e-403e-8c12-65303b271818", + "1f18aa87-af6f-41ef-9853-cdb8f32ebdea", + "26150609-0da3-4a7d-8868-0faf9c5f01bb", + "7e287123-70ca-47b9-8521-47db09b69b14", + "e2392362-125e-4f76-a2ee-524b183a3412", + "26660ad1-6ebb-4f59-8cba-a8432dfe8d38", + "a82b78bb-7fde-4cb3-94a4-035baf10bcf0", + "36037439-2044-4b50-b9d1-875b5a332143", + "716a6079-22da-47f1-ba73-c9d58f986a38", + "a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a", + "6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a", + "da922383-bfa4-4cd3-bbad-6bebab3d7742", + "2373b66a-092d-44cb-bfd7-82e86e7a3b4d", + "81c425f5-78f3-4771-afd6-3d2973825947", + "227d2f97-562b-4ccb-ae47-a5ec9e142fbb", + "20236825-b5df-46e7-89bf-62e1d640a897", + "02ce9a50-7af2-47ed-8596-af0c230501f8", + "4c26e3f3-3a14-4d86-b44a-d3cedebbb487", + "09a37c51-e625-49f4-a514-20a773797a8a", + "3e3fc409-bff3-4905-bf16-c968eee3f807", + "415ef462-bed3-493a-ac36-ca8c6d23bf1b", + "9f3bb592-209d-43bc-bb47-d77d9df56504", + "dd60633f-2c72-42ba-8547-6f2c8cb0fdb0", + "3f05f3b9-29ba-4b6b-95aa-2204697ffc06", + "f8369178-fafe-40c2-adc4-b9b08a125456", + "778efd0a-153f-4842-9214-f05fc176b877", + "47f7c0ce-a5fb-4100-a5e6-65cd0e7429e5", + "c2751594-0cd5-4088-be1b-b5f2f9ec97c4", + "48c46dc7-fe04-4505-ade7-723cba1aa6f6", + "42d25c08-fb87-4927-8b65-93631280a26f", + "bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108", + "3c8f201a-009d-4bbe-8b65-a6f8b35bb57f", + "d68204bf-11c1-4b13-b48b-d303c73d4bf6", + "91190194-f406-4cd6-b3f9-c43fac942b22", + "7f35355e-02a6-45b5-b140-f0be698bcf85", + "98e8e339-5f91-4ed2-b2b2-12647cb134f4", + "df67aebb-fb3a-44fd-b75b-51b6012df509", + "5df7b33a-9f77-4101-823e-02f863e1c1ae", + "22a4636f-8179-4357-8e87-d1743ece1f81", + "236833a3-5704-47fc-888c-4f298f09f799" + ], + "os": [ + "5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57", + "5812b315-e7bd-4265-b51f-863c02174c28", + "43c2d64c-bab5-4dcb-a30c-b888321c319a", + "7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82" + ], + "thunderbird": [ + "bb5e4c0d-f964-439c-97b6-bdb9747de3f4", + "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3" + ], + "vlc": [ + "59f21cfb-0120-4326-b255-a5b827b38967", + "8f080098-ddb1-424c-b438-4e96e5e4786e" + ], + "vs_code": [ + "0ed39f63-6049-43d4-ba4d-5fa2fe04a951", + "53ad5833-3455-407b-bbc6-45b4c79ab8fb" + ] +} \ No newline at end of file diff --git a/lib_run_single.py b/lib_run_single.py new file mode 100644 index 0000000..ff9972d --- /dev/null +++ b/lib_run_single.py @@ -0,0 +1,72 @@ +import datetime +import json +import logging +import os +import wandb + +from wrapt_timeout_decorator import * + +logger = logging.getLogger("desktopenv.experiment") + +# Open the JSON file +with open("./settings.json", "r") as file: + # Load the JSON data from the file + data = json.load(file) +time_limit = data["time_limit"] + +@timeout(time_limit, use_signals=False) +def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores): + agent.reset() + obs = env.reset(task_config=example) + done = False + step_idx = 0 + env.controller.start_recording() + str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"]) + while not done and step_idx < max_steps: + response, actions = agent.predict( + instruction, + obs + ) + for action in actions: + # Capture the timestamp before executing the action + action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + logger.info("Step %d: %s", step_idx + 1, action) + obs, reward, done, info = env.step(action, args.sleep_after_execution) + + logger.info("Reward: %.2f", reward) + logger.info("Done: %s", done) + # Save screenshot and trajectory information + with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), + "wb") as _f: + with open(obs['screenshot'], "rb") as __f: + screenshot = __f.read() + _f.write(screenshot) + # get a11tree and save to wandb + thisrun_a11tree = env.controller.get_accessibility_tree() + str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"), + thisrun_a11tree, + response, action, action_timestamp, done) + wandb.log({"Reward": reward}) + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write(json.dumps({ + "step_num": step_idx + 1, + "action_timestamp": action_timestamp, + "action": action, + "reward": reward, + "done": done, + "info": info, + "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png" + })) + f.write("\n") + if done: + logger.info("The episode is done.") + break + step_idx += 1 + wandb.log({"str_trajectory": str_table}) + result = env.evaluate() + logger.info("Result: %.2f", result) + scores.append(result) + with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: + f.write(f"{result}\n") + env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + wandb.log({"Result": result}) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 039eda8..263e5ee 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -5,19 +5,21 @@ import os import re import time import uuid -import openai import xml.etree.ElementTree as ET from http import HTTPStatus from io import BytesIO from typing import Dict, List -from google.api_core.exceptions import InvalidArgument + import backoff import dashscope import google.generativeai as genai +import openai import requests +import wandb from PIL import Image +from google.api_core.exceptions import InvalidArgument -from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes +from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ @@ -422,7 +424,6 @@ class PromptAgent: # with open("messages.json", "w") as f: # f.write(json.dumps(messages, indent=4)) - logger.info("Generating content with GPT model: %s", self.model) response = self.call_llm({ "model": self.model, "messages": messages, @@ -441,7 +442,7 @@ class PromptAgent: actions = None self.thoughts.append("") - return actions + return response, actions @backoff.on_exception( backoff.expo, @@ -461,7 +462,7 @@ class PromptAgent: "Content-Type": "application/json", "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" } - # logger.info("Generating content with GPT model: %s", self.model) + logger.info("Generating content with GPT model: %s", self.model) response = requests.post( "https://api.openai.com/v1/chat/completions", headers=headers, @@ -488,55 +489,162 @@ class PromptAgent: else: return response.json()['choices'][0]['message']['content'] - # elif self.model.startswith("mistral"): - # print("Call mistral") - # messages = payload["messages"] - # max_tokens = payload["max_tokens"] - # - # misrtal_messages = [] - # - # for i, message in enumerate(messages): - # mistral_message = { - # "role": message["role"], - # "content": [] - # } - # - # for part in message["content"]: - # mistral_message['content'] = part['text'] if part['type'] == "text" else None - # - # misrtal_messages.append(mistral_message) - # - # # the mistral not support system message in our endpoint, so we concatenate it at the first user message - # if misrtal_messages[0]['role'] == "system": - # misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content'] - # misrtal_messages.pop(0) - # - # # openai.api_base = "http://localhost:8000/v1" - # # openai.api_key = "test" - # # response = openai.ChatCompletion.create( - # # messages=misrtal_messages, - # # model="Mixtral-8x7B-Instruct-v0.1" - # # ) - # - # from openai import OpenAI - # TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2" - # - # client = OpenAI(api_key=TOGETHER_API_KEY, - # base_url='https://api.together.xyz', - # ) - # logger.info("Generating content with Mistral model: %s", self.model) - # response = client.chat.completions.create( - # messages=misrtal_messages, - # model="mistralai/Mixtral-8x7B-Instruct-v0.1", - # max_tokens=1024 - # ) - # - # try: - # # return response['choices'][0]['message']['content'] - # return response.choices[0].message.content - # except Exception as e: - # print("Failed to call LLM: " + str(e)) - # return "" + elif self.model.startswith("claude"): + messages = payload["messages"] + max_tokens = payload["max_tokens"] + top_p = payload["top_p"] + temperature = payload["temperature"] + + claude_messages = [] + + for i, message in enumerate(messages): + claude_message = { + "role": message["role"], + "content": [] + } + assert len(message["content"]) in [1, 2], "One text, or one text with one image" + for part in message["content"]: + + if part['type'] == "image_url": + image_source = {} + image_source["type"] = "base64" + image_source["media_type"] = "image/png" + image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "") + claude_message['content'].append({"type": "image", "source": image_source}) + + if part['type'] == "text": + claude_message['content'].append({"type": "text", "text": part['text']}) + + claude_messages.append(claude_message) + + # the claude not support system message in our endpoint, so we concatenate it at the first user message + if claude_messages[0]['role'] == "system": + claude_system_message_item = claude_messages[0]['content'][0] + claude_messages[1]['content'].insert(0, claude_system_message_item) + claude_messages.pop(0) + + headers = { + "x-api-key": os.environ["ANTHROPIC_API_KEY"], + "anthropic-version": "2023-06-01", + "content-type": "application/json" + } + + payload = { + "model": self.model, + "max_tokens": max_tokens, + "messages": claude_messages + } + + response = requests.post( + "https://api.anthropic.com/v1/messages", + headers=headers, + json=payload + ) + + if response.status_code != 200: + + logger.error("Failed to call LLM: " + response.text) + time.sleep(5) + return "" + else: + return response.json()['content'][0]['text'] + + + elif self.model.startswith("mistral"): + print("Call mistral") + messages = payload["messages"] + max_tokens = payload["max_tokens"] + top_p = payload["top_p"] + temperature = payload["temperature"] + + misrtal_messages = [] + + for i, message in enumerate(messages): + mistral_message = { + "role": message["role"], + "content": "" + } + + for part in message["content"]: + mistral_message['content'] = part['text'] if part['type'] == "text" else "" + + + misrtal_messages.append(mistral_message) + + + # openai.api_base = "http://localhost:8000/v1" + # response = openai.ChatCompletion.create( + # messages=misrtal_messages, + # model="Mixtral-8x7B-Instruct-v0.1" + # ) + + from openai import OpenAI + + client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"], + base_url='https://api.together.xyz', + ) + logger.info("Generating content with Mistral model: %s", self.model) + + response = client.chat.completions.create( + messages=misrtal_messages, + model=self.model, + max_tokens=max_tokens + ) + + try: + return response.choices[0].message.content + except Exception as e: + print("Failed to call LLM: " + str(e)) + return "" + + elif self.model.startswith("THUDM"): + # THUDM/cogagent-chat-hf + print("Call CogAgent") + messages = payload["messages"] + max_tokens = payload["max_tokens"] + top_p = payload["top_p"] + temperature = payload["temperature"] + + cog_messages = [] + + for i, message in enumerate(messages): + cog_message = { + "role": message["role"], + "content": [] + } + + for part in message["content"]: + if part['type'] == "image_url": + cog_message['content'].append({"type": "image_url", "image_url": {"url": part['image_url']['url'] } }) + + if part['type'] == "text": + cog_message['content'].append({"type": "text", "text": part['text']}) + + cog_messages.append(cog_message) + + # the cogagent not support system message in our endpoint, so we concatenate it at the first user message + if cog_messages[0]['role'] == "system": + cog_system_message_item = cog_messages[0]['content'][0] + cog_messages[1]['content'].insert(0, cog_system_message_item) + cog_messages.pop(0) + + payload = { + "model": self.model, + "max_tokens": max_tokens, + "messages": cog_messages + } + + base_url = "http://127.0.0.1:8000" + + response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=False) + if response.status_code == 200: + decoded_line = response.json() + content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "") + return content + else: + print("Failed to call LLM: ", response.status_code) + return "" + elif self.model.startswith("gemini"): def encoded_img_to_pil_img(data_str): @@ -612,6 +720,7 @@ class PromptAgent: try: return response.text except Exception as e: + logger.error("Meet exception when calling Gemini API, " + str(e)) return "" elif self.model.startswith("qwen"): messages = payload["messages"] diff --git a/run.py b/run.py index c56e142..4284169 100644 --- a/run.py +++ b/run.py @@ -6,13 +6,17 @@ import datetime import json import logging import os -import signal +import random import sys +import wandb +from tqdm import tqdm + +import lib_run_single from desktop_env.envs.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent -# Logger Configs {{{ # +# Logger Configs {{{ # logger = logging.getLogger() logger.setLevel(logging.DEBUG) @@ -46,13 +50,10 @@ logger.addHandler(sdebug_handler) logger = logging.getLogger("desktopenv.experiment") - -# make sure each example won't exceed the time limit -def handler(signo, frame): - raise RuntimeError("Time limit exceeded!") - - -signal.signal(signal.SIGALRM, handler) +# wandb config +### set your wandb api key here +os.environ["WANDB_API_KEY"] = "" +wandb.login(key=os.environ["WANDB_API_KEY"]) def config() -> argparse.Namespace: @@ -75,7 +76,7 @@ def config() -> argparse.Namespace: "screenshot_a11y_tree", "som" ], - default="som", + default="a11y_tree", help="Observation type", ) parser.add_argument("--screen_width", type=int, default=1920) @@ -86,10 +87,9 @@ def config() -> argparse.Namespace: # agent config parser.add_argument("--max_trajectory_length", type=int, default=3) parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples") - parser.add_argument("--example_time_limit", type=int, default=600) # lm config - parser.add_argument("--model", type=str, default="gpt-4-vision-preview") + parser.add_argument("--model", type=str, default="gpt-4-0125-preview") parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--max_tokens", type=int, default=1500) @@ -108,10 +108,28 @@ def test( ) -> None: scores = [] max_steps = args.max_steps - time_limit = args.example_time_limit # log args logger.info("Args: %s", args) + # set wandb project + cfg_args = \ + { + "path_to_vm": args.path_to_vm, + "headless": args.headless, + "action_space": args.action_space, + "observation_type": args.observation_type, + "screen_width": args.screen_width, + "screen_height": args.screen_height, + "sleep_after_execution": args.sleep_after_execution, + "max_steps": args.max_steps, + "max_trajectory_length": args.max_trajectory_length, + "model": args.model, + "temperature": args.temperature, + "top_p": args.top_p, + "max_tokens": args.max_tokens, + "stop_token": args.stop_token, + "result_dir": args.result_dir + } agent = PromptAgent( model=args.model, @@ -128,8 +146,10 @@ def test( headless=args.headless, ) - for domain in test_all_meta: - for example_id in test_all_meta[domain]: + for domain in tqdm(test_all_meta, desc="Domain"): + for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): + wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}", + name=f"{example_id}") # example setting config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") with open(config_file, "r", encoding="utf-8") as f: @@ -141,6 +161,10 @@ def test( instruction = example["instruction"] logger.info(f"[Instruction]: {instruction}") + # wandb each example config settings + cfg_args["instruction"] = instruction + cfg_args["start_time"] = datetime.datetime.now().strftime("%Y:%m:%d-%H:%M:%S") + wandb.config.update(cfg_args) example_result_dir = os.path.join( args.result_dir, @@ -151,79 +175,26 @@ def test( example_id ) os.makedirs(example_result_dir, exist_ok=True) - # example start running try: - signal.alarm(time_limit) - agent.reset() - obs = env.reset(task_config=example) - done = False - step_idx = 0 - env.controller.start_recording() - - while not done and step_idx < max_steps: - actions = agent.predict( - instruction, - obs - ) - for action in actions: - # Capture the timestamp before executing the action - action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") - logger.info("Step %d: %s", step_idx + 1, action) - - obs, reward, done, info = env.step(action, args.sleep_after_execution) - - logger.info("Reward: %.2f", reward) - logger.info("Done: %s", done) - logger.info("Info: %s", info) - - # Save screenshot and trajectory information - with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), - "wb") as _f: - with open(obs['screenshot'], "rb") as __f: - screenshot = __f.read() - _f.write(screenshot) - - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write(json.dumps({ - "step_num": step_idx + 1, - "action_timestamp": action_timestamp, - "action": action, - "reward": reward, - "done": done, - "info": info, - "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png" - })) - f.write("\n") - - if done: - logger.info("The episode is done.") - break - step_idx += 1 - - result = env.evaluate() - logger.info("Result: %.2f", result) - scores.append(result) + lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, + scores) + except Exception as e: + logger.error(f"Exception in {domain}/{example_id}: {e}") + wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])}) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) - except RuntimeError as e: - logger.error(f"Error in example {domain}/{example_id}: {e}") - # save info of this example and then continue - try: - env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write(json.dumps({ - "Error": f"Error in example {domain}/{example_id}: {e}", - "step": step_idx + 1, - })) - f.write("\n") - except Exception as new_e: - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write(json.dumps({ - "Error": f"Error in example {domain}/{example_id}: {e} and {new_e}", - "step": "before start recording", - })) - f.write("\n") - continue + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write(json.dumps({ + "Error": f"Time limit exceeded in {domain}/{example_id}" + })) + f.write("\n") + # wandb settings + os.mkdir(os.path.join(wandb.run.dir, "results/")) + for file in os.listdir(example_result_dir): + # move file to just under the root dir + os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}")) + wandb.finish() + env.close() logger.info(f"Average score: {sum(scores) / len(scores)}") @@ -236,9 +207,18 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_ finished = {} for domain in os.listdir(target_dir): + finished[domain] = [] domain_path = os.path.join(target_dir, domain) if os.path.isdir(domain_path): - finished[domain] = os.listdir(domain_path) + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) if not finished: return total_file_json @@ -250,6 +230,35 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_ return total_file_json +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read())) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + if __name__ == '__main__': ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -270,4 +279,10 @@ if __name__ == '__main__': left_info += f"{domain}: {len(test_file_list[domain])}\n" logger.info(f"Left tasks:\n{left_info}") - test(args, test_all_meta) + # get_result(args.action_space, + # args.model, + # args.observation_type, + # args.result_dir, + # test_all_meta + # ) + test(args, test_file_list) diff --git a/settings.json b/settings.json new file mode 100644 index 0000000..23bab77 --- /dev/null +++ b/settings.json @@ -0,0 +1,3 @@ +{ + "time_limit": "10" +} \ No newline at end of file