Initialize evaluation protocols and examples; Implement one kind of eval; Update requirements
This commit is contained in:
@@ -20,6 +20,18 @@ class PythonController:
|
||||
print("Failed to get screenshot. Status code:", response.status_code)
|
||||
return None
|
||||
|
||||
def get_file(self, file_path: str):
|
||||
"""
|
||||
Gets a file from the server.
|
||||
"""
|
||||
response = requests.post(self.http_server + "/file", data={"file_path": file_path})
|
||||
if response.status_code == 200:
|
||||
print("File downloaded successfully")
|
||||
return response.content
|
||||
else:
|
||||
print("Failed to get file. Status code:", response.status_code)
|
||||
return None
|
||||
|
||||
def execute_python_command(self, command: str) -> None:
|
||||
"""
|
||||
Executes a python command on the server.
|
||||
|
||||
@@ -8,10 +8,11 @@ import platform
|
||||
from typing import List
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
import requests
|
||||
|
||||
from desktop_env.controllers.python import PythonController
|
||||
from desktop_env.controllers.setup import SetupController
|
||||
from desktop_env.evaluators import eval_funcs
|
||||
|
||||
|
||||
def _execute_command(command: List[str]) -> None:
|
||||
@@ -32,7 +33,9 @@ class DesktopEnv(gym.Env):
|
||||
self,
|
||||
path_to_vm: str,
|
||||
snapshot_path: str = "base",
|
||||
instruction: str = None,
|
||||
config: dict = None,
|
||||
evaluator: dict = None,
|
||||
action_space: str = "computer_13",
|
||||
):
|
||||
# Initialize environment variables
|
||||
@@ -45,7 +48,9 @@ class DesktopEnv(gym.Env):
|
||||
self.host = f"http://{self._get_vm_ip()}:5000"
|
||||
self.controller = PythonController(http_server=self.host)
|
||||
self.setup_controller = SetupController(http_server=self.host)
|
||||
self.instruction = instruction
|
||||
self.config = config
|
||||
self.evaluator = evaluator
|
||||
|
||||
# mode: human or machine
|
||||
assert action_space in ["computer_13", "pyautogui"]
|
||||
@@ -113,6 +118,9 @@ class DesktopEnv(gym.Env):
|
||||
print("Setting up environment...")
|
||||
self.setup_controller.setup(self.config)
|
||||
|
||||
time.sleep(5)
|
||||
print("Environment setup complete.")
|
||||
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
@@ -127,12 +135,52 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
# todo: maybe for the better here we need to add a logic to wait until the rendering is done
|
||||
time.sleep(pause)
|
||||
observation = self._get_obs()
|
||||
observation = {
|
||||
"screenshot": self._get_obs(),
|
||||
"instruction": self.instruction
|
||||
}
|
||||
reward = 0 # todo: Define reward calculation for each example
|
||||
done = False # todo: Define episode termination condition for each example
|
||||
info = {}
|
||||
return observation, reward, done, info
|
||||
|
||||
def evaluate(self):
|
||||
"""
|
||||
Evaluate whether the task is successfully completed.
|
||||
"""
|
||||
def copy_file_to_local(_file_info):
|
||||
random_uuid = str(uuid.uuid4())
|
||||
os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||
_path = os.path.join("tmp", random_uuid, "tmp.xlsx")
|
||||
if _file_info["type"] == "cloud_file":
|
||||
url = _file_info["path"]
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
elif _file_info["type"] == "vm_file":
|
||||
# fixme: stream this part maybe as well
|
||||
file = self.controller.get_file(_file_info["path"])
|
||||
with open(_path, "wb") as f:
|
||||
f.write(file)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return _path
|
||||
|
||||
# todo: make this more flexible by refactoring
|
||||
eval_func = eval_funcs[self.evaluator["func"]]
|
||||
eval_func_vars = {}
|
||||
|
||||
for var_name, file_info in self.evaluator["paths"].items():
|
||||
path = copy_file_to_local(file_info)
|
||||
eval_func_vars[var_name] = path
|
||||
|
||||
return eval_func(**eval_func_vars)
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
return self._get_obs()
|
||||
|
||||
@@ -74,6 +74,22 @@ def capture_screen_with_cursor():
|
||||
return send_file(file_path, mimetype='image/png')
|
||||
|
||||
|
||||
@app.route('/file', methods=['POST'])
|
||||
def get_file():
|
||||
# Retrieve filename from the POST request
|
||||
if 'file_path' in request.form:
|
||||
file_path = request.form['file_path']
|
||||
else:
|
||||
return jsonify({"error": "file_path is required"}), 400
|
||||
|
||||
try:
|
||||
# Check if the file exists and send it to the user
|
||||
return send_file(file_path, as_attachment=True)
|
||||
except FileNotFoundError:
|
||||
# If the file is not found, return a 404 error
|
||||
return jsonify({"error": "File not found"}), 404
|
||||
|
||||
|
||||
@app.route('/platform', methods=['GET'])
|
||||
def get_platform():
|
||||
return platform.system()
|
||||
|
||||
Reference in New Issue
Block a user