Clean code; Refactor environment to pass screenshot content instead of path

This commit is contained in:
Timothyxxx
2024-04-13 23:34:01 +08:00
parent b9ae9b72b2
commit 9c75df5dce
10 changed files with 144 additions and 213 deletions

View File

@@ -2,6 +2,7 @@ import json
import logging import logging
import os import os
import os.path import os.path
import shutil
import sqlite3 import sqlite3
import tempfile import tempfile
import time import time
@@ -11,7 +12,6 @@ from datetime import datetime, timedelta
from typing import Any, Union, Optional from typing import Any, Union, Optional
from typing import Dict, List from typing import Dict, List
import shutil
import requests import requests
from playwright.sync_api import sync_playwright, TimeoutError from playwright.sync_api import sync_playwright, TimeoutError
from pydrive.auth import GoogleAuth from pydrive.auth import GoogleAuth
@@ -25,6 +25,7 @@ logger = logging.getLogger("desktopenv.setup")
FILE_PATH = os.path.dirname(os.path.abspath(__file__)) FILE_PATH = os.path.dirname(os.path.abspath(__file__))
class SetupController: class SetupController:
def __init__(self, vm_ip: str, cache_dir: str): def __init__(self, vm_ip: str, cache_dir: str):
self.vm_ip: str = vm_ip self.vm_ip: str = vm_ip
@@ -60,39 +61,6 @@ class SetupController:
logger.info("SETUP: %s(%s)", setup_function, str(parameters)) logger.info("SETUP: %s(%s)", setup_function, str(parameters))
# self._download_setup(config)
# self._change_wallpaper(config)
# self._tidy_desktop(config) todo: implement this
# self._open_setup(config)
# can add other setup steps
# ZDY_COMMENT: merged with launch
# def _command_setup(self, command: str):
# """
# Directly send a command into the virtual machine os for setting up.
# """
# payload = json.dumps({"command": command})
# headers = {
# 'Content-Type': 'application/json'
# }
# timeout = 5
# timout_whitelist = ["vlc"]
#
# try:
#
# response = requests.post(self.http_server + "/execute", headers=headers, data=payload, timeout=timeout)
# if response.status_code == 200:
# print("Command executed successfully:", response.text)
# else:
# print("Failed to execute command. Status code:", response.status_code)
# except requests.exceptions.Timeout as e:
# if command in timout_whitelist:
# print("Command executed successfully:", command)
# else:
# print("An error occurred while trying to execute the command:", e)
# except requests.exceptions.RequestException as e:
# print("An error occurred while trying to execute the command:", e)
def _download_setup(self, files: List[Dict[str, str]]): def _download_setup(self, files: List[Dict[str, str]]):
""" """
Args: Args:
@@ -140,11 +108,6 @@ class SetupController:
if not downloaded: if not downloaded:
raise requests.RequestException(f"Failed to download {url}. No retries left. Error: {e}") raise requests.RequestException(f"Failed to download {url}. No retries left. Error: {e}")
# payload = json.dumps({"url": url, "path": path})
# headers = {
# 'Content-Type': 'application/json'
# }
form = MultipartEncoder({ form = MultipartEncoder({
"file_path": path, "file_path": path,
"file_data": (os.path.basename(path), open(cache_path, "rb")) "file_data": (os.path.basename(path), open(cache_path, "rb"))
@@ -163,6 +126,41 @@ class SetupController:
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.error("An error occurred while trying to send the request: %s", e) logger.error("An error occurred while trying to send the request: %s", e)
def _upload_file_setup(self, files: List[Dict[str, str]]):
"""
Args:
files (List[Dict[str, str]]): files to download. lisf of dict like
{
"local_path": str, the local path to the file to upload
"path": str, the path on the VM to store the downloaded file
}
"""
for f in files:
local_path: str = f["local_path"]
path: str = f["path"]
if not os.path.exists(local_path):
logger.error(f"Setup Upload - Invalid local path ({local_path}).")
return
form = MultipartEncoder({
"file_path": path,
"file_data": (os.path.basename(path), open(local_path, "rb"))
})
headers = {"Content-Type": form.content_type}
logger.debug(form.content_type)
# send request to server to upload file
try:
logger.debug("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/upload")
response = requests.post(self.http_server + "/setup" + "/upload", headers=headers, data=form)
if response.status_code == 200:
logger.info("Command executed successfully: %s", response.text)
else:
logger.error("Failed to upload file. Status code: %s", response.text)
except requests.exceptions.RequestException as e:
logger.error("An error occurred while trying to send the request: %s", e)
def _change_wallpaper_setup(self, path: str): def _change_wallpaper_setup(self, path: str):
# if not config: # if not config:
# return # return

View File

@@ -3,22 +3,16 @@ from __future__ import annotations
import logging import logging
import os import os
import subprocess import subprocess
import tempfile
import time import time
from typing import Callable, Any, Optional, Tuple from typing import Callable, Any, Optional, Tuple
# import uuid
# import platform
from typing import List, Dict, Union from typing import List, Dict, Union
import gymnasium as gym import gymnasium as gym
from desktop_env.controllers.python import PythonController from desktop_env.controllers.python import PythonController
from desktop_env.controllers.setup import SetupController from desktop_env.controllers.setup import SetupController
# from desktop_env.evaluators import eval_funcs
from desktop_env.evaluators import metrics, getters from desktop_env.evaluators import metrics, getters
# import requests
logger = logging.getLogger("desktopenv.env") logger = logging.getLogger("desktopenv.env")
Metric = Callable[[Any, Any], float] Metric = Callable[[Any, Any], float]
@@ -46,8 +40,7 @@ def _execute_command(command: List[str]) -> None:
class DesktopEnv(gym.Env): class DesktopEnv(gym.Env):
""" """
DesktopEnv with OpenAI Gym interface. DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
Fixme: refactor the logic when implementing the multi-process version
""" """
def __init__( def __init__(
@@ -55,32 +48,33 @@ class DesktopEnv(gym.Env):
path_to_vm: str, path_to_vm: str,
snapshot_name: str = "init_state", snapshot_name: str = "init_state",
action_space: str = "computer_13", action_space: str = "computer_13",
tmp_dir: str = "tmp",
cache_dir: str = "cache", cache_dir: str = "cache",
screen_size: Tuple[int] = (1920, 1080), screen_size: Tuple[int] = (1920, 1080),
headless: bool = False, headless: bool = False,
require_a11y_tree: bool = True, require_a11y_tree: bool = True,
require_terminal: bool = False,
): ):
""" """
Args: Args:
path_to_vm (str): path to .vmx file path_to_vm (str): path to .vmx file
snapshot_name (str): snapshot name to revert to, default to "init_state"
action_space (str): "computer_13" | "pyautogui" action_space (str): "computer_13" | "pyautogui"
tmp_dir (str): temporary directory to store trajectory stuffs like
the extracted screenshots
cache_dir (str): cache directory to cache task-related stuffs like cache_dir (str): cache directory to cache task-related stuffs like
reference file for evaluation reference file for evaluation
screen_size (Tuple[int]): screen size of the VM
headless (bool): whether to run the VM in headless mode
require_a11y_tree (bool): whether to require accessibility tree
require_terminal (bool): whether to require terminal output
""" """
# Initialize environment variables # Initialize environment variables
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
self.snapshot_name = snapshot_name self.snapshot_name = snapshot_name
self.tmp_dir_base: str = tmp_dir
self.cache_dir_base: str = cache_dir self.cache_dir_base: str = cache_dir
self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM
self.headless = headless self.headless = headless
self.require_a11y_tree = require_a11y_tree self.require_a11y_tree = require_a11y_tree
self.require_terminal = require_terminal
os.makedirs(self.tmp_dir_base, exist_ok=True)
# Initialize emulator and controller # Initialize emulator and controller
logger.info("Initializing...") logger.info("Initializing...")
@@ -89,17 +83,17 @@ class DesktopEnv(gym.Env):
self.controller = PythonController(vm_ip=self.vm_ip) self.controller = PythonController(vm_ip=self.vm_ip)
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base) self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
# Meta info of the VM, move to the reset() function # Meta info of the VM
self.vm_platform: str = "" # self.controller.get_vm_platform() self.vm_platform: str = self.controller.get_vm_platform()
self.vm_screen_size = self.controller.get_vm_screen_size()
# mode: human or machine # mode: human or machine
self.instruction = None
assert action_space in ["computer_13", "pyautogui"] assert action_space in ["computer_13", "pyautogui"]
self.action_space = action_space self.action_space = action_space
# todo: define the action space and the observation space as gym did, or extend theirs
# episodic stuffs, like tmp dir and counters, will be updated or reset # episodic stuffs, like counters, will be updated or reset
# when calling self.reset() # when calling self.reset()
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
self._traj_no: int = -1 self._traj_no: int = -1
self._step_no: int = 0 self._step_no: int = 0
self.action_history: List[Dict[str, any]] = [] self.action_history: List[Dict[str, any]] = []
@@ -140,11 +134,7 @@ class DesktopEnv(gym.Env):
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name]) _execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
def _get_screenshot(self): def _get_screenshot(self):
# random_uuid = str(uuid.uuid4()) screenshot = None
# os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
# image_path = os.path.join("tmp", random_uuid, "screenshot.png")
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
# Get the screenshot and save to the image_path # Get the screenshot and save to the image_path
max_retries = 20 max_retries = 20
for _ in range(max_retries): for _ in range(max_retries):
@@ -153,14 +143,18 @@ class DesktopEnv(gym.Env):
break break
time.sleep(1) time.sleep(1)
with open(image_path, "wb") as f: if screenshot is None:
f.write(screenshot) logger.error("Failed to get screenshot!")
return image_path return screenshot
def _get_obs(self): def _get_obs(self):
screenshot_image_path = self._get_screenshot() return {
return screenshot_image_path "screenshot": self._get_screenshot(),
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
"instruction": self.instruction
}
def _set_task_info(self, task_config: Dict[str, Any]): def _set_task_info(self, task_config: Dict[str, Any]):
self.task_id: str = task_config["id"] self.task_id: str = task_config["id"]
@@ -227,18 +221,10 @@ class DesktopEnv(gym.Env):
self._step_no = 0 self._step_no = 0
self.action_history.clear() self.action_history.clear()
logger.info("Setup new temp dir...")
self.tmp_dir = tempfile.mkdtemp(
prefix="{:d}.{:}.".format(self._traj_no, self.task_id),
dir=self.tmp_dir_base
)
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name)) logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name]) _execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name])
time.sleep(5) time.sleep(5)
print(self.vm_screen_size)
logger.info("Starting emulator...") logger.info("Starting emulator...")
self._start_emulator() self._start_emulator()
logger.info("Emulator started.") logger.info("Emulator started.")
@@ -246,7 +232,6 @@ class DesktopEnv(gym.Env):
logger.info("Get meta info of the VM...") logger.info("Get meta info of the VM...")
self.vm_platform = self.controller.get_vm_platform() self.vm_platform = self.controller.get_vm_platform()
self.vm_screen_size = self.controller.get_vm_screen_size() self.vm_screen_size = self.controller.get_vm_screen_size()
print(self.vm_screen_size)
logger.info("Setting up environment...") logger.info("Setting up environment...")
self.setup_controller.setup(self.config) self.setup_controller.setup(self.config)
@@ -254,10 +239,7 @@ class DesktopEnv(gym.Env):
time.sleep(5) time.sleep(5)
logger.info("Environment setup complete.") logger.info("Environment setup complete.")
observation = { observation = self._get_obs()
"screenshot": self._get_obs(),
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
}
return observation return observation
def step(self, action, pause=0.5): def step(self, action, pause=0.5):
@@ -279,7 +261,6 @@ class DesktopEnv(gym.Env):
done = True done = True
info = {"done": True} info = {"done": True}
# fixme: add reminding logic here, decide if the action is valid for the current action_space
if self.action_space == "computer_13": if self.action_space == "computer_13":
# the set of all possible actions defined in the action representation # the set of all possible actions defined in the action representation
self.controller.execute_action(action) self.controller.execute_action(action)
@@ -290,12 +271,7 @@ class DesktopEnv(gym.Env):
# the set of all possible python commands insides `pyautogui` # the set of all possible python commands insides `pyautogui`
self.controller.execute_python_command(action) self.controller.execute_python_command(action)
observation = { observation = self._get_obs()
"screenshot": self._get_obs(),
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
# "terminal": self.controller.get_terminal_output(),
"instruction": self.instruction
}
return observation, reward, done, info return observation, reward, done, info
@@ -358,7 +334,7 @@ class DesktopEnv(gym.Env):
def render(self, mode='rgb_array'): def render(self, mode='rgb_array'):
if mode == 'rgb_array': if mode == 'rgb_array':
return self._get_obs() return self._get_screenshot()
else: else:
raise ValueError('Unsupported render mode: {}'.format(mode)) raise ValueError('Unsupported render mode: {}'.format(mode))

