1. fix quote and \ characters in execute_command ; 2. add terminal output text as extra observation ; 3. move get_vm_*() to reset()

This commit is contained in:
rhythmcao
2024-01-12 18:09:05 +08:00
parent 186df65683
commit d4116458ff
3 changed files with 56 additions and 8 deletions

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger("desktopenv.pycontroller")
class PythonController:
def __init__(self, vm_ip: str, pkgs_prefix: str = "python -c \"import pyautogui; {command}\""):
def __init__(self, vm_ip: str, pkgs_prefix: str = "import pyautogui; {command}"):
self.vm_ip = vm_ip
self.http_server = f"http://{vm_ip}:5000"
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
@@ -27,6 +27,16 @@ class PythonController:
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
return None
def get_terminal_output(self):
""" Gets the terminal output from the server. None -> no terminal output or unexpected error.
"""
response = requests.get(self.http_server + "/terminal")
if response.status_code == 200:
return response.json()["output"]
else:
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
return None
def get_accessibility_tree(self) -> Optional[str]:
response: requests.Response = requests.get(self.http_server + "/accessibility")
@@ -53,8 +63,8 @@ class PythonController:
Executes a python command on the server.
It can be used to execute the pyautogui commands, or... any other python command. who knows?
"""
command = self.pkgs_prefix.format(command=command)
payload = json.dumps({"command": command, "shell": True})
command_list = ["python", "-c", self.pkgs_prefix.format(command=command)]
payload = json.dumps({"command": command_list, "shell": False})
headers = {
'Content-Type': 'application/json'
}
@@ -186,7 +196,8 @@ class PythonController:
elif action_type == "TYPING":
if "text" not in parameters:
raise Exception(f"Unknown parameters: {parameters}")
text = parameters["text"]
# deal with special ' and \ characters
text = parameters["text"].replace("\\", "\\\\").replace("'", "\\'")
self.execute_python_command(f"pyautogui.typewrite('{text}')")
elif action_type == "PRESS":

View File

@@ -83,9 +83,9 @@ class DesktopEnv(gym.Env):
self.controller = PythonController(vm_ip=self.vm_ip)
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir)
# Meta info of the VM
self.vm_platform = self.controller.get_vm_platform()
self.vm_screen_size = self.controller.get_vm_screen_size()
# Meta info of the VM, move to the reset() function
self.vm_platform: str = "" # self.controller.get_vm_platform()
self.vm_screen_size = None # self.controller.get_vm_screen_size()
# mode: human or machine
assert action_space in ["computer_13", "pyautogui"]
@@ -188,6 +188,9 @@ class DesktopEnv(gym.Env):
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
time.sleep(5)
self.vm_platform = self.controller.get_vm_platform()
self.vm_screen_size = self.controller.get_vm_screen_size()
logger.info("Starting emulator...")
self._start_emulator()
logger.info("Emulator started.")
@@ -217,6 +220,7 @@ class DesktopEnv(gym.Env):
time.sleep(pause)
observation = {
"screenshot": self._get_obs(),
"terminal": self.controller.get_terminal_output(),
"instruction": self.instruction
}
reward = 0 # todo: Define reward calculation for each example

View File

@@ -14,7 +14,7 @@ from pyatspi import Value as ATValue
from pyatspi import Action as ATAction
from typing import List, Dict
from typing import Any
from typing import Any, Optional
import Xlib
import pyautogui
@@ -114,6 +114,39 @@ def capture_screen_with_cursor():
return send_file(file_path, mimetype='image/png')
def _has_active_terminal(desktop: Accessible) -> bool:
""" A quick check whether the terminal window is open and active.
"""
for app in desktop:
if app.getRoleName() == "application" and app.name == "gnome-terminal-server":
for frame in app:
if frame.getRoleName() == "frame" and frame.getState().contains(pyatspi.STATE_ACTIVE):
return True
return False
@app.route('/terminal', methods=['GET'])
def get_terminal_output():
user_platform = platform.system()
output: Optional[str] = None
try:
if user_platform == "Linux":
desktop: Accessible = pyatspi.Registry.getDesktop(0)
if _has_active_terminal(desktop):
desktop_xml: _Element = _create_node(desktop)
# 1. the terminal window (frame of application is st:active) is open and active
# 2. the terminal tab (terminal status is st:focused) is focused
xpath = '//application[@name="gnome-terminal-server"]/frame[@st:active="true"]//terminal[@st:focused="true"]'
terminals: List[_Element] = desktop_xml.xpath(xpath, namespaces=_accessibility_ns_map)
output = terminals[0].text.rstrip() if len(terminals) == 1 else None
else: # windows and macos platform is not implemented currently
raise NotImplementedError
return jsonify({"output": output, "status": "success"})
except:
return jsonify({"output": None, "status": "error"})
_accessibility_ns_map = { "st": "uri:deskat:state.at-spi.gnome.org"
, "attr": "uri:deskat:attributes.at-spi.gnome.org"
, "cp": "uri:deskat:component.at-spi.gnome.org"