ver Dec21stv2
updated usage of tmp and cache direcotories added cache function for evaluation resources acquiring
This commit is contained in:
@@ -8,13 +8,17 @@ class SetupController:
|
|||||||
def __init__(self, http_server: str):
|
def __init__(self, http_server: str):
|
||||||
self.http_server = http_server + "/setup"
|
self.http_server = http_server + "/setup"
|
||||||
|
|
||||||
def setup(self, config):
|
def setup(self, config: List[Dict[str, Any]]):
|
||||||
"""
|
"""
|
||||||
Setup Config:
|
Args:
|
||||||
{
|
config (List[Dict[str, Any]]): list of dict like {str: Any}. each
|
||||||
download: list[tuple[string]], # a list of tuples of url of file to download and the save path
|
config dict has the structure like
|
||||||
...
|
{
|
||||||
}
|
"type": str, corresponding to the `_{:}_setup` methods of
|
||||||
|
this class
|
||||||
|
"parameters": dick like {str, Any} providing the keyword
|
||||||
|
parameters
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for cfg in config:
|
for cfg in config:
|
||||||
|
|||||||
@@ -3,9 +3,10 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import uuid
|
#import uuid
|
||||||
import platform
|
#import platform
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import tempfile
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import requests
|
import requests
|
||||||
@@ -33,15 +34,48 @@ class DesktopEnv(gym.Env):
|
|||||||
self,
|
self,
|
||||||
path_to_vm: str,
|
path_to_vm: str,
|
||||||
snapshot_path: str = "base",
|
snapshot_path: str = "base",
|
||||||
|
task_id: str = "",
|
||||||
instruction: str = None,
|
instruction: str = None,
|
||||||
config: dict = None,
|
config: dict = None,
|
||||||
evaluator: dict = None,
|
evaluator: dict = None,
|
||||||
action_space: str = "computer_13",
|
action_space: str = "computer_13",
|
||||||
|
tmp_dir: str = "tmp",
|
||||||
|
cache_dir: str = "cache"
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
path_to_vm (str): path to .vmx file
|
||||||
|
snapshot_path (str): name of the based snapshot, can be found in
|
||||||
|
`snapshot*.displayName` field in .vmsd file
|
||||||
|
|
||||||
|
task_id (str): identifying the task
|
||||||
|
instruction (str): task instruction
|
||||||
|
config (List[Dict[str, Any]]): the config dict, refer to the doc of
|
||||||
|
SetupController
|
||||||
|
evaluator (Dict[str, Any]): defines the evaluation method. dict
|
||||||
|
like
|
||||||
|
{
|
||||||
|
"func": str as the evaluation method
|
||||||
|
"paths": paths to the involved files needed for evaluation
|
||||||
|
}
|
||||||
|
action_space (str): "computer_13" | "pyautogui"
|
||||||
|
|
||||||
|
tmp_dir (str): temporary directory to store trajectory stuffs like
|
||||||
|
the extracted screenshots
|
||||||
|
cache_dir (str): cache directory to cache task-related stuffs like
|
||||||
|
reference file for evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
# Initialize environment variables
|
# Initialize environment variables
|
||||||
self.path_to_vm = path_to_vm
|
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
|
||||||
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
|
||||||
|
|
||||||
|
self.task_id: str = task_id
|
||||||
|
self.tmp_dir_base: str = tmp_dir
|
||||||
|
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
|
||||||
|
self.cache_dir: str = os.path.join(cache_dir, task_id)
|
||||||
|
os.makedirs(self.cache_dir, exist_ok=True)
|
||||||
|
|
||||||
# Initialize emulator and controller
|
# Initialize emulator and controller
|
||||||
print("Initializing...")
|
print("Initializing...")
|
||||||
self._start_emulator()
|
self._start_emulator()
|
||||||
@@ -57,12 +91,18 @@ class DesktopEnv(gym.Env):
|
|||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
# todo: define the action space and the observation space as gym did, or extend theirs
|
# todo: define the action space and the observation space as gym did, or extend theirs
|
||||||
|
|
||||||
|
# counters
|
||||||
|
self._traj_no: int = -1
|
||||||
|
self._step_no: int = 0
|
||||||
|
|
||||||
def _start_emulator(self):
|
def _start_emulator(self):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
|
output = subprocess.check_output("vmrun -T ws list", shell=True, stderr=subprocess.STDOUT)
|
||||||
output = output.decode()
|
output = output.decode()
|
||||||
if self.path_to_vm.lstrip("~/") in output:
|
output: List[str] = output.splitlines()
|
||||||
|
#if self.path_to_vm.lstrip("~/") in output:
|
||||||
|
if self.path_to_vm in output:
|
||||||
print("VM is running.")
|
print("VM is running.")
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -89,9 +129,10 @@ class DesktopEnv(gym.Env):
|
|||||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_path])
|
||||||
|
|
||||||
def _get_screenshot(self):
|
def _get_screenshot(self):
|
||||||
random_uuid = str(uuid.uuid4())
|
#random_uuid = str(uuid.uuid4())
|
||||||
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
#os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||||
image_path = os.path.join("tmp", random_uuid, "screenshot.png")
|
#image_path = os.path.join("tmp", random_uuid, "screenshot.png")
|
||||||
|
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
|
||||||
|
|
||||||
# Get the screenshot and save to the image_path
|
# Get the screenshot and save to the image_path
|
||||||
screenshot = self.controller.get_screenshot()
|
screenshot = self.controller.get_screenshot()
|
||||||
@@ -107,6 +148,16 @@ class DesktopEnv(gym.Env):
|
|||||||
def reset(self, seed=None, options=None):
|
def reset(self, seed=None, options=None):
|
||||||
print("Resetting environment...")
|
print("Resetting environment...")
|
||||||
|
|
||||||
|
print("Setting counters...")
|
||||||
|
self._traj_no += 1
|
||||||
|
self._step_no = 0
|
||||||
|
|
||||||
|
print("Setup new temp dir...")
|
||||||
|
self.tmp_dir = tempfile.mkdtemp( prefix="{:d}.{:}.".format(self._traj_no, self.task_id)
|
||||||
|
, dir=self.tmp_dir_base
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
|
||||||
|
|
||||||
print("Reverting to snapshot to {}...".format(self.snapshot_path))
|
print("Reverting to snapshot to {}...".format(self.snapshot_path))
|
||||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
@@ -125,6 +176,8 @@ class DesktopEnv(gym.Env):
|
|||||||
return observation
|
return observation
|
||||||
|
|
||||||
def step(self, action, pause=0.5):
|
def step(self, action, pause=0.5):
|
||||||
|
self._step_no += 1
|
||||||
|
|
||||||
# fixme: add reminding logic here, decide if the action is valid for the current action_space
|
# fixme: add reminding logic here, decide if the action is valid for the current action_space
|
||||||
if self.action_space == "computer_13":
|
if self.action_space == "computer_13":
|
||||||
# the set of all possible actions defined in the action representation
|
# the set of all possible actions defined in the action representation
|
||||||
@@ -149,9 +202,13 @@ class DesktopEnv(gym.Env):
|
|||||||
Evaluate whether the task is successfully completed.
|
Evaluate whether the task is successfully completed.
|
||||||
"""
|
"""
|
||||||
def copy_file_to_local(_file_info):
|
def copy_file_to_local(_file_info):
|
||||||
random_uuid = str(uuid.uuid4())
|
#random_uuid = str(uuid.uuid4())
|
||||||
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
#os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||||
_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
|
#_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
|
||||||
|
_path = os.path.join(self.cache_dir, _file_info["dest"])
|
||||||
|
if os.path.exists(_path):
|
||||||
|
return _path
|
||||||
|
|
||||||
if _file_info["type"] == "cloud_file":
|
if _file_info["type"] == "cloud_file":
|
||||||
url = _file_info["path"]
|
url = _file_info["path"]
|
||||||
response = requests.get(url, stream=True, verify=False)
|
response = requests.get(url, stream=True, verify=False)
|
||||||
|
|||||||
@@ -31,11 +31,13 @@
|
|||||||
"paths": {
|
"paths": {
|
||||||
"expected": {
|
"expected": {
|
||||||
"type": "cloud_file",
|
"type": "cloud_file",
|
||||||
"path": "http://101.43.24.67/s/BAfFwa3689XTYoo/download/Quarterly_Product_Sales_by_Zone_gold.xlsx"
|
"path": "http://101.43.24.67/s/BAfFwa3689XTYoo/download/Quarterly_Product_Sales_by_Zone_gold.xlsx",
|
||||||
|
"dest": "Quarterly_Product_Sales_by_Zone_gold.xlsx"
|
||||||
},
|
},
|
||||||
"actual": {
|
"actual": {
|
||||||
"type": "vm_file",
|
"type": "vm_file",
|
||||||
"path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx"
|
"path": "/home/david/Quarterly_Product_Sales_by_Zone.xlsx",
|
||||||
|
"dest": "Quarterly_Product_Sales_by_Zone.xlsx"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
1
main.py
1
main.py
@@ -16,6 +16,7 @@ def human_agent():
|
|||||||
, action_space="computer_13"
|
, action_space="computer_13"
|
||||||
#, snapshot_path="base_setup"
|
#, snapshot_path="base_setup"
|
||||||
, snapshot_path="Init6"
|
, snapshot_path="Init6"
|
||||||
|
, task_id=example["id"]
|
||||||
, instruction=example["instruction"]
|
, instruction=example["instruction"]
|
||||||
, config=example["config"]
|
, config=example["config"]
|
||||||
, evaluator=example["evaluator"]
|
, evaluator=example["evaluator"]
|
||||||
|
|||||||
Reference in New Issue
Block a user