Merge branch 'zdy'

This commit is contained in:
David Chang
2023-12-22 15:23:56 +08:00
2 changed files with 43 additions and 46 deletions

View File

@@ -6,7 +6,7 @@ import time
#import uuid
#import platform
from typing import List, Dict
from typing import Callable, Any
from typing import Callable, Any, Optional
import tempfile
import gymnasium as gym
@@ -37,40 +37,24 @@ class DesktopEnv(gym.Env):
def __init__(
self,
path_to_vm: str,
snapshot_path: str = "base",
task_id: str = "",
instruction: str = None,
config: dict = None,
evaluator: dict = None,
action_space: str = "computer_13",
task_config: Dict[str, Any] = None,
tmp_dir: str = "tmp",
cache_dir: str = "cache"
):
"""
Args:
path_to_vm (str): path to .vmx file
snapshot_path (str): name of the based snapshot, can be found in
`snapshot*.displayName` field in .vmsd file
task_id (str): identifying the task
instruction (str): task instruction
config (List[Dict[str, Any]]): the config dict, refer to the doc of
SetupController
evaluator (Dict[str, Any]): defines the evaluation method. dict
like
{
"func": str as the evaluation method
"result": {
"type": str
...other keys
}
"condition": {
"type": str
...other keys
}
}
action_space (str): "computer_13" | "pyautogui"
task_config (Dict[str, Any]): manages task configs integratedly,
including
* base snapshot
* task id (uuid)
* instruction
* setup config
* evaluator config
tmp_dir (str): temporary directory to store trajectory stuffs like
the extracted screenshots
cache_dir (str): cache directory to cache task-related stuffs like
@@ -79,13 +63,8 @@ class DesktopEnv(gym.Env):
# Initialize environment variables
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
self.task_id: str = task_id
self.tmp_dir_base: str = tmp_dir
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
self.cache_dir: str = os.path.join(cache_dir, task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.cache_dir_base: str = cache_dir
# Initialize emulator and controller
print("Initializing...")
@@ -93,20 +72,27 @@ class DesktopEnv(gym.Env):
self.host = f"http://{self._get_vm_ip()}:5000"
self.controller = PythonController(http_server=self.host)
self.setup_controller = SetupController(http_server=self.host)
self.instruction = instruction
self.config = config
self.evaluator = evaluator
self.metric: Metric = getattr(metrics, evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(evaluator["expected"]["type"]))
# mode: human or machine
assert action_space in ["computer_13", "pyautogui"]
self.action_space = action_space
# todo: define the action space and the observation space as gym did, or extend theirs
# counters
# task-aware stuffs
self.snapshot_path = task_config["snapshot"] # todo: handling the logic of snapshot directory
self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"]
self.evaluator = task_config["evaluator"]
self.metric: Metric = getattr(metrics, self.evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
# episodic stuffs, like tmp dir and counters
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
self._traj_no: int = -1
self._step_no: int = 0
@@ -160,9 +146,23 @@ class DesktopEnv(gym.Env):
screenshot_image_path = self._get_screenshot()
return screenshot_image_path
def reset(self, seed=None, options=None):
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None):
print("Resetting environment...")
print("Switching task...")
if task_config is not None:
self.snapshot_path = task_config["snapshot"]
self.task_id = task_config["id"]
self.cache_dir = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"]
self.evaluator = task_config["evaluator"]
self.metric: Metric = getattr(metrics, self.evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
print("Setting counters...")
self._traj_no += 1
self._step_no = 0

View File

@@ -9,14 +9,11 @@ def human_agent():
with open("evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json", "r") as f:
example = json.load(f)
example["snapshot"] = "base_setup"
env = DesktopEnv( path_to_vm=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx"
, action_space="computer_13"
, snapshot_path="base_setup"
, task_id=example["id"]
, instruction=example["instruction"]
, config=example["config"]
, evaluator=example["evaluator"]
, task_config=example
)
# reset the environment to certain snapshot