Initialize evaluation protocols and examples; Implement one kind of eval; Update requirements

This commit is contained in:
Timothyxxx
2023-12-12 18:10:55 +08:00
parent 4b3af14b8f
commit 2ca36109b5
8 changed files with 139 additions and 50 deletions

View File

@@ -8,10 +8,11 @@ import platform
from typing import List
import gymnasium as gym
import requests
from desktop_env.controllers.python import PythonController
from desktop_env.controllers.setup import SetupController
from desktop_env.evaluators import eval_funcs
def _execute_command(command: List[str]) -> None:
@@ -32,7 +33,9 @@ class DesktopEnv(gym.Env):
self,
path_to_vm: str,
snapshot_path: str = "base",
instruction: str = None,
config: dict = None,
evaluator: dict = None,
action_space: str = "computer_13",
):
# Initialize environment variables
@@ -45,7 +48,9 @@ class DesktopEnv(gym.Env):
self.host = f"http://{self._get_vm_ip()}:5000"
self.controller = PythonController(http_server=self.host)
self.setup_controller = SetupController(http_server=self.host)
self.instruction = instruction
self.config = config
self.evaluator = evaluator
# mode: human or machine
assert action_space in ["computer_13", "pyautogui"]
@@ -113,6 +118,9 @@ class DesktopEnv(gym.Env):
print("Setting up environment...")
self.setup_controller.setup(self.config)
time.sleep(5)
print("Environment setup complete.")
observation = self._get_obs()
return observation
@@ -127,12 +135,52 @@ class DesktopEnv(gym.Env):
# todo: maybe for the better here we need to add a logic to wait until the rendering is done
time.sleep(pause)
observation = self._get_obs()
observation = {
"screenshot": self._get_obs(),
"instruction": self.instruction
}
reward = 0 # todo: Define reward calculation for each example
done = False # todo: Define episode termination condition for each example
info = {}
return observation, reward, done, info
def evaluate(self):
"""
Evaluate whether the task is successfully completed.
"""
def copy_file_to_local(_file_info):
random_uuid = str(uuid.uuid4())
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
if _file_info["type"] == "cloud_file":
url = _file_info["path"]
response = requests.get(url, stream=True)
response.raise_for_status()
with open(_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
elif _file_info["type"] == "vm_file":
# fixme: stream this part maybe as well
file = self.controller.get_file(_file_info["path"])
with open(_path, "wb") as f:
f.write(file)
else:
raise NotImplementedError
return _path
# todo: make this more flexible by refactoring
eval_func = eval_funcs[self.evaluator["func"]]
eval_func_vars = {}
for var_name, file_info in self.evaluator["paths"].items():
path = copy_file_to_local(file_info)
eval_func_vars[var_name] = path
return eval_func(**eval_func_vars)
def render(self, mode='rgb_array'):
if mode == 'rgb_array':
return self._get_obs()