Add Support for QWEN VL models from API (QWEN-VL-max, etc.); Improve on the robustness of getting observation/files, etc.

This commit is contained in:
Timothyxxx
2024-05-21 21:08:22 +08:00
parent f9594e476e
commit 306dcbda71
2 changed files with 251 additions and 127 deletions

View File

@@ -2,7 +2,7 @@ import json
import logging import logging
import random import random
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import time
import requests import requests
from desktop_env.envs.actions import KEYBOARD_KEYS from desktop_env.envs.actions import KEYBOARD_KEYS
@@ -16,66 +16,96 @@ class PythonController:
self.vm_ip = vm_ip self.vm_ip = vm_ip
self.http_server = f"http://{vm_ip}:5000" self.http_server = f"http://{vm_ip}:5000"
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
self.retry_times = 3
self.retry_interval = 5
def get_screenshot(self, retry_times=20): def get_screenshot(self) -> Optional[bytes]:
""" """
Gets a screenshot from the server. With the cursor. Gets a screenshot from the server. With the cursor. None -> no screenshot or unexpected error.
""" """
response = requests.get(self.http_server + "/screenshot")
if response.status_code == 200: for _ in range(self.retry_times):
return response.content try:
else:
for _ in range(retry_times):
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
logger.info("Retrying to get screenshot.")
response = requests.get(self.http_server + "/screenshot") response = requests.get(self.http_server + "/screenshot")
if response.status_code == 200: if response.status_code == 200:
logger.info("Got screenshot successfully")
return response.content return response.content
logger.error("Failed to get screenshot. Status code: %d", response.status_code) else:
return None logger.error("Failed to get screenshot. Status code: %d", response.status_code)
logger.info("Retrying to get screenshot.")
except Exception as e:
logger.error("An error occurred while trying to get the screenshot: %s", e)
logger.info("Retrying to get screenshot.")
time.sleep(self.retry_interval)
def get_terminal_output(self, retry_times=20): logger.error("Failed to get screenshot.")
""" Gets the terminal output from the server. None -> no terminal output or unexpected error. return None
def get_accessibility_tree(self) -> Optional[str]:
""" """
response = requests.get(self.http_server + "/terminal") Gets the accessibility tree from the server. None -> no accessibility tree or unexpected error.
if response.status_code == 200: """
return response.json()["output"]
else: for _ in range(self.retry_times):
for _ in range(retry_times): try:
logger.error("Failed to get terminal output. Status code: %d", response.status_code) response: requests.Response = requests.get(self.http_server + "/accessibility")
logger.info("Retrying to get terminal output.") if response.status_code == 200:
logger.info("Got accessibility tree successfully")
return response.json()["AT"]
else:
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
logger.info("Retrying to get accessibility tree.")
except Exception as e:
logger.error("An error occurred while trying to get the accessibility tree: %s", e)
logger.info("Retrying to get accessibility tree.")
time.sleep(self.retry_interval)
logger.error("Failed to get accessibility tree.")
return None
def get_terminal_output(self) -> Optional[str]:
"""
Gets the terminal output from the server. None -> no terminal output or unexpected error.
"""
for _ in range(self.retry_times):
try:
response = requests.get(self.http_server + "/terminal") response = requests.get(self.http_server + "/terminal")
if response.status_code == 200: if response.status_code == 200:
logger.info("Got terminal output successfully")
return response.json()["output"] return response.json()["output"]
logger.error("Failed to get terminal output. Status code: %d", response.status_code) else:
return None logger.error("Failed to get terminal output. Status code: %d", response.status_code)
logger.info("Retrying to get terminal output.")
except Exception as e:
logger.error("An error occurred while trying to get the terminal output: %s", e)
logger.info("Retrying to get terminal output.")
time.sleep(self.retry_interval)
def get_accessibility_tree(self, retry_times=20) -> Optional[str]: logger.error("Failed to get terminal output.")
return None
response: requests.Response = requests.get(self.http_server + "/accessibility") def get_file(self, file_path: str) -> Optional[bytes]:
if response.status_code == 200:
return response.json()["AT"]
else:
for _ in range(retry_times):
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
logger.info("Retrying to get accessibility tree.")
response = requests.get(self.http_server + "/accessibility")
if response.status_code == 200:
return response.json()["AT"]
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
return None
def get_file(self, file_path: str):
""" """
Gets a file from the server. Gets a file from the server.
""" """
response = requests.post(self.http_server + "/file", data={"file_path": file_path})
if response.status_code == 200: for _ in range(self.retry_times):
logger.info("File downloaded successfully") try:
return response.content response = requests.post(self.http_server + "/file", data={"file_path": file_path})
else: if response.status_code == 200:
logger.error("Failed to get file. Status code: %d", response.status_code) logger.info("File downloaded successfully")
return None return response.content
else:
logger.error("Failed to get file. Status code: %d", response.status_code)
logger.info("Retrying to get file.")
except Exception as e:
logger.error("An error occurred while trying to get the file: %s", e)
logger.info("Retrying to get file.")
time.sleep(self.retry_interval)
logger.error("Failed to get file.")
return None
def execute_python_command(self, command: str) -> None: def execute_python_command(self, command: str) -> None:
""" """
@@ -85,19 +115,26 @@ class PythonController:
# command_list = ["python", "-c", self.pkgs_prefix.format(command=command)] # command_list = ["python", "-c", self.pkgs_prefix.format(command=command)]
command_list = ["python", "-c", self.pkgs_prefix.format(command=command)] command_list = ["python", "-c", self.pkgs_prefix.format(command=command)]
payload = json.dumps({"command": command_list, "shell": False}) payload = json.dumps({"command": command_list, "shell": False})
headers = {
'Content-Type': 'application/json'
}
try: for _ in range(self.retry_times):
response = requests.post(self.http_server + "/execute", headers=headers, data=payload, timeout=90) try:
if response.status_code == 200: response = requests.post(self.http_server + "/execute", headers={'Content-Type': 'application/json'},
logger.info("Command executed successfully: %s", response.text) data=payload, timeout=90)
else: if response.status_code == 200:
logger.error("Failed to execute command. Status code: %d", response.status_code) logger.info("Command executed successfully: %s", response.text)
return response.json() return response.json()
except requests.exceptions.RequestException as e: else:
logger.error("An error occurred while trying to execute the command: %s", e) logger.error("Failed to execute command. Status code: %d", response.status_code)
logger.info("Retrying to execute command.")
except requests.exceptions.ReadTimeout:
break
except Exception as e:
logger.error("An error occurred while trying to execute the command: %s", e)
logger.info("Retrying to execute command.")
time.sleep(self.retry_interval)
logger.error("Failed to execute command.")
return None
def execute_action(self, action: Dict[str, Any]): def execute_action(self, action: Dict[str, Any]):
""" """
@@ -272,29 +309,47 @@ class PythonController:
""" """
Starts recording the screen. Starts recording the screen.
""" """
response = requests.post(self.http_server + "/start_recording")
if response.status_code == 200: for _ in range(self.retry_times):
logger.info("Recording started successfully") try:
else: response = requests.post(self.http_server + "/start_recording")
logger.error("Failed to start recording. Status code: %d", response.status_code) if response.status_code == 200:
logger.info("Recording started successfully")
return
else:
logger.error("Failed to start recording. Status code: %d", response.status_code)
logger.info("Retrying to start recording.")
except Exception as e:
logger.error("An error occurred while trying to start recording: %s", e)
logger.info("Retrying to start recording.")
time.sleep(self.retry_interval)
logger.error("Failed to start recording.")
def end_recording(self, dest: str): def end_recording(self, dest: str):
""" """
Ends recording the screen. Ends recording the screen.
""" """
try:
response = requests.post(self.http_server + "/end_recording") for _ in range(self.retry_times):
if response.status_code == 200: try:
logger.info("Recording stopped successfully") response = requests.post(self.http_server + "/end_recording")
with open(dest, 'wb') as f: if response.status_code == 200:
for chunk in response.iter_content(chunk_size=8192): logger.info("Recording stopped successfully")
if chunk: with open(dest, 'wb') as f:
f.write(chunk) for chunk in response.iter_content(chunk_size=8192):
else: if chunk:
logger.error("Failed to stop recording. Status code: %d", response.status_code) f.write(chunk)
return None return
except Exception as e: else:
logger.error("An error occurred while trying to download the recording: %s", e) logger.error("Failed to stop recording. Status code: %d", response.status_code)
logger.info("Retrying to stop recording.")
except Exception as e:
logger.error("An error occurred while trying to stop recording: %s", e)
logger.info("Retrying to stop recording.")
time.sleep(self.retry_interval)
logger.error("Failed to stop recording.")
# Additional info # Additional info
def get_vm_platform(self): def get_vm_platform(self):
@@ -307,60 +362,109 @@ class PythonController:
""" """
Gets the size of the vm screen. Gets the size of the vm screen.
""" """
response = requests.post(self.http_server + "/screen_size")
if response.status_code == 200: for _ in range(self.retry_times):
return response.json() try:
else: response = requests.post(self.http_server + "/screen_size")
logger.error("Failed to get screen size. Status code: %d", response.status_code) if response.status_code == 200:
return None logger.info("Got screen size successfully")
return response.json()
else:
logger.error("Failed to get screen size. Status code: %d", response.status_code)
logger.info("Retrying to get screen size.")
except Exception as e:
logger.error("An error occurred while trying to get the screen size: %s", e)
logger.info("Retrying to get screen size.")
time.sleep(self.retry_interval)
logger.error("Failed to get screen size.")
return None
def get_vm_window_size(self, app_class_name: str): def get_vm_window_size(self, app_class_name: str):
""" """
Gets the size of the vm app window. Gets the size of the vm app window.
""" """
response = requests.post(self.http_server + "/window_size", data={"app_class_name": app_class_name})
if response.status_code == 200: for _ in range(self.retry_times):
return response.json() try:
else: response = requests.post(self.http_server + "/window_size", data={"app_class_name": app_class_name})
logger.error("Failed to get window size. Status code: %d", response.status_code) if response.status_code == 200:
return None logger.info("Got window size successfully")
return response.json()
else:
logger.error("Failed to get window size. Status code: %d", response.status_code)
logger.info("Retrying to get window size.")
except Exception as e:
logger.error("An error occurred while trying to get the window size: %s", e)
logger.info("Retrying to get window size.")
time.sleep(self.retry_interval)
logger.error("Failed to get window size.")
return None
def get_vm_wallpaper(self): def get_vm_wallpaper(self):
""" """
Gets the wallpaper of the vm. Gets the wallpaper of the vm.
""" """
response = requests.post(self.http_server + "/wallpaper")
if response.status_code == 200:
logger.info("Wallpaper downloaded successfully")
return response.content
else:
logger.error("Failed to get wallpaper. Status code: %d", response.status_code)
return None
def get_vm_desktop_path(self): for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/wallpaper")
if response.status_code == 200:
logger.info("Got wallpaper successfully")
return response.content
else:
logger.error("Failed to get wallpaper. Status code: %d", response.status_code)
logger.info("Retrying to get wallpaper.")
except Exception as e:
logger.error("An error occurred while trying to get the wallpaper: %s", e)
logger.info("Retrying to get wallpaper.")
time.sleep(self.retry_interval)
logger.error("Failed to get wallpaper.")
return None
def get_vm_desktop_path(self) -> Optional[str]:
""" """
Gets the desktop path of the vm. Gets the desktop path of the vm.
""" """
response = requests.post(self.http_server + "/desktop_path")
if response.status_code == 200:
logger.info("Desktop path downloaded successfully")
return response.json()["desktop_path"]
else:
logger.error("Failed to get desktop path. Status code: %d", response.status_code)
return None
def get_vm_directory_tree(self, path): for _ in range(self.retry_times):
try:
response = requests.post(self.http_server + "/desktop_path")
if response.status_code == 200:
logger.info("Got desktop path successfully")
return response.json()["desktop_path"]
else:
logger.error("Failed to get desktop path. Status code: %d", response.status_code)
logger.info("Retrying to get desktop path.")
except Exception as e:
logger.error("An error occurred while trying to get the desktop path: %s", e)
logger.info("Retrying to get desktop path.")
time.sleep(self.retry_interval)
logger.error("Failed to get desktop path.")
return None
def get_vm_directory_tree(self, path) -> Optional[Dict[str, Any]]:
""" """
Gets the directory tree of the vm. Gets the directory tree of the vm.
""" """
payload = json.dumps({"path": path}) payload = json.dumps({"path": path})
headers = {
'Content-Type': 'application/json' for _ in range(self.retry_times):
} try:
response = requests.post(self.http_server + "/list_directory", headers=headers, data=payload) response = requests.post(self.http_server + "/list_directory", headers={'Content-Type': 'application/json'}, data=payload)
if response.status_code == 200: if response.status_code == 200:
logger.info("Directory tree downloaded successfully") logger.info("Got directory tree successfully")
return response.json()["directory_tree"] return response.json()["directory_tree"]
else: else:
logger.error("Failed to get directory tree. Status code: %d", response.status_code) logger.error("Failed to get directory tree. Status code: %d", response.status_code)
return None logger.info("Retrying to get directory tree.")
except Exception as e:
logger.error("An error occurred while trying to get directory tree: %s", e)
logger.info("Retrying to get directory tree.")
time.sleep(self.retry_interval)
logger.error("Failed to get directory tree.")
return None