View File

@@ -1,5 +1,6 @@
import csv import csv
# I want to write a function, reads a csv file, and get all the contents in the third column in the order of rows # I want to write a function, reads a csv file, and get all the contents in the third column in the order of rows
def get_conference_city_in_order(env, config): def get_conference_city_in_order(env, config):
# read the csv file # read the csv file
@@ -12,4 +13,3 @@ def get_conference_city_in_order(env, config):
# get the third column in the order of rows # get the third column in the order of rows
conference_city_list = [row[2] for row in reader] conference_city_list = [row[2] for row in reader]
return conference_city_list return conference_city_list

View File

@@ -99,6 +99,7 @@ from .gimp import (
check_image_file_size check_image_file_size
) )
from .libreoffice import check_libre_locale from .libreoffice import check_libre_locale
from .others import compare_epub, check_mp3_meta
from .pdf import check_pdf_pages from .pdf import check_pdf_pages
from .slides import ( from .slides import (
check_presenter_console_disable, check_presenter_console_disable,
@@ -150,7 +151,7 @@ from .vscode import (
check_html_background_image, check_html_background_image,
compare_zip_files compare_zip_files
) )
from .others import compare_epub, check_mp3_meta
def infeasible(): def infeasible():
pass pass

View File

@@ -2,7 +2,6 @@ import datetime
import json import json
import logging import logging
import os import os
# import wandb
from wrapt_timeout_decorator import * from wrapt_timeout_decorator import *
@@ -14,6 +13,7 @@ with open("./settings.json", "r") as file:
data = json.load(file) data = json.load(file)
time_limit = data["time_limit"] time_limit = data["time_limit"]
@timeout(time_limit, use_signals=False) @timeout(time_limit, use_signals=False)
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores): def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
agent.reset() agent.reset()
@@ -21,7 +21,6 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
done = False done = False
step_idx = 0 step_idx = 0
env.controller.start_recording() env.controller.start_recording()
# str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"])
while not done and step_idx < max_steps: while not done and step_idx < max_steps:
response, actions = agent.predict( response, actions = agent.predict(
instruction, instruction,
@@ -38,15 +37,7 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
# Save screenshot and trajectory information # Save screenshot and trajectory information
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f: "wb") as _f:
with open(obs['screenshot'], "rb") as __f: _f.write(obs['screenshot'])
screenshot = __f.read()
_f.write(screenshot)
# get a11tree and save to wandb
# thisrun_a11tree = env.controller.get_accessibility_tree()
# str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"),
# thisrun_a11tree,
# response, action, action_timestamp, done)
# run.log({"Reward": reward})
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({ f.write(json.dumps({
"step_num": step_idx + 1, "step_num": step_idx + 1,
@@ -62,11 +53,9 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
logger.info("The episode is done.") logger.info("The episode is done.")
break break
step_idx += 1 step_idx += 1
# run.log({"str_trajectory": str_table})
result = env.evaluate() result = env.evaluate()
logger.info("Result: %.2f", result) logger.info("Result: %.2f", result)
scores.append(result) scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n") f.write(f"{result}\n")
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
# run.log({"Result": result})

View File

@@ -47,8 +47,7 @@ def human_agent():
Runs the Gym environment with human input. Runs the Gym environment with human input.
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.") parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx", help="Path to the virtual machine .vmx file.")
parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.")
parser.add_argument('-e', '--example', type=str, help="Path to the example json file.") parser.add_argument('-e', '--example', type=str, help="Path to the example json file.")
args = parser.parse_args(sys.argv[1:]) args = parser.parse_args(sys.argv[1:])
@@ -56,13 +55,10 @@ def human_agent():
'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json' 'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json'
with open(example_path, "r", encoding="utf-8") as f: with open(example_path, "r", encoding="utf-8") as f:
example = json.load(f) example = json.load(f)
if args.snapshot is not None:
example['snapshot'] = args.snapshot
assert os.path.exists(args.path), "The specified path to the .vmx file does not exist." assert os.path.exists(args.path), "The specified path to the .vmx file does not exist."
env = DesktopEnv( env = DesktopEnv(
path_to_vm=args.path, path_to_vm=args.path,
snapshot_name=args.snapshot,
action_space="computer_13" action_space="computer_13"
) )
# reset the environment to certain snapshot # reset the environment to certain snapshot

View File

@@ -1,8 +1,9 @@
import io
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Tuple, List
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from typing import Tuple, List
def find_leaf_nodes(xlm_file_str): def find_leaf_nodes(xlm_file_str):
if not xlm_file_str: if not xlm_file_str:
@@ -24,8 +25,11 @@ def find_leaf_nodes(xlm_file_str):
collect_leaf_nodes(root, leaf_nodes) collect_leaf_nodes(root, leaf_nodes)
return leaf_nodes return leaf_nodes
state_ns = "uri:deskat:state.at-spi.gnome.org" state_ns = "uri:deskat:state.at-spi.gnome.org"
component_ns = "uri:deskat:component.at-spi.gnome.org" component_ns = "uri:deskat:component.at-spi.gnome.org"
def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool: def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
keeps: bool = node.tag.startswith("document") \ keeps: bool = node.tag.startswith("document") \
or node.tag.endswith("item") \ or node.tag.endswith("item") \
@@ -69,6 +73,7 @@ def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0 keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0
return keeps return keeps
def filter_nodes(root: ET, platform="ubuntu", check_image=False): def filter_nodes(root: ET, platform="ubuntu", check_image=False):
filtered_nodes = [] filtered_nodes = []
@@ -80,9 +85,10 @@ def filter_nodes(root: ET, platform="ubuntu", check_image=False):
return filtered_nodes return filtered_nodes
def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sampling_ratio=1.0): def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0):
# Load the screenshot image # Load the screenshot image
image = Image.open(image_file_path) image_stream = io.BytesIO(image_file_content)
image = Image.open(image_stream)
if float(down_sampling_ratio) != 1.0: if float(down_sampling_ratio) != 1.0:
image = image.resize((int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio))) image = image.resize((int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio)))
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
@@ -176,26 +182,14 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam
except ValueError: except ValueError:
pass pass
# Save the result output_image_stream = io.BytesIO()
image.save(output_image_file_path) image.save(output_image_stream, format='PNG')
return marks, drew_nodes, "\n".join(text_informations) image_content = output_image_stream.getvalue()
return marks, drew_nodes, "\n".join(text_informations), image_content
def print_nodes_with_indent(nodes, indent=0): def print_nodes_with_indent(nodes, indent=0):
for node in nodes: for node in nodes:
print(' ' * indent, node.tag, node.attrib) print(' ' * indent, node.tag, node.attrib)
print_nodes_with_indent(node, indent + 2) print_nodes_with_indent(node, indent + 2)
if __name__ == '__main__':
import json
with open('3.xml', 'r', encoding='utf-8') as f:
xml_file_str = f.read()
filtered_nodes = filter_nodes(ET.fromstring(xml_file_str))
print(len(filtered_nodes))
masks = draw_bounding_boxes( filtered_nodes, '3.a.png'
, '3.png'
)
# print(masks)
print(len(masks))

View File

@@ -4,7 +4,6 @@ import logging
import os import os
import re import re
import time import time
import uuid
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
@@ -29,9 +28,8 @@ logger = logging.getLogger("desktopenv.agent")
# Function to encode the image # Function to encode the image
def encode_image(image_path): def encode_image(image_content):
with open(image_path, "rb") as image_file: return base64.b64encode(image_content).decode('utf-8')
return base64.b64encode(image_file.read()).decode('utf-8')
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"): def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
@@ -71,16 +69,11 @@ def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"): def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
# Creat a tmp file to store the screenshot in random name
uuid_str = str(uuid.uuid4())
os.makedirs("tmp/images", exist_ok=True)
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True) nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
# Make tag screenshot # Make tag screenshot
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path) marks, drew_nodes, element_list, tagged_screenshot = draw_bounding_boxes(nodes, screenshot)
return marks, drew_nodes, tagged_screenshot_file_path, element_list return marks, drew_nodes, tagged_screenshot, element_list
def parse_actions_from_string(input_string): def parse_actions_from_string(input_string):

