ver Dec22ndv3

added function to switch tasks and reload the configs of setup &
evaluation
This commit is contained in:
David Chang
2023-12-22 15:20:27 +08:00
parent 10c8fbe995
commit fffc4aadca
2 changed files with 43 additions and 47 deletions

View File

@@ -6,7 +6,7 @@ import time
#import uuid #import uuid
#import platform #import platform
from typing import List, Dict from typing import List, Dict
from typing import Callable, Any from typing import Callable, Any, Optional
import tempfile import tempfile
import gymnasium as gym import gymnasium as gym
@@ -37,40 +37,24 @@ class DesktopEnv(gym.Env):
def __init__( def __init__(
self, self,
path_to_vm: str, path_to_vm: str,
snapshot_path: str = "base",
task_id: str = "",
instruction: str = None,
config: dict = None,
evaluator: dict = None,
action_space: str = "computer_13", action_space: str = "computer_13",
task_config: Dict[str, Any] = None,
tmp_dir: str = "tmp", tmp_dir: str = "tmp",
cache_dir: str = "cache" cache_dir: str = "cache"
): ):
""" """
Args: Args:
path_to_vm (str): path to .vmx file path_to_vm (str): path to .vmx file
snapshot_path (str): name of the based snapshot, can be found in
`snapshot*.displayName` field in .vmsd file
task_id (str): identifying the task
instruction (str): task instruction
config (List[Dict[str, Any]]): the config dict, refer to the doc of
SetupController
evaluator (Dict[str, Any]): defines the evaluation method. dict
like
{
"func": str as the evaluation method
"result": {
"type": str
...other keys
}
"condition": {
"type": str
...other keys
}
}
action_space (str): "computer_13" | "pyautogui" action_space (str): "computer_13" | "pyautogui"
task_config (Dict[str, Any]): manages task configs integratedly,
including
* base snapshot
* task id (uuid)
* instruction
* setup config
* evaluator config
tmp_dir (str): temporary directory to store trajectory stuffs like tmp_dir (str): temporary directory to store trajectory stuffs like
the extracted screenshots the extracted screenshots
cache_dir (str): cache directory to cache task-related stuffs like cache_dir (str): cache directory to cache task-related stuffs like
@@ -79,13 +63,8 @@ class DesktopEnv(gym.Env):
# Initialize environment variables # Initialize environment variables
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
self.snapshot_path = snapshot_path # todo: handling the logic of snapshot directory
self.task_id: str = task_id
self.tmp_dir_base: str = tmp_dir self.tmp_dir_base: str = tmp_dir
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset self.cache_dir_base: str = cache_dir
self.cache_dir: str = os.path.join(cache_dir, task_id)
os.makedirs(self.cache_dir, exist_ok=True)
# Initialize emulator and controller # Initialize emulator and controller
print("Initializing...") print("Initializing...")
@@ -93,20 +72,27 @@ class DesktopEnv(gym.Env):
self.host = f"http://{self._get_vm_ip()}:5000" self.host = f"http://{self._get_vm_ip()}:5000"
self.controller = PythonController(http_server=self.host) self.controller = PythonController(http_server=self.host)
self.setup_controller = SetupController(http_server=self.host) self.setup_controller = SetupController(http_server=self.host)
self.instruction = instruction
self.config = config
self.evaluator = evaluator
self.metric: Metric = getattr(metrics, evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(evaluator["expected"]["type"]))
# mode: human or machine # mode: human or machine
assert action_space in ["computer_13", "pyautogui"] assert action_space in ["computer_13", "pyautogui"]
self.action_space = action_space self.action_space = action_space
# todo: define the action space and the observation space as gym did, or extend theirs # todo: define the action space and the observation space as gym did, or extend theirs
# counters # task-aware stuffs
self.snapshot_path = task_config["snapshot"] # todo: handling the logic of snapshot directory
self.task_id: str = task_config["id"]
self.cache_dir: str = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"]
self.evaluator = task_config["evaluator"]
self.metric: Metric = getattr(metrics, self.evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
# episodic stuffs, like tmp dir and counters
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
self._traj_no: int = -1 self._traj_no: int = -1
self._step_no: int = 0 self._step_no: int = 0
@@ -160,9 +146,23 @@ class DesktopEnv(gym.Env):
screenshot_image_path = self._get_screenshot() screenshot_image_path = self._get_screenshot()
return screenshot_image_path return screenshot_image_path
def reset(self, seed=None, options=None): def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None):
print("Resetting environment...") print("Resetting environment...")
print("Switching task...")
if task_config is not None:
self.snapshot_path = task_config["snapshot"]
self.task_id = task_config["id"]
self.cache_dir = os.path.join(self.cache_dir_base, self.task_id)
os.makedirs(self.cache_dir, exist_ok=True)
self.instruction = task_config["instruction"]
self.config = task_config["config"]
self.evaluator = task_config["evaluator"]
self.metric: Metric = getattr(metrics, self.evaluator["func"])
self.result_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["result"]["type"]))
self.expected_getter: Getter = getattr(getters, "get_{:}".format(self.evaluator["expected"]["type"]))
print("Setting counters...") print("Setting counters...")
self._traj_no += 1 self._traj_no += 1
self._step_no = 0 self._step_no = 0

View File

@@ -9,17 +9,13 @@ def human_agent():
with open("evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json", "r") as f: with open("evaluation_examples/examples/f9584479-3d0d-4c79-affa-9ad7afdd8850.json", "r") as f:
example = json.load(f) example = json.load(f)
example["snapshot"] = "Init6"
#env = DesktopEnv( path_to_vm="/home/yuri/vmware/Windows 10 x64/Windows 10 x64.vmx" #env = DesktopEnv( path_to_vm="/home/yuri/vmware/Windows 10 x64/Windows 10 x64.vmx"
# path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx", # path_to_vm="/home/yuri/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
env = DesktopEnv( path_to_vm="/home/david/vmware/KUbuntu 64-bit/KUbuntu 64-bit.vmx" env = DesktopEnv( path_to_vm="/home/david/vmware/KUbuntu 64-bit/KUbuntu 64-bit.vmx"
, action_space="computer_13" , action_space="computer_13"
#, snapshot_path="base_setup" , task_config=example
, snapshot_path="Init6"
, task_id=example["id"]
, instruction=example["instruction"]
, config=example["config"]
, evaluator=example["evaluator"]
) )
# reset the environment to certain snapshot # reset the environment to certain snapshot