Merge branch 'main' of github.com:ztjhz/DesktopEnv
This commit is contained in:
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -11,8 +11,8 @@
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal",
|
||||
"args": [
|
||||
"--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx",
|
||||
"--example_time_limit", "60"
|
||||
"--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx"
|
||||
// "--example_time_limit", "60"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
16
demo.py
16
demo.py
@@ -1,16 +0,0 @@
|
||||
import signal
|
||||
import time
|
||||
|
||||
def handler(signo, frame):
|
||||
raise RuntimeError("Timeout")
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
|
||||
while True:
|
||||
try:
|
||||
signal.alarm(5) # seconds
|
||||
time.sleep(10)
|
||||
print("Working...")
|
||||
except Exception as e :
|
||||
print(e)
|
||||
continue
|
||||
@@ -263,16 +263,19 @@ class PythonController:
|
||||
"""
|
||||
Ends recording the screen.
|
||||
"""
|
||||
response = requests.post(self.http_server + "/end_recording")
|
||||
if response.status_code == 200:
|
||||
logger.info("Recording stopped successfully")
|
||||
with open(dest, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
else:
|
||||
logger.error("Failed to stop recording. Status code: %d", response.status_code)
|
||||
return None
|
||||
try:
|
||||
response = requests.post(self.http_server + "/end_recording")
|
||||
if response.status_code == 200:
|
||||
logger.info("Recording stopped successfully")
|
||||
with open(dest, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
else:
|
||||
logger.error("Failed to stop recording. Status code: %d", response.status_code)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to download the recording: %s", e)
|
||||
|
||||
# Additional info
|
||||
def get_vm_platform(self):
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"parameters": {
|
||||
"files": [
|
||||
{
|
||||
"url": "https://drive.usercontent.google.com/download?id=104pg3yochKyH2Uvlp3BdvKmHgYmSIESu&export=download&authuser=0&confirm=t&uuid=d1926366-4e54-4a44-8dcd-fc49ed6524d7&at=APZUnTXcBFV9kcacsA0toU83lMKJ:1706505549057d",
|
||||
"url": "https://drive.usercontent.google.com/download?id=1gqqY56robX1tb4YPa3Yk1d72T_k-Rgz3&export=download&authuser=0&confirm=t",
|
||||
"path": "/home/user/Desktop/15-MB-docx-file-download.docx"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"id": "3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
|
||||
"snapshot": "gimp",
|
||||
"instruction": "Download the image from \"https://drive.google.com/uc?export=download&id=1i8j5dGS57sA07jEuPNAlQW-sn5uqUnuK\", and then use GIMP to compress it to under 600KB. Resize if needed.",
|
||||
"instruction": "Download the image from \"https://drive.google.com/uc?export=download&id=1i8j5dGS57sA07jEuPNAlQW-sn5uqUnuK\", and then use GIMP to compress it to under 600KB as \"compressed.jpeg\" on the Desktop. Resize if needed.",
|
||||
"source": "",
|
||||
"config": [
|
||||
{
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
{
|
||||
"id": "e2392362-125e-4f76-a2ee-524b183a3412",
|
||||
"snapshot": "chrome",
|
||||
"instruction": "I recently started using the famous personal academic homepage template from academicpages.github.io to build my own personal homepage, and I have cloned it to my local ~/Code/Website folder. According to an online tutorial, I can configure my name and contact information in the _config.yaml file. However, I am not familiar with the YAML file format. Please help me find the sections related to the name and contact information in this file and change them to “Test Account” and “Test@gmail.com”.",
|
||||
"instruction": "I recently started using the famous personal academic homepage template from academicpages.github.io to build my own personal homepage, and I have cloned it to my local ~/Code/Website folder. According to an online tutorial, I can configure my name and contact information in the _config.yaml file. However, I am not familiar with the YAML file format. Please help me find the sections related to the name and contact information in this file and change them to \"Test Account\" and \"Test@gmail.com\".",
|
||||
"source": "authors",
|
||||
"config": [
|
||||
{
|
||||
"type": "command",
|
||||
"parameters": {
|
||||
"command": ["mkdir", "-p", "/home/user/Code/Website"]
|
||||
"command": [
|
||||
"mkdir",
|
||||
"-p",
|
||||
"/home/user/Code/Website"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -24,13 +28,22 @@
|
||||
{
|
||||
"type": "execute",
|
||||
"parameters": {
|
||||
"command": ["tar", "-xJvf", ".tmp.tar.xz", "-C", "/home/user/Code/Website/"]
|
||||
"command": [
|
||||
"tar",
|
||||
"-xJvf",
|
||||
".tmp.tar.xz",
|
||||
"-C",
|
||||
"/home/user/Code/Website/"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "launch",
|
||||
"parameters": {
|
||||
"command": ["google-chrome", "--remote-debugging-port=1337"]
|
||||
"command": [
|
||||
"google-chrome",
|
||||
"--remote-debugging-port=1337"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -46,14 +59,20 @@
|
||||
{
|
||||
"type": "chrome_open_tabs",
|
||||
"parameters": {
|
||||
"urls_to_open": ["https://academicpages.github.io/"]
|
||||
"urls_to_open": [
|
||||
"https://academicpages.github.io/"
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"trajectory": "trajectories/e2392362-125e-4f76-a2ee-524b183a3412",
|
||||
"related_apps": ["chrome", "os", "vscode"],
|
||||
"related_apps": [
|
||||
"chrome",
|
||||
"os",
|
||||
"vscode"
|
||||
],
|
||||
"evaluator": {
|
||||
"postconfig":[
|
||||
"postconfig": [
|
||||
{
|
||||
"type": "execute",
|
||||
"parameters": {
|
||||
@@ -66,23 +85,33 @@
|
||||
}
|
||||
],
|
||||
"func": "check_json",
|
||||
"options": {"is_yaml": true},
|
||||
"options": {
|
||||
"is_yaml": true
|
||||
},
|
||||
"expected": {
|
||||
"type": "rule",
|
||||
"rules": {
|
||||
"expect": [
|
||||
{
|
||||
"key": ["name"],
|
||||
"key": [
|
||||
"name"
|
||||
],
|
||||
"method": "eq",
|
||||
"ref": "Test Account"
|
||||
},
|
||||
{
|
||||
"key": ["author", "name"],
|
||||
"key": [
|
||||
"author",
|
||||
"name"
|
||||
],
|
||||
"method": "eq",
|
||||
"ref": "Test Account"
|
||||
},
|
||||
{
|
||||
"key": ["author", "email"],
|
||||
"key": [
|
||||
"author",
|
||||
"email"
|
||||
],
|
||||
"method": "eq",
|
||||
"ref": "Test@gmail.com"
|
||||
}
|
||||
@@ -95,4 +124,4 @@
|
||||
"dest": "_config.yaml"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -103,7 +103,6 @@
|
||||
"1e8df695-bd1b-45b3-b557-e7d599cf7597",
|
||||
"ecb0df7a-4e8d-4a03-b162-053391d3afaf",
|
||||
"8b1ce5f2-59d2-4dcc-b0b0-666a714b9a14",
|
||||
"7b802dad-6e0f-4204-9815-d4e3f57627d8",
|
||||
"a01fbce3-2793-461f-ab86-43680ccbae25",
|
||||
"0326d92d-d218-48a8-9ca1-981cd6d064c7",
|
||||
"0a2e43bf-b26c-4631-a966-af9dfa12c9e5",
|
||||
@@ -380,7 +379,6 @@
|
||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||
"ae506c68-352c-4094-9caa-ee9d42052317",
|
||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||
"c714dcee-cad3-4e12-8f3c-12bdcfcdb048",
|
||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||
"9d425400-e9b2-4424-9a4b-d4c7abac4140",
|
||||
|
||||
102
evaluation_examples/test_small.json
Normal file
102
evaluation_examples/test_small.json
Normal file
@@ -0,0 +1,102 @@
|
||||
{
|
||||
"chrome": [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
||||
],
|
||||
"gimp": [
|
||||
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
|
||||
"554785e9-4523-4e7a-b8e1-8016f565f56a"
|
||||
],
|
||||
"libreoffice_calc": [
|
||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"42e0a640-4f19-4b28-973d-729602b5a4a7"
|
||||
],
|
||||
"libreoffice_impress": [
|
||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||
"550ce7e7-747b-495f-b122-acdc4d0b8e54"
|
||||
],
|
||||
"libreoffice_writer": [
|
||||
"0810415c-bde4-4443-9047-d5f70165a697",
|
||||
"0a0faba3-5580-44df-965d-f562a99b291c"
|
||||
],
|
||||
"multi_apps": [
|
||||
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
||||
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
||||
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
||||
"510f64c8-9bcc-4be1-8d30-638705850618",
|
||||
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
||||
"c867c42d-a52d-4a24-8ae3-f75d256b5618",
|
||||
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
||||
"f7dfbef3-7697-431c-883a-db8583a4e4f9",
|
||||
"6d72aad6-187a-4392-a4c4-ed87269c51cf",
|
||||
"f918266a-b3e0-4914-865d-4faa564f1aef",
|
||||
"da52d699-e8d2-4dc5-9191-a2199e0b6a9b",
|
||||
"74d5859f-ed66-4d3e-aa0e-93d7a592ce41",
|
||||
"b5062e3e-641c-4e3a-907b-ac864d2e7652",
|
||||
"48d05431-6cd5-4e76-82eb-12b60d823f7d",
|
||||
"eb303e01-261e-4972-8c07-c9b4e7a4922a",
|
||||
"d1acdb87-bb67-4f30-84aa-990e56a09c92",
|
||||
"deec51c9-3b1e-4b9e-993c-4776f20e8bb2",
|
||||
"8e116af7-7db7-4e35-a68b-b0939c066c78",
|
||||
"185f29bd-5da0-40a6-b69c-ba7f4e0324ef",
|
||||
"2c1ebcd7-9c6d-4c9a-afad-900e381ecd5e",
|
||||
"3a93cae4-ad3e-403e-8c12-65303b271818",
|
||||
"1f18aa87-af6f-41ef-9853-cdb8f32ebdea",
|
||||
"26150609-0da3-4a7d-8868-0faf9c5f01bb",
|
||||
"7e287123-70ca-47b9-8521-47db09b69b14",
|
||||
"e2392362-125e-4f76-a2ee-524b183a3412",
|
||||
"26660ad1-6ebb-4f59-8cba-a8432dfe8d38",
|
||||
"a82b78bb-7fde-4cb3-94a4-035baf10bcf0",
|
||||
"36037439-2044-4b50-b9d1-875b5a332143",
|
||||
"716a6079-22da-47f1-ba73-c9d58f986a38",
|
||||
"a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a",
|
||||
"6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a",
|
||||
"da922383-bfa4-4cd3-bbad-6bebab3d7742",
|
||||
"2373b66a-092d-44cb-bfd7-82e86e7a3b4d",
|
||||
"81c425f5-78f3-4771-afd6-3d2973825947",
|
||||
"227d2f97-562b-4ccb-ae47-a5ec9e142fbb",
|
||||
"20236825-b5df-46e7-89bf-62e1d640a897",
|
||||
"02ce9a50-7af2-47ed-8596-af0c230501f8",
|
||||
"4c26e3f3-3a14-4d86-b44a-d3cedebbb487",
|
||||
"09a37c51-e625-49f4-a514-20a773797a8a",
|
||||
"3e3fc409-bff3-4905-bf16-c968eee3f807",
|
||||
"415ef462-bed3-493a-ac36-ca8c6d23bf1b",
|
||||
"9f3bb592-209d-43bc-bb47-d77d9df56504",
|
||||
"dd60633f-2c72-42ba-8547-6f2c8cb0fdb0",
|
||||
"3f05f3b9-29ba-4b6b-95aa-2204697ffc06",
|
||||
"f8369178-fafe-40c2-adc4-b9b08a125456",
|
||||
"778efd0a-153f-4842-9214-f05fc176b877",
|
||||
"47f7c0ce-a5fb-4100-a5e6-65cd0e7429e5",
|
||||
"c2751594-0cd5-4088-be1b-b5f2f9ec97c4",
|
||||
"48c46dc7-fe04-4505-ade7-723cba1aa6f6",
|
||||
"42d25c08-fb87-4927-8b65-93631280a26f",
|
||||
"bb7db4c2-30b5-4be7-8dd7-b8c4ec7d3108",
|
||||
"3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
|
||||
"d68204bf-11c1-4b13-b48b-d303c73d4bf6",
|
||||
"91190194-f406-4cd6-b3f9-c43fac942b22",
|
||||
"7f35355e-02a6-45b5-b140-f0be698bcf85",
|
||||
"98e8e339-5f91-4ed2-b2b2-12647cb134f4",
|
||||
"df67aebb-fb3a-44fd-b75b-51b6012df509",
|
||||
"5df7b33a-9f77-4101-823e-02f863e1c1ae",
|
||||
"22a4636f-8179-4357-8e87-d1743ece1f81",
|
||||
"236833a3-5704-47fc-888c-4f298f09f799"
|
||||
],
|
||||
"os": [
|
||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||
"5812b315-e7bd-4265-b51f-863c02174c28",
|
||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82"
|
||||
],
|
||||
"thunderbird": [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
||||
],
|
||||
"vlc": [
|
||||
"59f21cfb-0120-4326-b255-a5b827b38967",
|
||||
"8f080098-ddb1-424c-b438-4e96e5e4786e"
|
||||
],
|
||||
"vs_code": [
|
||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb"
|
||||
]
|
||||
}
|
||||
72
lib_run_single.py
Normal file
72
lib_run_single.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import wandb
|
||||
|
||||
from wrapt_timeout_decorator import *
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
# Open the JSON file
|
||||
with open("./settings.json", "r") as file:
|
||||
# Load the JSON data from the file
|
||||
data = json.load(file)
|
||||
time_limit = data["time_limit"]
|
||||
|
||||
@timeout(time_limit, use_signals=False)
|
||||
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
agent.reset()
|
||||
obs = env.reset(task_config=example)
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"])
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
with open(obs['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
# get a11tree and save to wandb
|
||||
thisrun_a11tree = env.controller.get_accessibility_tree()
|
||||
str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"),
|
||||
thisrun_a11tree,
|
||||
response, action, action_timestamp, done)
|
||||
wandb.log({"Reward": reward})
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
wandb.log({"str_trajectory": str_table})
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
wandb.log({"Result": result})
|
||||
@@ -5,19 +5,21 @@ import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import openai
|
||||
import xml.etree.ElementTree as ET
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
from google.api_core.exceptions import InvalidArgument
|
||||
|
||||
import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
import requests
|
||||
import wandb
|
||||
from PIL import Image
|
||||
from google.api_core.exceptions import InvalidArgument
|
||||
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
|
||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
||||
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
||||
@@ -422,7 +424,6 @@ class PromptAgent:
|
||||
# with open("messages.json", "w") as f:
|
||||
# f.write(json.dumps(messages, indent=4))
|
||||
|
||||
logger.info("Generating content with GPT model: %s", self.model)
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
@@ -441,7 +442,7 @@ class PromptAgent:
|
||||
actions = None
|
||||
self.thoughts.append("")
|
||||
|
||||
return actions
|
||||
return response, actions
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
@@ -461,7 +462,7 @@ class PromptAgent:
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
||||
}
|
||||
# logger.info("Generating content with GPT model: %s", self.model)
|
||||
logger.info("Generating content with GPT model: %s", self.model)
|
||||
response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=headers,
|
||||
@@ -488,55 +489,162 @@ class PromptAgent:
|
||||
else:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
|
||||
# elif self.model.startswith("mistral"):
|
||||
# print("Call mistral")
|
||||
# messages = payload["messages"]
|
||||
# max_tokens = payload["max_tokens"]
|
||||
#
|
||||
# misrtal_messages = []
|
||||
#
|
||||
# for i, message in enumerate(messages):
|
||||
# mistral_message = {
|
||||
# "role": message["role"],
|
||||
# "content": []
|
||||
# }
|
||||
#
|
||||
# for part in message["content"]:
|
||||
# mistral_message['content'] = part['text'] if part['type'] == "text" else None
|
||||
#
|
||||
# misrtal_messages.append(mistral_message)
|
||||
#
|
||||
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
||||
# if misrtal_messages[0]['role'] == "system":
|
||||
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
||||
# misrtal_messages.pop(0)
|
||||
#
|
||||
# # openai.api_base = "http://localhost:8000/v1"
|
||||
# # openai.api_key = "test"
|
||||
# # response = openai.ChatCompletion.create(
|
||||
# # messages=misrtal_messages,
|
||||
# # model="Mixtral-8x7B-Instruct-v0.1"
|
||||
# # )
|
||||
#
|
||||
# from openai import OpenAI
|
||||
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
||||
#
|
||||
# client = OpenAI(api_key=TOGETHER_API_KEY,
|
||||
# base_url='https://api.together.xyz',
|
||||
# )
|
||||
# logger.info("Generating content with Mistral model: %s", self.model)
|
||||
# response = client.chat.completions.create(
|
||||
# messages=misrtal_messages,
|
||||
# model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
# max_tokens=1024
|
||||
# )
|
||||
#
|
||||
# try:
|
||||
# # return response['choices'][0]['message']['content']
|
||||
# return response.choices[0].message.content
|
||||
# except Exception as e:
|
||||
# print("Failed to call LLM: " + str(e))
|
||||
# return ""
|
||||
elif self.model.startswith("claude"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
claude_messages = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
claude_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
for part in message["content"]:
|
||||
|
||||
if part['type'] == "image_url":
|
||||
image_source = {}
|
||||
image_source["type"] = "base64"
|
||||
image_source["media_type"] = "image/png"
|
||||
image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "")
|
||||
claude_message['content'].append({"type": "image", "source": image_source})
|
||||
|
||||
if part['type'] == "text":
|
||||
claude_message['content'].append({"type": "text", "text": part['text']})
|
||||
|
||||
claude_messages.append(claude_message)
|
||||
|
||||
# the claude not support system message in our endpoint, so we concatenate it at the first user message
|
||||
if claude_messages[0]['role'] == "system":
|
||||
claude_system_message_item = claude_messages[0]['content'][0]
|
||||
claude_messages[1]['content'].insert(0, claude_system_message_item)
|
||||
claude_messages.pop(0)
|
||||
|
||||
headers = {
|
||||
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": claude_messages
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
"https://api.anthropic.com/v1/messages",
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
time.sleep(5)
|
||||
return ""
|
||||
else:
|
||||
return response.json()['content'][0]['text']
|
||||
|
||||
|
||||
elif self.model.startswith("mistral"):
|
||||
print("Call mistral")
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
misrtal_messages = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
mistral_message = {
|
||||
"role": message["role"],
|
||||
"content": ""
|
||||
}
|
||||
|
||||
for part in message["content"]:
|
||||
mistral_message['content'] = part['text'] if part['type'] == "text" else ""
|
||||
|
||||
|
||||
misrtal_messages.append(mistral_message)
|
||||
|
||||
|
||||
# openai.api_base = "http://localhost:8000/v1"
|
||||
# response = openai.ChatCompletion.create(
|
||||
# messages=misrtal_messages,
|
||||
# model="Mixtral-8x7B-Instruct-v0.1"
|
||||
# )
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
||||
base_url='https://api.together.xyz',
|
||||
)
|
||||
logger.info("Generating content with Mistral model: %s", self.model)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
messages=misrtal_messages,
|
||||
model=self.model,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
try:
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print("Failed to call LLM: " + str(e))
|
||||
return ""
|
||||
|
||||
elif self.model.startswith("THUDM"):
|
||||
# THUDM/cogagent-chat-hf
|
||||
print("Call CogAgent")
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
cog_messages = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
cog_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
|
||||
for part in message["content"]:
|
||||
if part['type'] == "image_url":
|
||||
cog_message['content'].append({"type": "image_url", "image_url": {"url": part['image_url']['url'] } })
|
||||
|
||||
if part['type'] == "text":
|
||||
cog_message['content'].append({"type": "text", "text": part['text']})
|
||||
|
||||
cog_messages.append(cog_message)
|
||||
|
||||
# the cogagent not support system message in our endpoint, so we concatenate it at the first user message
|
||||
if cog_messages[0]['role'] == "system":
|
||||
cog_system_message_item = cog_messages[0]['content'][0]
|
||||
cog_messages[1]['content'].insert(0, cog_system_message_item)
|
||||
cog_messages.pop(0)
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": cog_messages
|
||||
}
|
||||
|
||||
base_url = "http://127.0.0.1:8000"
|
||||
|
||||
response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=False)
|
||||
if response.status_code == 200:
|
||||
decoded_line = response.json()
|
||||
content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "")
|
||||
return content
|
||||
else:
|
||||
print("Failed to call LLM: ", response.status_code)
|
||||
return ""
|
||||
|
||||
|
||||
elif self.model.startswith("gemini"):
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
@@ -612,6 +720,7 @@ class PromptAgent:
|
||||
try:
|
||||
return response.text
|
||||
except Exception as e:
|
||||
logger.error("Meet exception when calling Gemini API, " + str(e))
|
||||
return ""
|
||||
elif self.model.startswith("qwen"):
|
||||
messages = payload["messages"]
|
||||
|
||||
189
run.py
189
run.py
@@ -6,13 +6,17 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import random
|
||||
import sys
|
||||
import wandb
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import lib_run_single
|
||||
from desktop_env.envs.desktop_env import DesktopEnv
|
||||
from mm_agents.agent import PromptAgent
|
||||
|
||||
# Logger Configs {{{ #
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
@@ -46,13 +50,10 @@ logger.addHandler(sdebug_handler)
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
# make sure each example won't exceed the time limit
|
||||
def handler(signo, frame):
|
||||
raise RuntimeError("Time limit exceeded!")
|
||||
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
# wandb config
|
||||
### set your wandb api key here
|
||||
os.environ["WANDB_API_KEY"] = ""
|
||||
wandb.login(key=os.environ["WANDB_API_KEY"])
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
@@ -75,7 +76,7 @@ def config() -> argparse.Namespace:
|
||||
"screenshot_a11y_tree",
|
||||
"som"
|
||||
],
|
||||
default="som",
|
||||
default="a11y_tree",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
@@ -86,10 +87,9 @@ def config() -> argparse.Namespace:
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
|
||||
parser.add_argument("--example_time_limit", type=int, default=600)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
|
||||
parser.add_argument("--model", type=str, default="gpt-4-0125-preview")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
@@ -108,10 +108,28 @@ def test(
|
||||
) -> None:
|
||||
scores = []
|
||||
max_steps = args.max_steps
|
||||
time_limit = args.example_time_limit
|
||||
|
||||
# log args
|
||||
logger.info("Args: %s", args)
|
||||
# set wandb project
|
||||
cfg_args = \
|
||||
{
|
||||
"path_to_vm": args.path_to_vm,
|
||||
"headless": args.headless,
|
||||
"action_space": args.action_space,
|
||||
"observation_type": args.observation_type,
|
||||
"screen_width": args.screen_width,
|
||||
"screen_height": args.screen_height,
|
||||
"sleep_after_execution": args.sleep_after_execution,
|
||||
"max_steps": args.max_steps,
|
||||
"max_trajectory_length": args.max_trajectory_length,
|
||||
"model": args.model,
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
"max_tokens": args.max_tokens,
|
||||
"stop_token": args.stop_token,
|
||||
"result_dir": args.result_dir
|
||||
}
|
||||
|
||||
agent = PromptAgent(
|
||||
model=args.model,
|
||||
@@ -128,8 +146,10 @@ def test(
|
||||
headless=args.headless,
|
||||
)
|
||||
|
||||
for domain in test_all_meta:
|
||||
for example_id in test_all_meta[domain]:
|
||||
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||
wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}",
|
||||
name=f"{example_id}")
|
||||
# example setting
|
||||
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
@@ -141,6 +161,10 @@ def test(
|
||||
instruction = example["instruction"]
|
||||
|
||||
logger.info(f"[Instruction]: {instruction}")
|
||||
# wandb each example config settings
|
||||
cfg_args["instruction"] = instruction
|
||||
cfg_args["start_time"] = datetime.datetime.now().strftime("%Y:%m:%d-%H:%M:%S")
|
||||
wandb.config.update(cfg_args)
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
@@ -151,79 +175,26 @@ def test(
|
||||
example_id
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
# example start running
|
||||
try:
|
||||
signal.alarm(time_limit)
|
||||
agent.reset()
|
||||
obs = env.reset(task_config=example)
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
|
||||
while not done and step_idx < max_steps:
|
||||
actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
logger.info("Info: %s", info)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
with open(obs['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir,
|
||||
scores)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
||||
wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])})
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Error in example {domain}/{example_id}: {e}")
|
||||
# save info of this example and then continue
|
||||
try:
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"Error": f"Error in example {domain}/{example_id}: {e}",
|
||||
"step": step_idx + 1,
|
||||
}))
|
||||
f.write("\n")
|
||||
except Exception as new_e:
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"Error": f"Error in example {domain}/{example_id}: {e} and {new_e}",
|
||||
"step": "before start recording",
|
||||
}))
|
||||
f.write("\n")
|
||||
continue
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"Error": f"Time limit exceeded in {domain}/{example_id}"
|
||||
}))
|
||||
f.write("\n")
|
||||
# wandb settings
|
||||
os.mkdir(os.path.join(wandb.run.dir, "results/"))
|
||||
for file in os.listdir(example_result_dir):
|
||||
# move file to just under the root dir
|
||||
os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}"))
|
||||
wandb.finish()
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
|
||||
@@ -236,9 +207,18 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
finished[domain] = os.listdir(domain_path)
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
@@ -250,6 +230,35 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read()))
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -270,4 +279,10 @@ if __name__ == '__main__':
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
test(args, test_all_meta)
|
||||
# get_result(args.action_space,
|
||||
# args.model,
|
||||
# args.observation_type,
|
||||
# args.result_dir,
|
||||
# test_all_meta
|
||||
# )
|
||||
test(args, test_file_list)
|
||||
|
||||
3
settings.json
Normal file
3
settings.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"time_limit": "10"
|
||||
}
|
||||
Reference in New Issue
Block a user