View File

@@ -14,12 +14,11 @@ import backoff
import dashscope import dashscope
import google.generativeai as genai import google.generativeai as genai
import openai import openai
from groq import Groq
import requests import requests
import tiktoken import tiktoken
from PIL import Image from PIL import Image
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
from groq import Groq
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
@@ -954,8 +953,8 @@ class PromptAgent:
} }
assert len(message["content"]) in [1, 2], "One text, or one text with one image" assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]: for part in message["content"]:
qwen_message['content'].append({"image": part['image_url']['url']}) if part[ qwen_message['content'].append({"image": "file://" + save_to_tmp_img_file(part['image_url']['url'])}) if part[
'type'] == "image_url" else None 'type'] == "image_url" else None
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
qwen_messages.append(qwen_message) qwen_messages.append(qwen_message)
@@ -966,14 +965,30 @@ class PromptAgent:
if flag > 20: if flag > 20:
break break
logger.info("Generating content with model: %s", self.model) logger.info("Generating content with model: %s", self.model)
response = dashscope.Generation.call(
model=self.model, if self.model in ["qwen-vl-plus", "qwen-vl-max"]:
messages=qwen_messages, response = dashscope.MultiModalConversation.call(
result_format="message", model=self.model,
max_length=max_tokens, messages=qwen_messages,
top_p=top_p, result_format="message",
temperature=temperature max_length=max_tokens,
) top_p=top_p,
temperature=temperature
)
elif self.model in ["qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-0428", "qwen-max-0403",
"qwen-max-0107", "qwen-max-longcontext"]:
response = dashscope.Generation.call(
model=self.model,
messages=qwen_messages,
result_format="message",
max_length=max_tokens,
top_p=top_p,
temperature=temperature
)
else:
raise ValueError("Invalid model: " + self.model)
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
break break
@@ -989,11 +1004,16 @@ class PromptAgent:
else: else:
for i in range(len(qwen_messages[-1]["content"])): for i in range(len(qwen_messages[-1]["content"])):
if "text" in qwen_messages[-1]["content"][i]: if "text" in qwen_messages[-1]["content"][i]:
qwen_messages[-1]["content"][i]["text"] = ' '.join(qwen_messages[-1]["content"][i]["text"].split()[:-500]) qwen_messages[-1]["content"][i]["text"] = ' '.join(
qwen_messages[-1]["content"][i]["text"].split()[:-500])
flag = flag + 1 flag = flag + 1
try: try:
return response['output']['choices'][0]['message']['content'] if self.model in ["qwen-vl-plus", "qwen-vl-max"]:
return response['output']['choices'][0]['message']['content'][0]['text']
else:
return response['output']['choices'][0]['message']['content']
except Exception as e: except Exception as e:
print("Failed to call LLM: " + str(e)) print("Failed to call LLM: " + str(e))
return "" return ""