1. fix quote and \ characters in execute_command ; 2. add terminal output text as extra observation ; 3. move get_vm_*() to reset()

This commit is contained in:
rhythmcao
2024-01-12 18:09:05 +08:00
parent 186df65683
commit d4116458ff
3 changed files with 56 additions and 8 deletions

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger("desktopenv.pycontroller")
class PythonController: class PythonController:
def __init__(self, vm_ip: str, pkgs_prefix: str = "python -c \"import pyautogui; {command}\""): def __init__(self, vm_ip: str, pkgs_prefix: str = "import pyautogui; {command}"):
self.vm_ip = vm_ip self.vm_ip = vm_ip
self.http_server = f"http://{vm_ip}:5000" self.http_server = f"http://{vm_ip}:5000"
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
@@ -27,6 +27,16 @@ class PythonController:
logger.error("Failed to get screenshot. Status code: %d", response.status_code) logger.error("Failed to get screenshot. Status code: %d", response.status_code)
return None return None
def get_terminal_output(self):
""" Gets the terminal output from the server. None -> no terminal output or unexpected error.
"""
response = requests.get(self.http_server + "/terminal")
if response.status_code == 200:
return response.json()["output"]
else:
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
return None
def get_accessibility_tree(self) -> Optional[str]: def get_accessibility_tree(self) -> Optional[str]:
response: requests.Response = requests.get(self.http_server + "/accessibility") response: requests.Response = requests.get(self.http_server + "/accessibility")
@@ -53,8 +63,8 @@ class PythonController:
Executes a python command on the server. Executes a python command on the server.
It can be used to execute the pyautogui commands, or... any other python command. who knows? It can be used to execute the pyautogui commands, or... any other python command. who knows?
""" """
command = self.pkgs_prefix.format(command=command) command_list = ["python", "-c", self.pkgs_prefix.format(command=command)]
payload = json.dumps({"command": command, "shell": True}) payload = json.dumps({"command": command_list, "shell": False})
headers = { headers = {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
@@ -186,7 +196,8 @@ class PythonController:
elif action_type == "TYPING": elif action_type == "TYPING":
if "text" not in parameters: if "text" not in parameters:
raise Exception(f"Unknown parameters: {parameters}") raise Exception(f"Unknown parameters: {parameters}")
text = parameters["text"] # deal with special ' and \ characters
text = parameters["text"].replace("\\", "\\\\").replace("'", "\\'")
self.execute_python_command(f"pyautogui.typewrite('{text}')") self.execute_python_command(f"pyautogui.typewrite('{text}')")
elif action_type == "PRESS": elif action_type == "PRESS":

View File

@@ -83,9 +83,9 @@ class DesktopEnv(gym.Env):
self.controller = PythonController(vm_ip=self.vm_ip) self.controller = PythonController(vm_ip=self.vm_ip)
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir) self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir)
# Meta info of the VM # Meta info of the VM, move to the reset() function
self.vm_platform = self.controller.get_vm_platform() self.vm_platform: str = "" # self.controller.get_vm_platform()
self.vm_screen_size = self.controller.get_vm_screen_size() self.vm_screen_size = None # self.controller.get_vm_screen_size()
# mode: human or machine # mode: human or machine
assert action_space in ["computer_13", "pyautogui"] assert action_space in ["computer_13", "pyautogui"]
@@ -188,6 +188,9 @@ class DesktopEnv(gym.Env):
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path]) _execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_path])
time.sleep(5) time.sleep(5)
self.vm_platform = self.controller.get_vm_platform()
self.vm_screen_size = self.controller.get_vm_screen_size()
logger.info("Starting emulator...") logger.info("Starting emulator...")
self._start_emulator() self._start_emulator()
logger.info("Emulator started.") logger.info("Emulator started.")
@@ -217,6 +220,7 @@ class DesktopEnv(gym.Env):
time.sleep(pause) time.sleep(pause)
observation = { observation = {
"screenshot": self._get_obs(), "screenshot": self._get_obs(),
"terminal": self.controller.get_terminal_output(),
"instruction": self.instruction "instruction": self.instruction
} }
reward = 0 # todo: Define reward calculation for each example reward = 0 # todo: Define reward calculation for each example

View File

@@ -14,7 +14,7 @@ from pyatspi import Value as ATValue
from pyatspi import Action as ATAction from pyatspi import Action as ATAction
from typing import List, Dict from typing import List, Dict
from typing import Any from typing import Any, Optional
import Xlib import Xlib
import pyautogui import pyautogui
@@ -114,6 +114,39 @@ def capture_screen_with_cursor():
return send_file(file_path, mimetype='image/png') return send_file(file_path, mimetype='image/png')
def _has_active_terminal(desktop: Accessible) -> bool:
""" A quick check whether the terminal window is open and active.
"""
for app in desktop:
if app.getRoleName() == "application" and app.name == "gnome-terminal-server":
for frame in app:
if frame.getRoleName() == "frame" and frame.getState().contains(pyatspi.STATE_ACTIVE):
return True
return False
@app.route('/terminal', methods=['GET'])
def get_terminal_output():
user_platform = platform.system()
output: Optional[str] = None
try:
if user_platform == "Linux":
desktop: Accessible = pyatspi.Registry.getDesktop(0)
if _has_active_terminal(desktop):
desktop_xml: _Element = _create_node(desktop)
# 1. the terminal window (frame of application is st:active) is open and active
# 2. the terminal tab (terminal status is st:focused) is focused
xpath = '//application[@name="gnome-terminal-server"]/frame[@st:active="true"]//terminal[@st:focused="true"]'
terminals: List[_Element] = desktop_xml.xpath(xpath, namespaces=_accessibility_ns_map)
output = terminals[0].text.rstrip() if len(terminals) == 1 else None
else: # windows and macos platform is not implemented currently
raise NotImplementedError
return jsonify({"output": output, "status": "success"})
except:
return jsonify({"output": None, "status": "error"})
_accessibility_ns_map = { "st": "uri:deskat:state.at-spi.gnome.org" _accessibility_ns_map = { "st": "uri:deskat:state.at-spi.gnome.org"
, "attr": "uri:deskat:attributes.at-spi.gnome.org" , "attr": "uri:deskat:attributes.at-spi.gnome.org"
, "cp": "uri:deskat:component.at-spi.gnome.org" , "cp": "uri:deskat:component.at-spi.gnome.org"