Merge branch 'main' of github.com:ztjhz/DesktopEnv

This commit is contained in:
Siheng Zhao
2024-03-20 22:42:01 +08:00
15 changed files with 125 additions and 100 deletions

3
.gitignore vendored
View File

@@ -2,6 +2,9 @@
*.pth *.pth
*.pt *.pt
# Credential files
evaluation_examples/settings/googledrive/credentials.json
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/
*.py[cod] *.py[cod]

View File

@@ -64,7 +64,7 @@ class PythonController:
It can be used to execute the pyautogui commands, or... any other python command. who knows? It can be used to execute the pyautogui commands, or... any other python command. who knows?
""" """
# command_list = ["python", "-c", self.pkgs_prefix.format(command=command)] # command_list = ["python", "-c", self.pkgs_prefix.format(command=command)]
command_list = ["python3", "-c", self.pkgs_prefix.format(command=command)] command_list = ["python", "-c", self.pkgs_prefix.format(command=command)]
payload = json.dumps({"command": command_list, "shell": False}) payload = json.dumps({"command": command_list, "shell": False})
headers = { headers = {
'Content-Type': 'application/json' 'Content-Type': 'application/json'

View File

@@ -58,7 +58,8 @@ class DesktopEnv(gym.Env):
tmp_dir: str = "tmp", tmp_dir: str = "tmp",
cache_dir: str = "cache", cache_dir: str = "cache",
screen_size: Tuple[int] = (1920, 1080), screen_size: Tuple[int] = (1920, 1080),
headless: bool = False headless: bool = False,
require_a11y_tree: bool = True,
): ):
""" """
Args: Args:
@@ -77,6 +78,7 @@ class DesktopEnv(gym.Env):
self.cache_dir_base: str = cache_dir self.cache_dir_base: str = cache_dir
self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM
self.headless = headless self.headless = headless
self.require_a11y_tree = require_a11y_tree
os.makedirs(self.tmp_dir_base, exist_ok=True) os.makedirs(self.tmp_dir_base, exist_ok=True)
@@ -248,7 +250,7 @@ class DesktopEnv(gym.Env):
observation = { observation = {
"screenshot": self._get_obs(), "screenshot": self._get_obs(),
"accessibility_tree": self.controller.get_accessibility_tree(), "accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
} }
return observation return observation
@@ -284,7 +286,7 @@ class DesktopEnv(gym.Env):
observation = { observation = {
"screenshot": self._get_obs(), "screenshot": self._get_obs(),
"accessibility_tree": self.controller.get_accessibility_tree(), "accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
# "terminal": self.controller.get_terminal_output(), # "terminal": self.controller.get_terminal_output(),
"instruction": self.instruction "instruction": self.instruction
} }

View File

