ver Dec21stv3

merged zdy to main
This commit is contained in:
David Chang
2023-12-21 16:21:21 +08:00
4 changed files with 80 additions and 16 deletions

View File

@@ -8,13 +8,17 @@ class SetupController:
def __init__(self, http_server: str):
self.http_server = http_server + "/setup"
def setup(self, config):
def setup(self, config: List[Dict[str, Any]]):
"""
Setup Config:
{
download: list[tuple[string]], # a list of tuples of url of file to download and the save path
...
}
Args:
config (List[Dict[str, Any]]): list of dict like {str: Any}. each
config dict has the structure like
{
"type": str, corresponding to the `_{:}_setup` methods of
this class
"parameters": dick like {str, Any} providing the keyword
parameters
}
"""
for cfg in config:

View File

@@ -3,9 +3,10 @@ from __future__ import annotations
import os
import subprocess
import time
import uuid
import platform
#import uuid
#import platform
from typing import List
import tempfile
import gymnasium as gym
import requests
@@ -33,15 +34,48 @@ class DesktopEnv(gym.Env):
self,
path_to_vm: str,
snapshot_path: str = "base",
task_id: str = "",
instruction: str = None,
config: dict = None,
evaluator: dict = None,
action_space: str = "computer_13",
tmp_dir: str = "tmp",
cache_dir: str = "cache"
):
"""
Args:
path_to_vm (str): path to .vmx file
snapshot_path (str): name of the based snapshot, can be found in
`snapshot*.displayName` field in .vmsd file
task_id (str): identifying the task
instruction (str): task instruction
config (List[Dict[str, Any]]): the config dict, refer to the doc of
SetupController
evaluator (Dict[str, Any]): defines the evaluation method. dict
like
{
"func": str as the evaluation method
"paths": paths to the involved files needed for evaluation
}
action_space (str): "computer_13" | "pyautogui"
tmp_dir (str): temporary directory to store trajectory stuffs like
the extracted screenshots
cache_dir (str): cache directory to cache task-related stuffs like
reference file for evaluation
"""
# Initialize environment variables
self.path_to_vm = path_to_vm
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
self.task_id: str = task_id
self.tmp_dir_base: str = tmp_dir
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
self.cache_dir: str = os.path.join(cache_dir, task_id)
os.makedirs(self.cache_dir, exist_ok=True)
# Initialize emulator and controller
print("Initializing...")
self._start_emulator()
@@ -57,12 +91,18 @@ class DesktopEnv(gym.Env):
self.action_space = action_space
# todo: define the action space and the observation space as gym did, or extend theirs
# counters
self._traj_no: int = -1
self._step_no: int = 0
def _start_emulator(self):
while True:
try:
output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
output = output.decode()
if self.path_to_vm.lstrip("~/") in output:
output: List[str] = output.splitlines()
#if self.path_to_vm.lstrip("~/") in output:
if self.path_to_vm in output:
print("VM is running.")
break
else:
@@ -89,9 +129,10 @@ class DesktopEnv(gym.Env):
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
def _get_screenshot(self):
random_uuid = str(uuid.uuid4())
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
#random_uuid = str(uuid.uuid4())
#os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
#image_path = os.path.join("tmp", random_uuid, "screenshot.png")
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
# Get the screenshot and save to the image_path
screenshot = self.controller.get_screenshot()
@@ -107,6 +148,16 @@ class DesktopEnv(gym.Env):
def reset(self, seed=None, options=None):
print("Resetting environment...")
print("Setting counters...")
self._traj_no += 1
self._step_no = 0
print("Setup new temp dir...")
self.tmp_dir = tempfile.mkdtemp( prefix="{:d}.{:}.".format(self._traj_no, self.task_id)
, dir=self.tmp_dir_base
)
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
print("Reverting to snapshot to {}...".format(self.snapshot_path))
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
time.sleep(5)
@@ -125,6 +176,8 @@ class DesktopEnv(gym.Env):
return observation
def step(self, action, pause=0.5):
self._step_no += 1
# fixme: add reminding logic here, decide if the action is valid for the current action_space
if self.action_space == "computer_13":
# the set of all possible actions defined in the action representation
@@ -149,9 +202,13 @@ class DesktopEnv(gym.Env):
Evaluate whether the task is successfully completed.
"""
def copy_file_to_local(_file_info):
random_uuid = str(uuid.uuid4())
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
#random_uuid = str(uuid.uuid4())
#os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
#_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
_path = os.path.join(self.cache_dir, _file_info["dest"])
if os.path.exists(_path):
return _path
if _file_info["type"] == "cloud_file":
url = _file_info["path"]
response = requests.get(url, stream=True)

View File

@@ -32,10 +32,12 @@
"expected": {
"type": "cloud_file",
"path": "https://drive.usercontent.google.com/download?id=17f1wZuJPvUEc5at_Fy3c18VFdOk0x7xz&export=download&authuser=0&confirm=t&uuid=6d2edffd-0ce0-426e-9820-8af25b4667f3&at=APZUnTVh7JS85dwZBaV2hytWQgDK:1702361510956"
"dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx"
},
"actual": {
"type": "vm_file",
"path": "C:\\Users\\tianbaox\\Desktop\\Quarterly_Product_Sales_by_Zone.xlsx"
"dest": "Quarterly_Product_Sales_by_Zone.xlsx"
}
}
}

View File

@@ -13,6 +13,7 @@ def human_agent():
env = DesktopEnv( path_to_vm=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
, action_space="computer_13"
, snapshot_path="base_setup"
, task_id=example["id"]
, instruction=example["instruction"]
, config=example["config"]
, evaluator=example["evaluator"]