16
run.py
View File

@@ -50,12 +50,6 @@ logger.addHandler(sdebug_handler)
logger = logging.getLogger("desktopenv.experiment") logger = logging.getLogger("desktopenv.experiment")
# wandb config
### set your wandb api key here
# os.environ["WANDB_API_KEY"] = "48ec18fb4da7087238c6d6833eab9907565adbf3"
# wandb.login(key=os.environ.get("WANDB_API_KEY", None))
def config() -> argparse.Namespace: def config() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark" description="Run end-to-end evaluation on the benchmark"
@@ -153,9 +147,6 @@ def test(
for domain in tqdm(test_all_meta, desc="Domain"): for domain in tqdm(test_all_meta, desc="Domain"):
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
# run = wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}",
# name=f"{example_id}")
# example setting
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
with open(config_file, "r", encoding="utf-8") as f: with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f) example = json.load(f)
@@ -186,19 +177,12 @@ def test(
scores) scores)
except Exception as e: except Exception as e:
logger.error(f"Exception in {domain}/{example_id}: {e}") logger.error(f"Exception in {domain}/{example_id}: {e}")
# wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])})
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({ f.write(json.dumps({
"Error": f"Time limit exceeded in {domain}/{example_id}" "Error": f"Time limit exceeded in {domain}/{example_id}"
})) }))
f.write("\n") f.write("\n")
# wandb settings
# os.mkdir(os.path.join(wandb.run.dir, "results/"))
# for file in os.listdir(example_result_dir):
# # move file to just under the root dir
# os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}"))
# wandb.finish()
env.close() env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}") logger.info(f"Average score: {sum(scores) / len(scores)}")