@@ -77,6 +77,7 @@ from .general import (
literal_match literal_match
) )
from .gimp import ( from .gimp import (
check_structure_sim_resized,
check_brightness_decrease_and_structure_sim, check_brightness_decrease_and_structure_sim,
check_contrast_increase_and_structure_sim, check_contrast_increase_and_structure_sim,
check_saturation_increase_and_structure_sim, check_saturation_increase_and_structure_sim,

View File

@@ -414,9 +414,18 @@ def _create_pywinauto_node(node: BaseWrapper, depth: int = 0, flag: Optional[str
attribute_dict: Dict[str, Any] = {"name": node.element_info.name} attribute_dict: Dict[str, Any] = {"name": node.element_info.name}
# States {{{ # # States {{{ #
attribute_dict["{{{:}}}enabled".format(_accessibility_ns_map["st"])] = str(node.is_enabled()).lower() try:
attribute_dict["{{{:}}}visible".format(_accessibility_ns_map["st"])] = str(node.is_visible()).lower() attribute_dict["{{{:}}}enabled".format(_accessibility_ns_map["st"])] = str(node.is_enabled()).lower()
attribute_dict["{{{:}}}active".format(_accessibility_ns_map["st"])] = str(node.is_active()).lower() except:
pass
try:
attribute_dict["{{{:}}}visible".format(_accessibility_ns_map["st"])] = str(node.is_visible()).lower()
except:
pass
try:
attribute_dict["{{{:}}}active".format(_accessibility_ns_map["st"])] = str(node.is_active()).lower()
except:
pass
if hasattr(node, "is_minimized"): if hasattr(node, "is_minimized"):
try: try:
@@ -603,9 +612,14 @@ def get_accessibility_tree():
@app.route('/screen_size', methods=['POST']) @app.route('/screen_size', methods=['POST'])
def get_screen_size(): def get_screen_size():
d = display.Display() if platform_name=="Linux":
screen_width = d.screen().width_in_pixels d = display.Display()
screen_height = d.screen().height_in_pixels screen_width = d.screen().width_in_pixels
screen_height = d.screen().height_in_pixels
elif platform_name=="Windows":
user32 = ctypes.windll.user32
screen_width: int = user32.GetSystemMetrics(0)
screen_height: int = user32.GetSystemMetrics(1)
return jsonify( return jsonify(
{ {
"width": screen_width, "width": screen_width,

View File

@@ -21,16 +21,18 @@
"type": "launch", "type": "launch",
"parameters": { "parameters": {
"command": [ "command": [
"C:\Program Files\Google\Chrome\Application\chrome.exe", "C:\\Program Files\\Google\\Chrome\\Application\\chrome.exe",
"--remote-debugging-port=1337" "--remote-debugging-port=1337"
] ]
} }
}, },
{ {
"type": "launch", "type": "launch",
"parameters": { "parameters": {
"command": "nc -l -p 9222 |nc 127.0.0.1 1337", "command": [
"shell": true "ncat.exe", "-k", "-l", "0.0.0.0", "9222",
"--sh-exec", "ncat.exe 127.0.0.1 1337"
]
} }
}, },
{ {
@@ -54,8 +56,8 @@
"parameters": { "parameters": {
"files": [ "files": [
{ {
"url": "https://drive.usercontent.google.com/download?id=18jdi0OanMtAQenm4ODTivsxTSzdj4HUV&export=download&authuser=0&confirm=t&uuid=e858d3cc-4535-4419-a651-8856ac517d19&at=APZUnTW7g4ygfrkKTPBWCO13twRj:1706611460571", "url": "https://drive.google.com/uc?id=1Yy-ZrkMq4pIQq1Y75bD2WVJXxHMTaMqE&export=download",
"path": "/home/user/thunderbird-profile.tar.gz" "path": "C:\\Users\\chenj\\thunderbird-profile.7z"
} }
] ]
} }
@@ -64,21 +66,30 @@
"type": "execute", "type": "execute",
"parameters": { "parameters": {
"command": [ "command": [
"tar", "C:\\Program Files\\7-Zip\\7z.exe",
"-xz", "x", "C:\\Users\\chenj\\thunderbird-profile.7z"
"--recursive-unlink",
"-f",
"/home/user/thunderbird-profile.tar.gz",
"-C",
"/home/user/"
] ]
} }
}, },
{
"type": "execute",
"parameters": {
"command": "rm -r C:\\Users\\chenj\\AppData\\Roaming\\Thunderbird",
"shell": true
}
},
{
"type": "execute",
"parameters": {
"command": "mv C:\\Users\\chenj\\Thunderbird C:\\Users\\chenj\\AppData\\Roaming\\Thunderbird",
"shell": true
}
},
{ {
"type": "launch", "type": "launch",
"parameters": { "parameters": {
"command": [ "command": [
"/usr/bin/thunderbird" "C:\\Program Files\\Mozilla Thunderbird\\thunderbird.exe"
] ]
} }
} }

View File

@@ -21,7 +21,7 @@
"type": "launch", "type": "launch",
"parameters": { "parameters": {
"command": [ "command": [
"google-chrome", "C:\\Program Files\\Google\\Chrome\\Application\\chrome.exe",
"--remote-debugging-port=1337" "--remote-debugging-port=1337"
] ]
} }
@@ -30,9 +30,8 @@
"type": "launch", "type": "launch",
"parameters": { "parameters": {
"command": [ "command": [
"socat", "ncat.exe", "-k", "-l", "0.0.0.0", "9222",
"tcp-listen:9222,fork", "--sh-exec", "ncat.exe 127.0.0.1 1337"
"tcp:localhost:1337"
] ]
} }
}, },
@@ -60,7 +59,7 @@
"files": [ "files": [
{ {
"url": "https://drive.usercontent.google.com/download?id=18TvzE8jnULU2g9XJsT-TaPEKcLGNVfu0&export=download&authuser=0&confirm=t&uuid=d914e031-9aa6-431b-81c0-73fcb87af027&at=APZUnTUx56WM_I3gnhHo-eZX__kx:1706158167271", "url": "https://drive.usercontent.google.com/download?id=18TvzE8jnULU2g9XJsT-TaPEKcLGNVfu0&export=download&authuser=0&confirm=t&uuid=d914e031-9aa6-431b-81c0-73fcb87af027&at=APZUnTUx56WM_I3gnhHo-eZX__kx:1706158167271",
"path": "/home/user/Desktop/form.docx" "path": "C:\\Users\\chenj\\Desktop\\form.docx"
} }
] ]
} }
@@ -68,7 +67,7 @@
{ {
"type": "open", "type": "open",
"parameters": { "parameters": {
"path": "/home/user/Desktop/form.docx" "path": "C:\\Users\\chenj\\Desktop\\form.docx"
} }
} }
], ],

