Add Support for QWEN models from API (QWEN-max, etc.); Improve on the robustness of getting observation

This commit is contained in:
Timothyxxx
2024-05-20 00:47:43 +08:00
parent 25e808cc91
commit f9594e476e
3 changed files with 60 additions and 38 deletions

View File

@@ -11,12 +11,13 @@ logger = logging.getLogger("desktopenv.pycontroller")
class PythonController:
def __init__(self, vm_ip: str, pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"):
def __init__(self, vm_ip: str,
pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"):
self.vm_ip = vm_ip
self.http_server = f"http://{vm_ip}:5000"
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
def get_screenshot(self):
def get_screenshot(self, retry_times=20):
"""
Gets a screenshot from the server. With the cursor.
"""
@@ -24,25 +25,43 @@ class PythonController:
if response.status_code == 200:
return response.content
else:
for _ in range(retry_times):
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
logger.info("Retrying to get screenshot.")
response = requests.get(self.http_server + "/screenshot")
if response.status_code == 200:
return response.content
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
return None
def get_terminal_output(self):
def get_terminal_output(self, retry_times=20):
""" Gets the terminal output from the server. None -> no terminal output or unexpected error.
"""
response = requests.get(self.http_server + "/terminal")
if response.status_code == 200:
return response.json()["output"]
else:
for _ in range(retry_times):
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
logger.info("Retrying to get terminal output.")
response = requests.get(self.http_server + "/terminal")
if response.status_code == 200:
return response.json()["output"]
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
return None
def get_accessibility_tree(self) -> Optional[str]:
def get_accessibility_tree(self, retry_times=20) -> Optional[str]:
response: requests.Response = requests.get(self.http_server + "/accessibility")
if response.status_code == 200:
return response.json()["AT"]
else:
for _ in range(retry_times):
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
logger.info("Retrying to get accessibility tree.")
response = requests.get(self.http_server + "/accessibility")
if response.status_code == 200:
return response.json()["AT"]
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
return None

View File

@@ -138,24 +138,9 @@ class DesktopEnv(gym.Env):
def _save_state(self):
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
def _get_screenshot(self):
screenshot = None
# Get the screenshot and save to the image_path
max_retries = 20
for _ in range(max_retries):
screenshot = self.controller.get_screenshot()
if screenshot is not None:
break
time.sleep(1)
if screenshot is None:
logger.error("Failed to get screenshot!")
return screenshot
def _get_obs(self):
return {
"screenshot": self._get_screenshot(),
"screenshot": self.controller.get_screenshot(),
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
"instruction": self.instruction