Add Support for QWEN VL models from API (QWEN-VL-max, etc.); Improve on the robustness of getting observation/files, etc.

This commit is contained in:
Timothyxxx
2024-05-21 21:08:22 +08:00
parent f9594e476e
commit 306dcbda71
2 changed files with 251 additions and 127 deletions

View File

@@ -2,7 +2,7 @@ import json
import logging
import random
from typing import Any, Dict, Optional
import time
import requests
from desktop_env.envs.actions import KEYBOARD_KEYS
@@ -16,66 +16,96 @@ class PythonController:
self.vm_ip = vm_ip
self.http_server = f"http://{vm_ip}:5000"
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
self.retry_times = 3
self.retry_interval = 5
def get_screenshot(self, retry_times=20):
def get_screenshot(self) -> Optional[bytes]:
"""
Gets a screenshot from the server. With the cursor.
Gets a screenshot from the server. With the cursor. None -> no screenshot or unexpected error.
"""
response = requests.get(self.http_server + "/screenshot")
if response.status_code == 200:
return response.content
else:
for _ in range(retry_times):
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
logger.info("Retrying to get screenshot.")
for _ in range(self.retry_times):
try:
response = requests.get(self.http_server + "/screenshot")
if response.status_code == 200:
logger.info("Got screenshot successfully")
return response.content
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
return None
else:
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
logger.info("Retrying to get screenshot.")
except Exception as e:
logger.error("An error occurred while trying to get the screenshot: %s", e)
logger.info("Retrying to get screenshot.")
time.sleep(self.retry_interval)
def get_terminal_output(self, retry_times=20):
""" Gets the terminal output from the server. None -> no terminal output or unexpected error.
logger.error("Failed to get screenshot.")
return None
def get_accessibility_tree(self) -> Optional[str]:
"""
response = requests.get(self.http_server + "/terminal")
if response.status_code == 200:
return response.json()["output"]
else:
for _ in range(retry_times):
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
logger.info("Retrying to get terminal output.")
Gets the accessibility tree from the server. None -> no accessibility tree or unexpected error.
"""
for _ in range(self.retry_times):
try:
response: requests.Response = requests.get(self.http_server + "/accessibility")
if response.status_code == 200:
logger.info("Got accessibility tree successfully")
return response.json()["AT"]
else:
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
logger.info("Retrying to get accessibility tree.")
except Exception as e:
logger.error("An error occurred while trying to get the accessibility tree: %s", e)
logger.info("Retrying to get accessibility tree.")
time.sleep(self.retry_interval)
logger.error("Failed to get accessibility tree.")
return None
def get_terminal_output(self) -> Optional[str]:
"""
Gets the terminal output from the server. None -> no terminal output or unexpected error.
"""
for _ in range(self.retry_times):
try:
response = requests.get(self.http_server + "/terminal")
if response.status_code == 200:
logger.info("Got terminal output successfully")
return response.json()["output"]
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
return None
else:
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
logger.info("Retrying to get terminal output.")
except Exception as e:
logger.error("An error occurred while trying to get the terminal output: %s", e)
logger.info("Retrying to get terminal output.")
time.sleep(self.retry_interval)
def get_accessibility_tree(self, retry_times=20) -> Optional[str]:
logger.error("Failed to get terminal output.")
return None
response: requests.Response = requests.get(self.http_server + "/accessibility")
if response.status_code == 200:
return response.json()["AT"]
else:
for _ in range(retry_times):
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
logger.info("Retrying to get accessibility tree.")
response = requests.get(self.http_server + "/accessibility")
if response.status_code == 200:
return response.json()["AT"]
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
return None
def get_file(self, file_path: str):
def get_file(self, file_path: str) -> Optional[bytes]:
"""
Gets a file from the server.
"""
response = requests.post(self.http_server + "/file", data={"file_path": file_path})
if response.status_code == 200:
logger.info("File downloaded successfully")
return response.content
else:
logger.error("Failed to get file. Status code: %d", response.status_code)
return None
for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/file", data={"file_path": file_path})
if response.status_code == 200:
logger.info("File downloaded successfully")
return response.content
else:
logger.error("Failed to get file. Status code: %d", response.status_code)
logger.info("Retrying to get file.")
except Exception as e:
logger.error("An error occurred while trying to get the file: %s", e)
logger.info("Retrying to get file.")
time.sleep(self.retry_interval)
logger.error("Failed to get file.")
return None
def execute_python_command(self, command: str) -> None:
"""
@@ -85,19 +115,26 @@ class PythonController:
# command_list = ["python", "-c", self.pkgs_prefix.format(command=command)]
command_list = ["python", "-c", self.pkgs_prefix.format(command=command)]
payload = json.dumps({"command": command_list, "shell": False})
headers = {
'Content-Type': 'application/json'
}
try:
response = requests.post(self.http_server + "/execute", headers=headers, data=payload, timeout=90)
if response.status_code == 200:
logger.info("Command executed successfully: %s", response.text)
else:
logger.error("Failed to execute command. Status code: %d", response.status_code)
return response.json()
except requests.exceptions.RequestException as e:
logger.error("An error occurred while trying to execute the command: %s", e)
for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/execute", headers={'Content-Type': 'application/json'},
data=payload, timeout=90)
if response.status_code == 200:
logger.info("Command executed successfully: %s", response.text)
return response.json()
else:
logger.error("Failed to execute command. Status code: %d", response.status_code)
logger.info("Retrying to execute command.")
except requests.exceptions.ReadTimeout:
break
except Exception as e:
logger.error("An error occurred while trying to execute the command: %s", e)
logger.info("Retrying to execute command.")
time.sleep(self.retry_interval)
logger.error("Failed to execute command.")
return None
def execute_action(self, action: Dict[str, Any]):
"""
@@ -272,29 +309,47 @@ class PythonController:
"""
Starts recording the screen.
"""
response = requests.post(self.http_server + "/start_recording")
if response.status_code == 200:
logger.info("Recording started successfully")
else:
logger.error("Failed to start recording. Status code: %d", response.status_code)
for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/start_recording")
if response.status_code == 200:
logger.info("Recording started successfully")
return
else:
logger.error("Failed to start recording. Status code: %d", response.status_code)
logger.info("Retrying to start recording.")
except Exception as e:
logger.error("An error occurred while trying to start recording: %s", e)
logger.info("Retrying to start recording.")
time.sleep(self.retry_interval)
logger.error("Failed to start recording.")
def end_recording(self, dest: str):
"""
Ends recording the screen.
"""
try:
response = requests.post(self.http_server + "/end_recording")
if response.status_code == 200:
logger.info("Recording stopped successfully")
with open(dest, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
else:
logger.error("Failed to stop recording. Status code: %d", response.status_code)
return None
except Exception as e:
logger.error("An error occurred while trying to download the recording: %s", e)
for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/end_recording")
if response.status_code == 200:
logger.info("Recording stopped successfully")
with open(dest, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return
else:
logger.error("Failed to stop recording. Status code: %d", response.status_code)
logger.info("Retrying to stop recording.")
except Exception as e:
logger.error("An error occurred while trying to stop recording: %s", e)
logger.info("Retrying to stop recording.")
time.sleep(self.retry_interval)
logger.error("Failed to stop recording.")
# Additional info
def get_vm_platform(self):
@@ -307,60 +362,109 @@ class PythonController:
"""
Gets the size of the vm screen.
"""
response = requests.post(self.http_server + "/screen_size")
if response.status_code == 200:
return response.json()
else:
logger.error("Failed to get screen size. Status code: %d", response.status_code)
return None
for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/screen_size")
if response.status_code == 200:
logger.info("Got screen size successfully")
return response.json()
else:
logger.error("Failed to get screen size. Status code: %d", response.status_code)
logger.info("Retrying to get screen size.")
except Exception as e:
logger.error("An error occurred while trying to get the screen size: %s", e)
logger.info("Retrying to get screen size.")
time.sleep(self.retry_interval)
logger.error("Failed to get screen size.")
return None
def get_vm_window_size(self, app_class_name: str):
"""
Gets the size of the vm app window.
"""
response = requests.post(self.http_server + "/window_size", data={"app_class_name": app_class_name})
if response.status_code == 200:
return response.json()
else:
logger.error("Failed to get window size. Status code: %d", response.status_code)
return None
for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/window_size", data={"app_class_name": app_class_name})
if response.status_code == 200:
logger.info("Got window size successfully")
return response.json()
else:
logger.error("Failed to get window size. Status code: %d", response.status_code)
logger.info("Retrying to get window size.")
except Exception as e:
logger.error("An error occurred while trying to get the window size: %s", e)
logger.info("Retrying to get window size.")
time.sleep(self.retry_interval)
logger.error("Failed to get window size.")
return None
def get_vm_wallpaper(self):
"""
Gets the wallpaper of the vm.
"""
response = requests.post(self.http_server + "/wallpaper")
if response.status_code == 200:
logger.info("Wallpaper downloaded successfully")
return response.content
else:
logger.error("Failed to get wallpaper. Status code: %d", response.status_code)
return None
def get_vm_desktop_path(self):
for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/wallpaper")
if response.status_code == 200:
logger.info("Got wallpaper successfully")
return response.content
else:
logger.error("Failed to get wallpaper. Status code: %d", response.status_code)
logger.info("Retrying to get wallpaper.")
except Exception as e:
logger.error("An error occurred while trying to get the wallpaper: %s", e)
logger.info("Retrying to get wallpaper.")
time.sleep(self.retry_interval)
logger.error("Failed to get wallpaper.")
return None
def get_vm_desktop_path(self) -> Optional[str]:
"""
Gets the desktop path of the vm.
"""
response = requests.post(self.http_server + "/desktop_path")
if response.status_code == 200:
logger.info("Desktop path downloaded successfully")
return response.json()["desktop_path"]
else:
logger.error("Failed to get desktop path. Status code: %d", response.status_code)
return None
def get_vm_directory_tree(self, path):
for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/desktop_path")
if response.status_code == 200:
logger.info("Got desktop path successfully")
return response.json()["desktop_path"]
else:
logger.error("Failed to get desktop path. Status code: %d", response.status_code)
logger.info("Retrying to get desktop path.")
except Exception as e:
logger.error("An error occurred while trying to get the desktop path: %s", e)
logger.info("Retrying to get desktop path.")
time.sleep(self.retry_interval)
logger.error("Failed to get desktop path.")
return None
def get_vm_directory_tree(self, path) -> Optional[Dict[str, Any]]:
"""
Gets the directory tree of the vm.
"""
payload = json.dumps({"path": path})
headers = {
'Content-Type': 'application/json'
}
response = requests.post(self.http_server + "/list_directory", headers=headers, data=payload)
if response.status_code == 200:
logger.info("Directory tree downloaded successfully")
return response.json()["directory_tree"]
else:
logger.error("Failed to get directory tree. Status code: %d", response.status_code)
return None
for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/list_directory", headers={'Content-Type': 'application/json'}, data=payload)
if response.status_code == 200:
logger.info("Got directory tree successfully")
return response.json()["directory_tree"]
else:
logger.error("Failed to get directory tree. Status code: %d", response.status_code)
logger.info("Retrying to get directory tree.")
except Exception as e:
logger.error("An error occurred while trying to get directory tree: %s", e)
logger.info("Retrying to get directory tree.")
time.sleep(self.retry_interval)
logger.error("Failed to get directory tree.")
return None