View File

@@ -8,8 +8,7 @@
"type": "launch", "type": "launch",
"parameters": { "parameters": {
"command": [ "command": [
"libreoffice", "C:\\Program Files\\Microsoft Office\\root\\Office16\\EXCEL.EXE"
"--calc"
] ]
} }
}, },
@@ -18,8 +17,8 @@
"parameters": { "parameters": {
"files": [ "files": [
{ {
"url": "https://drive.usercontent.google.com/download?id=1wKXmJ14dnxSzdy9ZF_ePWU7zpevY6Dry&export=download&authuser=0&confirm=t&uuid=9b476c95-8eee-4a9a-8cee-c3620d5ce250&at=APZUnTUzDeeeMNr34DB1vEnBK6N7:1706719624132", "url": "https://drive.google.com/uc?id=1njAaNiujlh1DZzGK7nL5iZsppsNAMkH7&export=download",
"path": "/home/user/thunderbird-profile.tar.gz" "path": "C:\\Users\\chenj\\thunderbird-profile.7z"
} }
] ]
} }
@@ -28,21 +27,30 @@
"type": "execute", "type": "execute",
"parameters": { "parameters": {
"command": [ "command": [
"tar", "C:\\Program Files\\7-Zip\\7z.exe",
"--recursive-unlink", "x", "C:\\Users\\chenj\\thunderbird-profile.7z"
"-xz",
"-f",
"/home/user/thunderbird-profile.tar.gz",
"-C",
"/home/user/"
] ]
} }
}, },
{
"type": "execute",
"parameters": {
"command": "rm -r C:\\Users\\chenj\\AppData\\Roaming\\Thunderbird",
"shell": true
}
},
{
"type": "execute",
"parameters": {
"command": "mv C:\\Users\\chenj\\Thunderbird C:\\Users\\chenj\\AppData\\Roaming\\Thunderbird",
"shell": true
}
},
{ {
"type": "launch", "type": "launch",
"parameters": { "parameters": {
"command": [ "command": [
"/usr/bin/thunderbird" "C:\\Program Files\\Mozilla Thunderbird\\thunderbird.exe"
] ]
} }
} }
@@ -61,12 +69,12 @@
"result": [ "result": [
{ {
"type": "vm_file", "type": "vm_file",
"path": "/home/user/Desktop/contacts.csv", "path": "C:\\Users\\chenj\\Desktop\\contacts.csv",
"dest": "contacts.csv" "dest": "contacts.csv"
}, },
{ {
"type": "vm_file", "type": "vm_file",
"path": "/home/user/Desktop/contacts.xlsx", "path": "C:\\Users\\chenj\\Desktop\\contacts.xlsx",
"dest": "contacts.xlsx" "dest": "contacts.xlsx"
} }
], ],

View File

