diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..cf0e7fc --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,19 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File with Arguments", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "args": [ + "--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx" + // "--example_time_limit", "60" + ] + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 8eb867f..6262044 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,12 @@ Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw) 2. Install the environment package, download the examples and the virtual machine image. +For x86_64 Linux or Windows, you can install the environment package and download the examples and the virtual machine image by running the following commands: ```bash pip install desktop-env gdown xxxx -gdown xxxx +vmrun -T ws start "Ubuntu/Ubuntu.vmx" nogui +vmrun -T ws snapshot "Ubuntu/Ubuntu.vmx" "init_state" ``` ## Quick Start diff --git a/desktop_env/controllers/python.py b/desktop_env/controllers/python.py index 60a4bb4..4159cde 100644 --- a/desktop_env/controllers/python.py +++ b/desktop_env/controllers/python.py @@ -263,16 +263,19 @@ class PythonController: """ Ends recording the screen. """ - response = requests.post(self.http_server + "/end_recording") - if response.status_code == 200: - logger.info("Recording stopped successfully") - with open(dest, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - else: - logger.error("Failed to stop recording. Status code: %d", response.status_code) - return None + try: + response = requests.post(self.http_server + "/end_recording") + if response.status_code == 200: + logger.info("Recording stopped successfully") + with open(dest, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + else: + logger.error("Failed to stop recording. Status code: %d", response.status_code) + return None + except Exception as e: + logger.error("An error occurred while trying to download the recording: %s", e) # Additional info def get_vm_platform(self): diff --git a/evaluation_examples/examples/multi_apps/2b9493d7-49b8-493a-a71b-56cd1f4d6908.json b/evaluation_examples/examples/multi_apps/2b9493d7-49b8-493a-a71b-56cd1f4d6908.json index 99e148b..fd85e1b 100644 --- a/evaluation_examples/examples/multi_apps/2b9493d7-49b8-493a-a71b-56cd1f4d6908.json +++ b/evaluation_examples/examples/multi_apps/2b9493d7-49b8-493a-a71b-56cd1f4d6908.json @@ -9,7 +9,7 @@ "parameters": { "files": [ { - "url": "https://drive.usercontent.google.com/download?id=104pg3yochKyH2Uvlp3BdvKmHgYmSIESu&export=download&authuser=0&confirm=t&uuid=d1926366-4e54-4a44-8dcd-fc49ed6524d7&at=APZUnTXcBFV9kcacsA0toU83lMKJ:1706505549057d", + "url": "https://drive.usercontent.google.com/download?id=1gqqY56robX1tb4YPa3Yk1d72T_k-Rgz3&export=download&authuser=0&confirm=t", "path": "/home/user/Desktop/15-MB-docx-file-download.docx" } ] diff --git a/evaluation_examples/examples/multi_apps/3c8f201a-009d-4bbe-8b65-a6f8b35bb57f.json b/evaluation_examples/examples/multi_apps/3c8f201a-009d-4bbe-8b65-a6f8b35bb57f.json index 283a3ad..015e3a6 100644 --- a/evaluation_examples/examples/multi_apps/3c8f201a-009d-4bbe-8b65-a6f8b35bb57f.json +++ b/evaluation_examples/examples/multi_apps/3c8f201a-009d-4bbe-8b65-a6f8b35bb57f.json @@ -1,7 +1,7 @@ { "id": "3c8f201a-009d-4bbe-8b65-a6f8b35bb57f", "snapshot": "gimp", - "instruction": "Download the image from \"https://drive.google.com/uc?export=download&id=1i8j5dGS57sA07jEuPNAlQW-sn5uqUnuK\", and then use GIMP to compress it to under 600KB. Resize if needed.", + "instruction": "Download the image from \"https://drive.google.com/uc?export=download&id=1i8j5dGS57sA07jEuPNAlQW-sn5uqUnuK\", and then use GIMP to compress it to under 600KB as \"compressed.jpeg\" on the Desktop. Resize if needed.", "source": "", "config": [ { diff --git a/evaluation_examples/examples/multi_apps/demo.py b/evaluation_examples/examples/multi_apps/demo.py new file mode 100644 index 0000000..ffa2b85 --- /dev/null +++ b/evaluation_examples/examples/multi_apps/demo.py @@ -0,0 +1,19 @@ +import pandas as pd + +file_path = "/Users/lxc/Downloads/Speedtest.csv" +# 找到csv第二行的第二个数据格里的值 +# with open(file_path, "r") as f: +# for i, line in enumerate(f): +# if i == 1: +# data = line.split(",")[1] +# break +# print(data) + +with open(file_path, "r") as f: + reader = pd.read_csv(f, sep=',', header=None) + # for column in reader.columns: + # if column.startswith("TEST_DATE"): + # data_col = column + # break + for data in reader['TEST_DATE']: + print(data) \ No newline at end of file diff --git a/evaluation_examples/examples/multi_apps/e2392362-125e-4f76-a2ee-524b183a3412.json b/evaluation_examples/examples/multi_apps/e2392362-125e-4f76-a2ee-524b183a3412.json index ea08560..b591cfd 100644 --- a/evaluation_examples/examples/multi_apps/e2392362-125e-4f76-a2ee-524b183a3412.json +++ b/evaluation_examples/examples/multi_apps/e2392362-125e-4f76-a2ee-524b183a3412.json @@ -1,13 +1,17 @@ { "id": "e2392362-125e-4f76-a2ee-524b183a3412", "snapshot": "chrome", - "instruction": "I recently started using the famous personal academic homepage template from academicpages.github.io to build my own personal homepage, and I have cloned it to my local ~/Code/Website folder. According to an online tutorial, I can configure my name and contact information in the _config.yaml file. However, I am not familiar with the YAML file format. Please help me find the sections related to the name and contact information in this file and change them to “Test Account” and “Test@gmail.com”.", + "instruction": "I recently started using the famous personal academic homepage template from academicpages.github.io to build my own personal homepage, and I have cloned it to my local ~/Code/Website folder. According to an online tutorial, I can configure my name and contact information in the _config.yaml file. However, I am not familiar with the YAML file format. Please help me find the sections related to the name and contact information in this file and change them to \"Test Account\" and \"Test@gmail.com\".", "source": "authors", "config": [ { "type": "command", "parameters": { - "command": ["mkdir", "-p", "/home/user/Code/Website"] + "command": [ + "mkdir", + "-p", + "/home/user/Code/Website" + ] } }, { @@ -24,13 +28,22 @@ { "type": "execute", "parameters": { - "command": ["tar", "-xJvf", ".tmp.tar.xz", "-C", "/home/user/Code/Website/"] + "command": [ + "tar", + "-xJvf", + ".tmp.tar.xz", + "-C", + "/home/user/Code/Website/" + ] } }, { "type": "launch", "parameters": { - "command": ["google-chrome", "--remote-debugging-port=1337"] + "command": [ + "google-chrome", + "--remote-debugging-port=1337" + ] } }, { @@ -46,14 +59,20 @@ { "type": "chrome_open_tabs", "parameters": { - "urls_to_open": ["https://academicpages.github.io/"] + "urls_to_open": [ + "https://academicpages.github.io/" + ] } } ], "trajectory": "trajectories/e2392362-125e-4f76-a2ee-524b183a3412", - "related_apps": ["chrome", "os", "vscode"], + "related_apps": [ + "chrome", + "os", + "vscode" + ], "evaluator": { - "postconfig":[ + "postconfig": [ { "type": "execute", "parameters": { @@ -66,23 +85,33 @@ } ], "func": "check_json", - "options": {"is_yaml": true}, + "options": { + "is_yaml": true + }, "expected": { "type": "rule", "rules": { "expect": [ { - "key": ["name"], + "key": [ + "name" + ], "method": "eq", "ref": "Test Account" }, { - "key": ["author", "name"], + "key": [ + "author", + "name" + ], "method": "eq", "ref": "Test Account" }, { - "key": ["author", "email"], + "key": [ + "author", + "email" + ], "method": "eq", "ref": "Test@gmail.com" } @@ -95,4 +124,4 @@ "dest": "_config.yaml" } } -} +} \ No newline at end of file diff --git a/evaluation_examples/test_all.json b/evaluation_examples/test_all.json index 0514d47..7153d86 100644 --- a/evaluation_examples/test_all.json +++ b/evaluation_examples/test_all.json @@ -103,7 +103,6 @@ "1e8df695-bd1b-45b3-b557-e7d599cf7597", "ecb0df7a-4e8d-4a03-b162-053391d3afaf", "8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14", - "7b802dad-6e0f-4204-9815-d4e3f57627d8", "a01fbce3-2793-461f-ab86-43680ccbae25", "0326d92d-d218-48a8-9ca1-981cd6d064c7", "0a2e43bf-b26c-4631-a966-af9dfa12c9e5", @@ -380,7 +379,6 @@ "9439a27b-18ae-42d8-9778-5f68f891805e", "ae506c68-352c-4094-9caa-ee9d42052317", "ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae", - "c714dcee-cad3-4e12-8f3c-12bdcfcdb048", "930fdb3b-11a8-46fe-9bac-577332e2640e", "276cc624-87ea-4f08-ab93-f770e3790175", "9d425400-e9b2-4424-9a4b-d4c7abac4140", diff --git a/evaluation_examples/test_small.json b/evaluation_examples/test_small.json new file mode 100644 index 0000000..4c1feb7 --- /dev/null +++ b/evaluation_examples/test_small.json @@ -0,0 +1,102 @@ +{ + "chrome": [ + "bb5e4c0d-f964-439c-97b6-bdb9747de3f4", + "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3" + ], + "gimp": [ + "7a4deb26-d57d-4ea9-9a73-630f66a7b568", + "554785e9-4523-4e7a-b8e1-8016f565f56a" + ], + "libreoffice_calc": [ + "357ef137-7eeb-4c80-a3bb-0951f26a8aff", + "42e0a640-4f19-4b28-973d-729602b5a4a7" + ], + "libreoffice_impress": [ + "5d901039-a89c-4bfb-967b-bf66f4df075e", + "550ce7e7-747b-495f-b122-acdc4d0b8e54" + ], + "libreoffice_writer": [ + "0810415c-bde4-4443-9047-d5f70165a697", + "0a0faba3-5580-44df-965d-f562a99b291c" + ], + "multi_apps": [ + "2b9493d7-49b8-493a-a71b-56cd1f4d6908", + "46407397-a7d5-4c6b-92c6-dbe038b1457b", + "4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc", + "510f64c8-9bcc-4be1-8d30-638705850618", + "897e3b53-5d4d-444b-85cb-2cdc8a97d903", + "c867c42d-a52d-4a24-8ae3-f75d256b5618", + "e135df7c-7687-4ac0-a5f0-76b74438b53e", + "f7dfbef3-7697-431c-883a-db8583a4e4f9", + "6d72aad6-187a-4392-a4c4-ed87269c51cf", + "f918266a-b3e0-4914-865d-4faa564f1aef", + "da52d699-e8d2-4dc5-9191-a2199e0b6a9b", + "74d5859f-ed66-4d3e-aa0e-93d7a592ce41", + "b5062e3e-641c-4e3a-907b-ac864d2e7652", + "48d05431-6cd5-4e76-82eb-12b60d823f7d", + "eb303e01-261e-4972-8c07-c9b4e7a4922a", + "d1acdb87-bb67-4f30-84aa-990e56a09c92", + "deec51c9-3b1e-4b9e-993c-4776f20e8bb2", + "8e116af7-7db7-4e35-a68b-b0939c066c78", + "185f29bd-5da0-40a6-b69c-ba7f4e0324ef", + "2c1ebcd7-9c6d-4c9a-afad-900e381ecd5e", + "3a93cae4-ad3e-403e-8c12-65303b271818", + "1f18aa87-af6f-41ef-9853-cdb8f32ebdea", + "26150609-0da3-4a7d-8868-0faf9c5f01bb", + "7e287123-70ca-47b9-8521-47db09b69b14", + "e2392362-125e-4f76-a2ee-524b183a3412", + "26660ad1-6ebb-4f59-8cba-a8432dfe8d38", + "a82b78bb-7fde-4cb3-94a4-035baf10bcf0", + "36037439-2044-4b50-b9d1-875b5a332143", + "716a6079-22da-47f1-ba73-c9d58f986a38", + "a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a", + "6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a", + "da922383-bfa4-4cd3-bbad-6bebab3d7742", + "2373b66a-092d-44cb-bfd7-82e86e7a3b4d", + "81c425f5-78f3-4771-afd6-3d2973825947", + "227d2f97-562b-4ccb-ae47-a5ec9e142fbb", + "20236825-b5df-46e7-89bf-62e1d640a897", + "02ce9a50-7af2-47ed-8596-af0c230501f8", + "4c26e3f3-3a14-4d86-b44a-d3cedebbb487", + "09a37c51-e625-49f4-a514-20a773797a8a", + "3e3fc409-bff3-4905-bf16-c968eee3f807", + "415ef462-bed3-493a-ac36-ca8c6d23bf1b", + "9f3bb592-209d-43bc-bb47-d77d9df56504", + "dd60633f-2c72-42ba-8547-6f2c8cb0fdb0", + "3f05f3b9-29ba-4b6b-95aa-2204697ffc06", + "f8369178-fafe-40c2-adc4-b9b08a125456", + "778efd0a-153f-4842-9214-f05fc176b877", + "47f7c0ce-a5fb-4100-a5e6-65cd0e7429e5", + "c2751594-0cd5-4088-be1b-b5f2f9ec97c4", + "48c46dc7-fe04-4505-ade7-723cba1aa6f6", + "42d25c08-fb87-4927-8b65-93631280a26f", + "bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108", + "3c8f201a-009d-4bbe-8b65-a6f8b35bb57f", + "d68204bf-11c1-4b13-b48b-d303c73d4bf6", + "91190194-f406-4cd6-b3f9-c43fac942b22", + "7f35355e-02a6-45b5-b140-f0be698bcf85", + "98e8e339-5f91-4ed2-b2b2-12647cb134f4", + "df67aebb-fb3a-44fd-b75b-51b6012df509", + "5df7b33a-9f77-4101-823e-02f863e1c1ae", + "22a4636f-8179-4357-8e87-d1743ece1f81", + "236833a3-5704-47fc-888c-4f298f09f799" + ], + "os": [ + "5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57", + "5812b315-e7bd-4265-b51f-863c02174c28", + "43c2d64c-bab5-4dcb-a30c-b888321c319a", + "7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82" + ], + "thunderbird": [ + "bb5e4c0d-f964-439c-97b6-bdb9747de3f4", + "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3" + ], + "vlc": [ + "59f21cfb-0120-4326-b255-a5b827b38967", + "8f080098-ddb1-424c-b438-4e96e5e4786e" + ], + "vs_code": [ + "0ed39f63-6049-43d4-ba4d-5fa2fe04a951", + "53ad5833-3455-407b-bbc6-45b4c79ab8fb" + ] +} \ No newline at end of file diff --git a/lib_run_single.py b/lib_run_single.py new file mode 100644 index 0000000..d60fd7a --- /dev/null +++ b/lib_run_single.py @@ -0,0 +1,72 @@ +import datetime +import json +import logging +import os +import wandb + +from wrapt_timeout_decorator import * + +logger = logging.getLogger("desktopenv.experiment") + +# Open the JSON file +with open("./settings.json", "r") as file: + # Load the JSON data from the file + data = json.load(file) +time_limit = data["time_limit"] + +@timeout(time_limit, use_signals=False) +def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores, run): + agent.reset() + obs = env.reset(task_config=example) + done = False + step_idx = 0 + env.controller.start_recording() + str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"]) + while not done and step_idx < max_steps: + response, actions = agent.predict( + instruction, + obs + ) + for action in actions: + # Capture the timestamp before executing the action + action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + logger.info("Step %d: %s", step_idx + 1, action) + obs, reward, done, info = env.step(action, args.sleep_after_execution) + + logger.info("Reward: %.2f", reward) + logger.info("Done: %s", done) + # Save screenshot and trajectory information + with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), + "wb") as _f: + with open(obs['screenshot'], "rb") as __f: + screenshot = __f.read() + _f.write(screenshot) + # get a11tree and save to wandb + thisrun_a11tree = env.controller.get_accessibility_tree() + str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"), + thisrun_a11tree, + response, action, action_timestamp, done) + run.log({"Reward": reward}) + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write(json.dumps({ + "step_num": step_idx + 1, + "action_timestamp": action_timestamp, + "action": action, + "reward": reward, + "done": done, + "info": info, + "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png" + })) + f.write("\n") + if done: + logger.info("The episode is done.") + break + step_idx += 1 + run.log({"str_trajectory": str_table}) + result = env.evaluate() + logger.info("Result: %.2f", result) + scores.append(result) + with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: + f.write(f"{result}\n") + env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + run.log({"Result": result}) diff --git a/main.py b/main.py index bdb2e6a..06debec 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ import logging import os import sys import time - +import argparse from desktop_env.envs.desktop_env import DesktopEnv # Logger Configs {{{ # @@ -46,18 +46,29 @@ def human_agent(): """ Runs the Gym environment with human input. """ + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.") + parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.") + parser.add_argument('-e', '--example', type=str, help="Path to the example json file.") + args = parser.parse_args(sys.argv[1:]) - with open("evaluation_examples/examples/multi_apps/b5062e3e-641c-4e3a-907b-ac864d2e7652.json", "r", encoding="utf-8") as f: + example_path = args.example if args.example is not None and os.path.exists(args.example) else \ + 'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json' + with open(example_path, "r", encoding="utf-8") as f: example = json.load(f) - example["snapshot"] = "Snapshot 35" + if args.snapshot is not None: + example['snapshot'] = args.snapshot - env = DesktopEnv( path_to_vm=r"/mnt/data1/david/os-images/Ubuntu-1218/Ubuntu.vmx" - , snapshot_name="Snapshot 35" - , action_space="computer_13" - ) + assert os.path.exists(args.path), "The specified path to the .vmx file does not exist." + env = DesktopEnv( + path_to_vm=args.path, + snapshot_name=args.snapshot, + action_space="computer_13" + ) # reset the environment to certain snapshot observation = env.reset(task_config=example) done = False + logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"]) trajectory = [ { diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 85db78b..ff92673 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -13,22 +13,16 @@ from typing import Dict, List import backoff import dashscope import google.generativeai as genai +import openai import requests from PIL import Image -from vertexai.preview.generative_models import ( - HarmBlockThreshold, - HarmCategory, - Image, -) +from google.api_core.exceptions import InvalidArgument -from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes +from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ - SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \ - SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT - -# todo: cross-check with visualwebarena + SYS_PROMPT_IN_SOM_OUT_TAG logger = logging.getLogger("desktopenv.agent") @@ -43,7 +37,7 @@ def linearize_accessibility_tree(accessibility_tree): # leaf_nodes = find_leaf_nodes(accessibility_tree) filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree)) - linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n" + linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n" # Linearize the accessibility tree nodes into a table format for node in filtered_nodes: @@ -71,7 +65,8 @@ def tag_screenshot(screenshot, accessibility_tree): uuid_str = str(uuid.uuid4()) os.makedirs("tmp/images", exist_ok=True) tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png") - nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) + # nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) + nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True) # Make tag screenshot marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path) @@ -176,7 +171,7 @@ class PromptAgent: temperature=0.5, action_space="computer_13", observation_type="screenshot_a11y_tree", - # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"] + # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"] max_trajectory_length=3 ): self.model = model @@ -205,7 +200,7 @@ class PromptAgent: self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE else: raise ValueError("Invalid action space: " + action_space) - elif observation_type == "both": + elif observation_type == "screenshot_a11y_tree": if action_space == "computer_13": self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION elif action_space == "pyautogui": @@ -216,14 +211,7 @@ class PromptAgent: if action_space == "computer_13": raise ValueError("Invalid action space: " + action_space) elif action_space == "pyautogui": - self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG - else: - raise ValueError("Invalid action space: " + action_space) - elif observation_type == "seeact": - if action_space == "computer_13": - raise ValueError("Invalid action space: " + action_space) - elif action_space == "pyautogui": - self.system_message = SYS_PROMPT_SEEACT + self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG else: raise ValueError("Invalid action space: " + action_space) else: @@ -233,8 +221,7 @@ class PromptAgent: """ Predict the next action(s) based on the current observation. """ - self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format( - instruction) + system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction) # Prepare the payload for the API call messages = [] @@ -245,7 +232,7 @@ class PromptAgent: "content": [ { "type": "text", - "text": self.system_message + "text": system_message }, ] }) @@ -266,7 +253,7 @@ class PromptAgent: for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts): # {{{1 - if self.observation_type == "both": + if self.observation_type == "screenshot_a11y_tree": _screenshot = previous_obs["screenshot"] _linearized_accessibility_tree = previous_obs["accessibility_tree"] logger.debug("LINEAR AT: %s", _linearized_accessibility_tree) @@ -288,18 +275,15 @@ class PromptAgent: } ] }) - elif self.observation_type in ["som", "seeact"]: + elif self.observation_type in ["som"]: _screenshot = previous_obs["screenshot"] - _linearized_accessibility_tree = previous_obs["accessibility_tree"] - logger.debug("LINEAR AT: %s", _linearized_accessibility_tree) messages.append({ "role": "user", "content": [ { "type": "text", - "text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( - _linearized_accessibility_tree) + "text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?" }, { "type": "image_url", @@ -356,11 +340,11 @@ class PromptAgent: }) # {{{1 - if self.observation_type in ["screenshot", "both"]: + if self.observation_type in ["screenshot", "screenshot_a11y_tree"]: base64_image = encode_image(obs["screenshot"]) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) - if self.observation_type == "both": + if self.observation_type == "screenshot_a11y_tree": self.observations.append({ "screenshot": base64_image, "accessibility_tree": linearized_accessibility_tree @@ -412,11 +396,9 @@ class PromptAgent: # Add som to the screenshot masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) base64_image = encode_image(tagged_screenshot) - linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) self.observations.append({ - "screenshot": base64_image, - "accessibility_tree": linearized_accessibility_tree + "screenshot": base64_image }) messages.append({ @@ -424,35 +406,7 @@ class PromptAgent: "content": [ { "type": "text", - "text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( - linearized_accessibility_tree) - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{base64_image}", - "detail": "high" - } - } - ] - }) - elif self.observation_type == "seeact": - # Add som to the screenshot - masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) - base64_image = encode_image(tagged_screenshot) - linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) - - self.observations.append({ - "screenshot": base64_image, - "accessibility_tree": linearized_accessibility_tree - }) - - messages.append({ - "role": "user", - "content": [ - { - "type": "text", - "text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree) + "text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?" }, { "type": "image_url", @@ -469,60 +423,35 @@ class PromptAgent: # with open("messages.json", "w") as f: # f.write(json.dumps(messages, indent=4)) - logger.info("Generating content with GPT model: %s", self.model) response = self.call_llm({ "model": self.model, "messages": messages, - "max_tokens": self.max_tokens + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "temperature": self.temperature }) logger.info("RESPONSE: %s", response) - if self.observation_type == "seeact": - messages.append({ - "role": "assistant", - "content": [ - { - "type": "text", - "text": response - } - ] - }) - - messages.append({ - "role": "user", - "content": [ - { - "type": "text", - "text": "{}\n\nWhat's the next step that you will do to help with the task?".format( - ACTION_GROUNDING_PROMPT_SEEACT) - } - ] - }) - - logger.info("Generating content with GPT model: %s", self.model) - response = self.call_llm({ - "model": self.model, - "messages": messages, - "max_tokens": self.max_tokens, - "top_p": self.top_p, - "temperature": self.temperature - }) - logger.info("RESPONSE: %s", response) - try: actions = self.parse_actions(response, masks) self.thoughts.append(response) - except Exception as e: + except ValueError as e: print("Failed to parse action from response", e) actions = None self.thoughts.append("") - return actions + return response, actions @backoff.on_exception( backoff.expo, - (Exception), + # here you should add more model exceptions as you want, + # but you are forbidden to add "Exception", that is, a common type of exception + # because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit + (openai.RateLimitError, + openai.BadRequestError, + openai.InternalServerError, + InvalidArgument), max_tries=5 ) def call_llm(self, payload): @@ -542,14 +471,15 @@ class PromptAgent: if response.status_code != 200: if response.json()['error']['code'] == "context_length_exceeded": logger.error("Context length exceeded. Retrying with a smaller context.") - payload["messages"] = payload["messages"][-1:] + payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:] retry_response = requests.post( "https://api.openai.com/v1/chat/completions", headers=headers, json=payload ) if retry_response.status_code != 200: - logger.error("Failed to call LLM: " + retry_response.text) + logger.error( + "Failed to call LLM even after attempt on shortening the history: " + retry_response.text) return "" logger.error("Failed to call LLM: " + response.text) @@ -558,55 +488,173 @@ class PromptAgent: else: return response.json()['choices'][0]['message']['content'] - # elif self.model.startswith("mistral"): - # print("Call mistral") - # messages = payload["messages"] - # max_tokens = payload["max_tokens"] - # - # misrtal_messages = [] - # - # for i, message in enumerate(messages): - # mistral_message = { - # "role": message["role"], - # "content": [] - # } - # - # for part in message["content"]: - # mistral_message['content'] = part['text'] if part['type'] == "text" else None - # - # misrtal_messages.append(mistral_message) - # - # # the mistral not support system message in our endpoint, so we concatenate it at the first user message - # if misrtal_messages[0]['role'] == "system": - # misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content'] - # misrtal_messages.pop(0) - # - # # openai.api_base = "http://localhost:8000/v1" - # # openai.api_key = "test" - # # response = openai.ChatCompletion.create( - # # messages=misrtal_messages, - # # model="Mixtral-8x7B-Instruct-v0.1" - # # ) - # - # from openai import OpenAI - # TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2" - # - # client = OpenAI(api_key=TOGETHER_API_KEY, - # base_url='https://api.together.xyz', - # ) - # logger.info("Generating content with Mistral model: %s", self.model) - # response = client.chat.completions.create( - # messages=misrtal_messages, - # model="mistralai/Mixtral-8x7B-Instruct-v0.1", - # max_tokens=1024 - # ) - # - # try: - # # return response['choices'][0]['message']['content'] - # return response.choices[0].message.content - # except Exception as e: - # print("Failed to call LLM: " + str(e)) - # return "" + elif self.model.startswith("claude"): + messages = payload["messages"] + max_tokens = payload["max_tokens"] + top_p = payload["top_p"] + temperature = payload["temperature"] + + claude_messages = [] + + for i, message in enumerate(messages): + claude_message = { + "role": message["role"], + "content": [] + } + assert len(message["content"]) in [1, 2], "One text, or one text with one image" + for part in message["content"]: + + if part['type'] == "image_url": + image_source = {} + image_source["type"] = "base64" + image_source["media_type"] = "image/png" + image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "") + claude_message['content'].append({"type": "image", "source": image_source}) + + if part['type'] == "text": + claude_message['content'].append({"type": "text", "text": part['text']}) + + claude_messages.append(claude_message) + + # the claude not support system message in our endpoint, so we concatenate it at the first user message + if claude_messages[0]['role'] == "system": + claude_system_message_item = claude_messages[0]['content'][0] + claude_messages[1]['content'].insert(0, claude_system_message_item) + claude_messages.pop(0) + + # headers = { + # "x-api-key": os.environ["ANTHROPIC_API_KEY"], + # "anthropic-version": "2023-06-01", + # "content-type": "application/json" + # } + + headers = { + "Accept": "application / json", + "Authorization": "Bearer " + os.environ["ANTHROPIC_API_KEY"], + "User-Agent": "Apifox/1.0.0 (https://apifox.com)", + "Content-Type": "application/json" + } + + payload = { + "model": self.model, + "max_tokens": max_tokens, + "messages": claude_messages, + "temperature": temperature, + "top_p": top_p + } + + response = requests.post( + # "https://chat.claude.com/v1/chat/completions", + "https://api.aigcbest.top/v1/chat/completions", + headers=headers, + json=payload + ) + + if response.status_code != 200: + + logger.error("Failed to call LLM: " + response.text) + time.sleep(5) + return "" + # else: + # return response.json()['content'][0]['text'] + else: + return response.json()['choices'][0]['message']['content'] + + + elif self.model.startswith("mistral"): + print("Call mistral") + messages = payload["messages"] + max_tokens = payload["max_tokens"] + top_p = payload["top_p"] + temperature = payload["temperature"] + + misrtal_messages = [] + + for i, message in enumerate(messages): + mistral_message = { + "role": message["role"], + "content": "" + } + + for part in message["content"]: + mistral_message['content'] = part['text'] if part['type'] == "text" else "" + + misrtal_messages.append(mistral_message) + + # openai.api_base = "http://localhost:8000/v1" + # response = openai.ChatCompletion.create( + # messages=misrtal_messages, + # model="Mixtral-8x7B-Instruct-v0.1" + # ) + + from openai import OpenAI + + client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"], + base_url='https://api.together.xyz', + ) + logger.info("Generating content with Mistral model: %s", self.model) + + response = client.chat.completions.create( + messages=misrtal_messages, + model=self.model, + max_tokens=max_tokens + ) + + try: + return response.choices[0].message.content + except Exception as e: + print("Failed to call LLM: " + str(e)) + return "" + + elif self.model.startswith("THUDM"): + # THUDM/cogagent-chat-hf + print("Call CogAgent") + messages = payload["messages"] + max_tokens = payload["max_tokens"] + top_p = payload["top_p"] + temperature = payload["temperature"] + + cog_messages = [] + + for i, message in enumerate(messages): + cog_message = { + "role": message["role"], + "content": [] + } + + for part in message["content"]: + if part['type'] == "image_url": + cog_message['content'].append( + {"type": "image_url", "image_url": {"url": part['image_url']['url']}}) + + if part['type'] == "text": + cog_message['content'].append({"type": "text", "text": part['text']}) + + cog_messages.append(cog_message) + + # the cogagent not support system message in our endpoint, so we concatenate it at the first user message + if cog_messages[0]['role'] == "system": + cog_system_message_item = cog_messages[0]['content'][0] + cog_messages[1]['content'].insert(0, cog_system_message_item) + cog_messages.pop(0) + + payload = { + "model": self.model, + "max_tokens": max_tokens, + "messages": cog_messages + } + + base_url = "http://127.0.0.1:8000" + + response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=False) + if response.status_code == 200: + decoded_line = response.json() + content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "") + return content + else: + print("Failed to call LLM: ", response.status_code) + return "" + elif self.model.startswith("gemini"): def encoded_img_to_pil_img(data_str): @@ -656,8 +704,9 @@ class PromptAgent: for message in gemini_messages: message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n" gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}] + # gemini_messages[-1]['parts'][1].save("output.png", "PNG") - print(gemini_messages) + # print(gemini_messages) api_key = os.environ.get("GENAI_API_KEY") assert api_key is not None, "Please set the GENAI_API_KEY environment variable" genai.configure(api_key=api_key) @@ -671,17 +720,17 @@ class PromptAgent: "temperature": temperature }, safety_settings={ - HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + "harassment": "block_none", + "hate": "block_none", + "sex": "block_none", + "danger": "block_none" } ) try: return response.text except Exception as e: + logger.error("Meet exception when calling Gemini API, " + str(e)) return "" elif self.model.startswith("qwen"): messages = payload["messages"] @@ -726,7 +775,7 @@ class PromptAgent: def parse_actions(self, response: str, masks=None): - if self.observation_type in ["screenshot", "a11y_tree", "both"]: + if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]: # parse from the response if self.action_space == "computer_13": actions = parse_actions_from_string(response) @@ -738,7 +787,7 @@ class PromptAgent: self.actions.append(actions) return actions - elif self.observation_type in ["som", "seeact"]: + elif self.observation_type in ["som"]: # parse from the response if self.action_space == "computer_13": raise ValueError("Invalid action space: " + self.action_space) diff --git a/mm_agents/llm_server/CogAgent/CogAgent.py b/mm_agents/llm_server/CogAgent/CogAgent.py new file mode 100644 index 0000000..1b4cd53 --- /dev/null +++ b/mm_agents/llm_server/CogAgent/CogAgent.py @@ -0,0 +1,405 @@ +import os +import gc +import time +import base64 + +from contextlib import asynccontextmanager +from typing import List, Literal, Union, Tuple, Optional +import torch +import uvicorn +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from loguru import logger +from pydantic import BaseModel, Field +from sse_starlette.sse import EventSourceResponse +from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedModel, PreTrainedTokenizer, \ + TextIteratorStreamer +from PIL import Image +from io import BytesIO + +MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/cogvlm-chat-hf') +TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", 'lmsys/vicuna-7b-v1.5') +DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +if os.environ.get('QUANT_ENABLED'): + QUANT_ENABLED = True +else: + with torch.cuda.device(DEVICE): + __, total_bytes = torch.cuda.mem_get_info() + total_gb = total_bytes / (1 << 30) + if total_gb < 40: + QUANT_ENABLED = True + else: + QUANT_ENABLED = False + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + An asynchronous context manager for managing the lifecycle of the FastAPI app. + It ensures that GPU memory is cleared after the app's lifecycle ends, which is essential for efficient resource management in GPU environments. + """ + yield + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +class ModelCard(BaseModel): + """ + A Pydantic model representing a model card, which provides metadata about a machine learning model. + It includes fields like model ID, owner, and creation time. + """ + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "owner" + root: Optional[str] = None + parent: Optional[str] = None + permission: Optional[list] = None + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = [] + + +class ImageUrl(BaseModel): + url: str + + +class TextContent(BaseModel): + type: Literal["text"] + text: str + + +class ImageUrlContent(BaseModel): + type: Literal["image_url"] + image_url: ImageUrl + + +ContentItem = Union[TextContent, ImageUrlContent] + + +class ChatMessageInput(BaseModel): + role: Literal["user", "assistant", "system"] + content: Union[str, List[ContentItem]] + name: Optional[str] = None + + +class ChatMessageResponse(BaseModel): + role: Literal["assistant"] + content: str = None + name: Optional[str] = None + + +class DeltaMessage(BaseModel): + role: Optional[Literal["user", "assistant", "system"]] = None + content: Optional[str] = None + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessageInput] + temperature: Optional[float] = 0.8 + top_p: Optional[float] = 0.8 + max_tokens: Optional[int] = None + stream: Optional[bool] = False + # Additional parameters + repetition_penalty: Optional[float] = 1.0 + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessageResponse + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ChatCompletionResponse(BaseModel): + model: str + object: Literal["chat.completion", "chat.completion.chunk"] + choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] + created: Optional[int] = Field(default_factory=lambda: int(time.time())) + usage: Optional[UsageInfo] = None + + +@app.get("/v1/models", response_model=ModelList) +async def list_models(): + """ + An endpoint to list available models. It returns a list of model cards. + This is useful for clients to query and understand what models are available for use. + """ + model_card = ModelCard(id="cogvlm-chat-17b") # can be replaced by your model id like cogagent-chat-18b + return ModelList(data=[model_card]) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def create_chat_completion(request: ChatCompletionRequest): + global model, tokenizer + + if len(request.messages) < 1 or request.messages[-1].role == "assistant": + raise HTTPException(status_code=400, detail="Invalid request") + + gen_params = dict( + messages=request.messages, + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_tokens or 1024, + echo=False, + stream=request.stream, + ) + + if request.stream: + generate = predict(request.model, gen_params) + return EventSourceResponse(generate, media_type="text/event-stream") + response = generate_cogvlm(model, tokenizer, gen_params) + + usage = UsageInfo() + + message = ChatMessageResponse( + role="assistant", + content=response["text"], + ) + logger.debug(f"==== message ====\n{message}") + choice_data = ChatCompletionResponseChoice( + index=0, + message=message, + ) + task_usage = UsageInfo.model_validate(response["usage"]) + for usage_key, usage_value in task_usage.model_dump().items(): + setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) + return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage) + + +async def predict(model_id: str, params: dict): + """ + Handle streaming predictions. It continuously generates responses for a given input stream. + This is particularly useful for real-time, continuous interactions with the model. + """ + + global model, tokenizer + + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=None + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + + previous_text = "" + for new_response in generate_stream_cogvlm(model, tokenizer, params): + decoded_unicode = new_response["text"] + delta_text = decoded_unicode[len(previous_text):] + previous_text = decoded_unicode + delta = DeltaMessage( + content=delta_text, + role="assistant", + ) + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=delta, + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + ) + chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) + + +def generate_cogvlm(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): + """ + Generates a response using the CogVLM model. It processes the chat history and image data, if any, + and then invokes the model to generate a response. + """ + + for response in generate_stream_cogvlm(model, tokenizer, params): + pass + return response + + +def process_history_and_images(messages: List[ChatMessageInput]) -> Tuple[ + Optional[str], Optional[List[Tuple[str, str]]], Optional[List[Image.Image]]]: + """ + Process history messages to extract text, identify the last user query, + and convert base64 encoded image URLs to PIL images. + + Args: + messages(List[ChatMessageInput]): List of ChatMessageInput objects. + return: A tuple of three elements: + - The last user query as a string. + - Text history formatted as a list of tuples for the model. + - List of PIL Image objects extracted from the messages. + """ + formatted_history = [] + image_list = [] + last_user_query = '' + + for i, message in enumerate(messages): + role = message.role + content = message.content + + if isinstance(content, list): # text + text_content = ' '.join(item.text for item in content if isinstance(item, TextContent)) + else: + text_content = content + + if isinstance(content, list): # image + for item in content: + if isinstance(item, ImageUrlContent): + image_url = item.image_url.url + if image_url.startswith("data:image/jpeg;base64,"): + base64_encoded_image = image_url.split("data:image/jpeg;base64,")[1] + image_data = base64.b64decode(base64_encoded_image) + image = Image.open(BytesIO(image_data)).convert('RGB') + image_list.append(image) + elif image_url.startswith("data:image/png;base64,"): + base64_encoded_image = image_url.split("data:image/png;base64,")[1] + image_data = base64.b64decode(base64_encoded_image) + image = Image.open(BytesIO(image_data)).convert('RGB') + image_list.append(image) + + if role == 'user': + if i == len(messages) - 1: # 最后一条用户消息 + last_user_query = text_content + else: + formatted_history.append((text_content, '')) + elif role == 'assistant': + if formatted_history: + if formatted_history[-1][1] != '': + assert False, f"the last query is answered. answer again. {formatted_history[-1][0]}, {formatted_history[-1][1]}, {text_content}" + formatted_history[-1] = (formatted_history[-1][0], text_content) + else: + assert False, f"assistant reply before user" + else: + assert False, f"unrecognized role: {role}" + + return last_user_query, formatted_history, image_list + + +@torch.inference_mode() +def generate_stream_cogvlm(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): + """ + Generates a stream of responses using the CogVLM model in inference mode. + It's optimized to handle continuous input-output interactions with the model in a streaming manner. + """ + messages = params["messages"] + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_new_tokens = int(params.get("max_tokens", 256)) + query, history, image_list = process_history_and_images(messages) + + logger.debug(f"==== request ====\n{query}") + + input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, + images=[image_list[-1]]) + inputs = { + 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), + 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), + 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE), + 'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]], + } + if 'cross_images' in input_by_model and input_by_model['cross_images']: + inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]] + + input_echo_len = len(inputs["input_ids"][0]) + streamer = TextIteratorStreamer( + tokenizer=tokenizer, + timeout=60.0, + skip_prompt=True, + skip_special_tokens=True +) + gen_kwargs = { + "repetition_penalty": repetition_penalty, + "max_new_tokens": max_new_tokens, + "do_sample": True if temperature > 1e-5 else False, + "top_p": top_p if temperature > 1e-5 else 0, + 'streamer': streamer, + } + if temperature > 1e-5: + gen_kwargs["temperature"] = temperature + + total_len = 0 + generated_text = "" + with torch.no_grad(): + model.generate(**inputs, **gen_kwargs) + for next_text in streamer: + generated_text += next_text + yield { + "text": generated_text, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + } + ret = { + "text": generated_text, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": total_len - input_echo_len, + "total_tokens": total_len, + }, + } + yield ret + + +gc.collect() +torch.cuda.empty_cache() + +if __name__ == "__main__": + tokenizer = LlamaTokenizer.from_pretrained( + TOKENIZER_PATH, + trust_remote_code=True) + + if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: + torch_type = torch.bfloat16 + else: + torch_type = torch.float16 + + print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE)) + + if 'cuda' in DEVICE: + if QUANT_ENABLED: + model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + load_in_4bit=True, + trust_remote_code=True, + torch_dtype=torch_type, + low_cpu_mem_usage=True + ).eval() + else: + model = AutoModelForCausalLM.from_pretrained( + MODEL_PATH, + load_in_4bit=False, + trust_remote_code=True, + torch_dtype=torch_type, + low_cpu_mem_usage=True + ).to(DEVICE).eval() + + else: + model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, trust_remote_code=True).float().to(DEVICE).eval() + uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/mm_agents/llm_server/CogAgent/README.md b/mm_agents/llm_server/CogAgent/README.md new file mode 100644 index 0000000..b6f61d2 --- /dev/null +++ b/mm_agents/llm_server/CogAgent/README.md @@ -0,0 +1,7 @@ +## Deploy CogAgent as server + +``` +python CogAgent.py +``` + +The CogAgent LLM will be deployed on http://127.0.0.1:8000 \ No newline at end of file diff --git a/mm_agents/prompts.py b/mm_agents/prompts.py index 15aefeb..462aac7 100644 --- a/mm_agents/prompts.py +++ b/mm_agents/prompts.py @@ -798,10 +798,10 @@ You MUST choose and ONLY CHOOSE from the action space above, otherwise your acti You CAN predict multiple actions at one step, but you should only return one action for each step. """.strip() -SYS_PROMPT_IN_SOM_A11Y_OUT_TAG = """ +SYS_PROMPT_IN_SOM_OUT_TAG = """ You are an agent which follow my instruction and perform desktop computer tasks as instructed. You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. -For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library. +For each step, you will get an observation of the desktop by a screenshot with interact-able elements marked with numerical tags. And you will predict the action of the computer based on the image. You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot. You can replace x, y in the code with the tag of the element you want to operate with. such as: diff --git a/run.py b/run.py index 908d479..5212bc0 100644 --- a/run.py +++ b/run.py @@ -6,12 +6,17 @@ import datetime import json import logging import os +import random import sys +import wandb +from tqdm import tqdm + +import lib_run_single from desktop_env.envs.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent -# Logger Configs {{{ # +# Logger Configs {{{ # logger = logging.getLogger() logger.setLevel(logging.DEBUG) @@ -45,6 +50,10 @@ logger.addHandler(sdebug_handler) logger = logging.getLogger("desktopenv.experiment") +# wandb config +### set your wandb api key here +wandb.login(key=os.environ.get("WANDB_API_KEY", None)) + def config() -> argparse.Namespace: parser = argparse.ArgumentParser( @@ -79,7 +88,7 @@ def config() -> argparse.Namespace: parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples") # lm config - parser.add_argument("--model", type=str, default="gpt-4-vision-preview") + parser.add_argument("--model", type=str, default="gpt-4-0125-preview") parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--max_tokens", type=int, default=1500) @@ -101,6 +110,25 @@ def test( # log args logger.info("Args: %s", args) + # set wandb project + cfg_args = \ + { + "path_to_vm": args.path_to_vm, + "headless": args.headless, + "action_space": args.action_space, + "observation_type": args.observation_type, + "screen_width": args.screen_width, + "screen_height": args.screen_height, + "sleep_after_execution": args.sleep_after_execution, + "max_steps": args.max_steps, + "max_trajectory_length": args.max_trajectory_length, + "model": args.model, + "temperature": args.temperature, + "top_p": args.top_p, + "max_tokens": args.max_tokens, + "stop_token": args.stop_token, + "result_dir": args.result_dir + } agent = PromptAgent( model=args.model, @@ -117,8 +145,11 @@ def test( headless=args.headless, ) - for domain in test_all_meta: - for example_id in test_all_meta[domain]: + for domain in tqdm(test_all_meta, desc="Domain"): + for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): + run = wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}", + name=f"{example_id}") + # example setting config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") with open(config_file, "r", encoding="utf-8") as f: example = json.load(f) @@ -129,6 +160,10 @@ def test( instruction = example["instruction"] logger.info(f"[Instruction]: {instruction}") + # wandb each example config settings + cfg_args["instruction"] = instruction + cfg_args["start_time"] = datetime.datetime.now().strftime("%Y:%m:%d-%H:%M:%S") + run.config.update(cfg_args) example_result_dir = os.path.join( args.result_dir, @@ -139,69 +174,88 @@ def test( example_id ) os.makedirs(example_result_dir, exist_ok=True) - - agent.reset() - obs = env.reset(task_config=example) - done = False - step_idx = 0 - env.controller.start_recording() - - while not done and step_idx < max_steps: - actions = agent.predict( - instruction, - obs - ) - - for action in actions: - step_idx += 1 - # Capture the timestamp before executing the action - action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") - logger.info("Step %d: %s", step_idx + 1, action) - - observation, reward, done, info = env.step(action, args.sleep_after_execution) - - logger.info("Reward: %.2f", reward) - logger.info("Done: %s", done) - logger.info("Info: %s", info) - - # Save screenshot and trajectory information - with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), - "wb") as _f: - with open(observation['screenshot'], "rb") as __f: - screenshot = __f.read() - _f.write(screenshot) - - with open(os.path.join(example_result_dir, "traj.json"), "a") as f: - f.write(json.dumps({ - "step_num": step_idx + 1, - "action_timestamp": action_timestamp, - "action": action, - "reward": reward, - "done": done, - "info": info, - "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png" - })) - f.write("\n") - - if done: - logger.info("The episode is done.") - break - - result = env.evaluate() - logger.info("Result: %.2f", result) - scores.append(result) - env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + # example start running + try: + lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, + scores, run) + except Exception as e: + logger.error(f"Exception in {domain}/{example_id}: {e}") + wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])}) + env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write(json.dumps({ + "Error": f"Time limit exceeded in {domain}/{example_id}" + })) + f.write("\n") + # wandb settings + os.mkdir(os.path.join(wandb.run.dir, "results/")) + for file in os.listdir(example_result_dir): + # move file to just under the root dir + os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}")) + wandb.finish() env.close() logger.info(f"Average score: {sum(scores) / len(scores)}") -def get_unfinished(test_file_list, result_dir): - finished = [] - for domain in os.listdir(result_dir): - for example_id in os.listdir(os.path.join(result_dir, domain)): - finished.append(f"{domain}/{example_id}") - return [x for x in test_file_list if x not in finished] +def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read())) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result if __name__ == '__main__': @@ -209,10 +263,25 @@ if __name__ == '__main__': os.environ["TOKENIZERS_PARALLELISM"] = "false" args = config() - # test_file_list = get_unfinished(args.test, args.result_dir) - # logger.info(f"Total {len(test_file_list)} tasks left") - with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f: test_all_meta = json.load(f) - test(args, test_all_meta) + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result(args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta + ) + test(args, test_file_list) diff --git a/settings.json b/settings.json new file mode 100644 index 0000000..7ee7a21 --- /dev/null +++ b/settings.json @@ -0,0 +1,3 @@ +{ + "time_limit": "600" +} \ No newline at end of file