@@ -10,11 +10,11 @@
"files": [ "files": [
{ {
"url": "https://drive.usercontent.google.com/download?id=1JGZNCShtmpu7A8Z8lkjc8hdFEAMXZVvh&export=download&authuser=0&confirm=t&uuid=67063da6-2a72-4ed2-92b2-ade508439ce4&at=APZUnTUgS17YjX-D0oSvALwnPosB:1709368886960", "url": "https://drive.usercontent.google.com/download?id=1JGZNCShtmpu7A8Z8lkjc8hdFEAMXZVvh&export=download&authuser=0&confirm=t&uuid=67063da6-2a72-4ed2-92b2-ade508439ce4&at=APZUnTUgS17YjX-D0oSvALwnPosB:1709368886960",
"path": "/home/user/Desktop/2023_validation_Book_Reading_Rate.xlsx" "path": "C:\\Users\\chenj\\Desktop\\2023_validation_Book_Reading_Rate.xlsx"
}, },
{ {
"url": "https://drive.usercontent.google.com/download?id=1iySmK8zvTzgmERH7KQuESP05NBsMunhV&export=download&authuser=0&confirm=t&uuid=130f6cee-0f9a-4f2e-a84d-89a3b302f350&at=APZUnTXugQOTOApe1_zxUbafo2Sp:1709369519349", "url": "https://drive.usercontent.google.com/download?id=1iySmK8zvTzgmERH7KQuESP05NBsMunhV&export=download&authuser=0&confirm=t&uuid=130f6cee-0f9a-4f2e-a84d-89a3b302f350&at=APZUnTXugQOTOApe1_zxUbafo2Sp:1709369519349",
"path": "/home/user/Desktop/book_list_result.docx" "path": "C:\\Users\\chenj\\Desktop\\book_list_result.docx"
} }
] ]
} }
@@ -22,7 +22,7 @@
{ {
"type": "open", "type": "open",
"parameters": { "parameters": {
"path": "/home/user/Desktop/2023_validation_Book_Reading_Rate.xlsx" "path": "C:\\Users\\chenj\\Desktop\\2023_validation_Book_Reading_Rate.xlsx"
} }
} }
], ],
@@ -38,7 +38,7 @@
{ {
"type": "activate_window", "type": "activate_window",
"parameters": { "parameters": {
"window_name": "book_list_result.docx - LibreOffice Writer", "window_name": "book_list_result - Word",
"strict": true "strict": true
} }
}, },
@@ -54,10 +54,16 @@
"command": [ "command": [
"python", "python",
"-c", "-c",
"import pyautogui; import time; pyautogui.hotkey('ctrl', 's'); time.sleep(0.5); " "import pyautogui; import time; pyautogui.hotkey(\"ctrl\", \"s\"); time.sleep(0.5); pyautogui.press(\"enter\");"
] ]
} }
} },
{
"type": "sleep",
"parameters": {
"seconds": 0.5
}
}
], ],
"expected": { "expected": {
"type": "cloud_file", "type": "cloud_file",
@@ -66,7 +72,7 @@
}, },
"result": { "result": {
"type": "vm_file", "type": "vm_file",
"path": "/home/user/Desktop/book_list_result.docx", "path": "C:\\Users\\chenj\\Desktop\\book_list_result.docx",
"dest": "book_list_result.docx" "dest": "book_list_result.docx"
} }
} }

View File

@@ -10,21 +10,27 @@
"files": [ "files": [
{ {
"url": "https://drive.usercontent.google.com/download?id=1l09TnSiXo-qOK2UazcIdrT_M6JwTfzq7&export=download&authuser=0&confirm=t&uuid=80bd550f-f3a6-4b69-ae0f-221c12b11fd9&at=APZUnTWgUlKuIDJZmkr0Q9Bze3w_:1709784652645", "url": "https://drive.usercontent.google.com/download?id=1l09TnSiXo-qOK2UazcIdrT_M6JwTfzq7&export=download&authuser=0&confirm=t&uuid=80bd550f-f3a6-4b69-ae0f-221c12b11fd9&at=APZUnTWgUlKuIDJZmkr0Q9Bze3w_:1709784652645",
"path": "/home/user/Desktop/calculator.zip" "path": "C:\\Users\\chenj\\Desktop\\calculator.zip"
} }
] ]
} }
}, },
{ {
"type": "execute", "type": "execute",
"parameters": { "parameters": {
"command": [ "command": [
"/bin/bash", "C:\\Program Files\\7-Zip\\7z.exe",
"-c", "C:\\Users\\chenj\\Desktop\\calculator.zip"
"unzip /home/user/Desktop/calculator.zip -d /home/user/Desktop/ && rm -rf /home/user/Desktop/calculator.zip" ]
] }
} },
} {
"type": "execute",
"parameters": {
"command": "rm C:\\Users\\chenj\\Desktop\\calculator.zip",
"shell": true
}
}
], ],
"trajectory": "trajectories/f918266a-b3e0-4914-865d-4faa564f1aef", "trajectory": "trajectories/f918266a-b3e0-4914-865d-4faa564f1aef",
"related_apps": [ "related_apps": [

View File

@@ -1,4 +1,4 @@
{ {
"email": "xlang2024anonym@gmail.com", "email": "xlang2024anonym@gmail.com",
"password": "q]wN~0iD>H:6" "password": "Evt5LLj!VJ6Y!C$B"
} }

View File

@@ -1 +0,0 @@
{"access_token": "ya29.a0Ad52N382_JIl2nZBNpJCgoU3HXk2Kz7CArVYn_PGI8pXFucAozry1Vmp5QolzGrnl4UChZswJDOgcdPm5Ew-NbdHPX95wxknoG1oJKqjWYtjl3mw433hiGtriuKWKnXcz1NUf8ewqqq458tJLLDhbbZFW7eZRQrdJzmrGAaCgYKAZ4SARISFQHGX2Mik2MQ5qx0goIypVyzbcUmYw0173", "client_id": "786888752612-rgng5v9hcq4as7pn0b40gt9r5lekmht9.apps.googleusercontent.com", "client_secret": "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", "refresh_token": "1//0eVpYfdSAjvbCCgYIARAAGA4SNwF-L9IrAgL6KVceiEVTjtQdmPki2I3m8ejP3lzTLL2Wa3-rdrYfU7eYeKDVCS5KRxa_xCE_pPY", "token_expiry": "2024-03-13T10:09:01Z", "token_uri": "https://oauth2.googleapis.com/token", "user_agent": null, "revoke_uri": "https://oauth2.googleapis.com/revoke", "id_token": null, "id_token_jwt": null, "token_response": {"access_token": "ya29.a0Ad52N382_JIl2nZBNpJCgoU3HXk2Kz7CArVYn_PGI8pXFucAozry1Vmp5QolzGrnl4UChZswJDOgcdPm5Ew-NbdHPX95wxknoG1oJKqjWYtjl3mw433hiGtriuKWKnXcz1NUf8ewqqq458tJLLDhbbZFW7eZRQrdJzmrGAaCgYKAZ4SARISFQHGX2Mik2MQ5qx0goIypVyzbcUmYw0173", "expires_in": 3599, "scope": "https://www.googleapis.com/auth/drive", "token_type": "Bearer"}, "scopes": ["https://www.googleapis.com/auth/drive"], "token_info_uri": "https://oauth2.googleapis.com/tokeninfo", "invalid": false, "_class": "OAuth2Credentials", "_module": "oauth2client.client"}

32
main.py
View File

@@ -70,38 +70,6 @@ def human_agent():
done = False done = False
logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"]) logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"])
trajectory = [
{
"action_type": "MOVE_TO", #
"parameters": {
"x": 754,
"y": 1057
}
},
{"action_type": "CLICK", "parameters": {"button": "right", "num_clicks": 1}}
]
for i in range(len(trajectory)):
# action = get_human_action()
# action = {
# "action_type": 0,
# "click_type": 3,
# }
logger.info(trajectory[i])
observation, reward, done, info = env.step(trajectory[i])
observation.pop("accessibility_tree")
logger.info("Observation: %s", observation)
logger.info("Reward: %.2f", reward)
logger.info("Info: %s", info)
logger.info("================================\n")
if done:
logger.info("The episode is done.")
break
input("Press Enter to start human operation...") input("Press Enter to start human operation...")
human_start_time = time.time() human_start_time = time.time()
input("Press Enter to finish human operation.") input("Press Enter to finish human operation.")

View File

@@ -360,7 +360,7 @@ class PromptAgent:
# {{{1 # {{{1
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]: if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = encode_image(obs["screenshot"]) base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None
logger.debug("LINEAR AT: %s", linearized_accessibility_tree) logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if self.observation_type == "screenshot_a11y_tree": if self.observation_type == "screenshot_a11y_tree":

10
run.py
View File

@@ -95,6 +95,10 @@ def config() -> argparse.Namespace:
parser.add_argument("--max_tokens", type=int, default=1500) parser.add_argument("--max_tokens", type=int, default=1500)
parser.add_argument("--stop_token", type=str, default=None) parser.add_argument("--stop_token", type=str, default=None)
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument("--test_all_meta_path", type=str, default="evaluation_examples/test_all.json")
# logging related # logging related
parser.add_argument("--result_dir", type=str, default="./results") parser.add_argument("--result_dir", type=str, default="./results")
args = parser.parse_args() args = parser.parse_args()
@@ -144,6 +148,7 @@ def test(
action_space=agent.action_space, action_space=agent.action_space,
screen_size=(args.screen_width, args.screen_height), screen_size=(args.screen_width, args.screen_height),
headless=args.headless, headless=args.headless,
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
) )
for domain in tqdm(test_all_meta, desc="Domain"): for domain in tqdm(test_all_meta, desc="Domain"):
@@ -264,9 +269,12 @@ if __name__ == '__main__':
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config() args = config()
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f: with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f) test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished( test_file_list = get_unfinished(
args.action_space, args.action_space,
args.model, args.model,