Add Support for QWEN VL models from API (QWEN-VL-max, etc.); Improve on the robustness of getting observation/files, etc.
This commit is contained in:
@@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import time
|
||||
import requests
|
||||
|
||||
from desktop_env.envs.actions import KEYBOARD_KEYS
|
||||
@@ -16,66 +16,96 @@ class PythonController:
|
||||
self.vm_ip = vm_ip
|
||||
self.http_server = f"http://{vm_ip}:5000"
|
||||
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
|
||||
self.retry_times = 3
|
||||
self.retry_interval = 5
|
||||
|
||||
def get_screenshot(self, retry_times=20):
|
||||
def get_screenshot(self) -> Optional[bytes]:
|
||||
"""
|
||||
Gets a screenshot from the server. With the cursor.
|
||||
Gets a screenshot from the server. With the cursor. None -> no screenshot or unexpected error.
|
||||
"""
|
||||
response = requests.get(self.http_server + "/screenshot")
|
||||
if response.status_code == 200:
|
||||
return response.content
|
||||
else:
|
||||
for _ in range(retry_times):
|
||||
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get screenshot.")
|
||||
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.get(self.http_server + "/screenshot")
|
||||
if response.status_code == 200:
|
||||
logger.info("Got screenshot successfully")
|
||||
return response.content
|
||||
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
|
||||
return None
|
||||
else:
|
||||
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get screenshot.")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to get the screenshot: %s", e)
|
||||
logger.info("Retrying to get screenshot.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
def get_terminal_output(self, retry_times=20):
|
||||
""" Gets the terminal output from the server. None -> no terminal output or unexpected error.
|
||||
logger.error("Failed to get screenshot.")
|
||||
return None
|
||||
|
||||
def get_accessibility_tree(self) -> Optional[str]:
|
||||
"""
|
||||
response = requests.get(self.http_server + "/terminal")
|
||||
if response.status_code == 200:
|
||||
return response.json()["output"]
|
||||
else:
|
||||
for _ in range(retry_times):
|
||||
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get terminal output.")
|
||||
Gets the accessibility tree from the server. None -> no accessibility tree or unexpected error.
|
||||
"""
|
||||
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response: requests.Response = requests.get(self.http_server + "/accessibility")
|
||||
if response.status_code == 200:
|
||||
logger.info("Got accessibility tree successfully")
|
||||
return response.json()["AT"]
|
||||
else:
|
||||
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get accessibility tree.")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to get the accessibility tree: %s", e)
|
||||
logger.info("Retrying to get accessibility tree.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to get accessibility tree.")
|
||||
return None
|
||||
|
||||
def get_terminal_output(self) -> Optional[str]:
|
||||
"""
|
||||
Gets the terminal output from the server. None -> no terminal output or unexpected error.
|
||||
"""
|
||||
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.get(self.http_server + "/terminal")
|
||||
if response.status_code == 200:
|
||||
logger.info("Got terminal output successfully")
|
||||
return response.json()["output"]
|
||||
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
|
||||
return None
|
||||
else:
|
||||
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get terminal output.")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to get the terminal output: %s", e)
|
||||
logger.info("Retrying to get terminal output.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
def get_accessibility_tree(self, retry_times=20) -> Optional[str]:
|
||||
logger.error("Failed to get terminal output.")
|
||||
return None
|
||||
|
||||
response: requests.Response = requests.get(self.http_server + "/accessibility")
|
||||
if response.status_code == 200:
|
||||
return response.json()["AT"]
|
||||
else:
|
||||
for _ in range(retry_times):
|
||||
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get accessibility tree.")
|
||||
response = requests.get(self.http_server + "/accessibility")
|
||||
if response.status_code == 200:
|
||||
return response.json()["AT"]
|
||||
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
|
||||
return None
|
||||
|
||||
def get_file(self, file_path: str):
|
||||
def get_file(self, file_path: str) -> Optional[bytes]:
|
||||
"""
|
||||
Gets a file from the server.
|
||||
"""
|
||||
response = requests.post(self.http_server + "/file", data={"file_path": file_path})
|
||||
if response.status_code == 200:
|
||||
logger.info("File downloaded successfully")
|
||||
return response.content
|
||||
else:
|
||||
logger.error("Failed to get file. Status code: %d", response.status_code)
|
||||
return None
|
||||
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.post(self.http_server + "/file", data={"file_path": file_path})
|
||||
if response.status_code == 200:
|
||||
logger.info("File downloaded successfully")
|
||||
return response.content
|
||||
else:
|
||||
logger.error("Failed to get file. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get file.")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to get the file: %s", e)
|
||||
logger.info("Retrying to get file.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to get file.")
|
||||
return None
|
||||
|
||||
def execute_python_command(self, command: str) -> None:
|
||||
"""
|
||||
@@ -85,19 +115,26 @@ class PythonController:
|
||||
# command_list = ["python", "-c", self.pkgs_prefix.format(command=command)]
|
||||
command_list = ["python", "-c", self.pkgs_prefix.format(command=command)]
|
||||
payload = json.dumps({"command": command_list, "shell": False})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(self.http_server + "/execute", headers=headers, data=payload, timeout=90)
|
||||
if response.status_code == 200:
|
||||
logger.info("Command executed successfully: %s", response.text)
|
||||
else:
|
||||
logger.error("Failed to execute command. Status code: %d", response.status_code)
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error("An error occurred while trying to execute the command: %s", e)
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.post(self.http_server + "/execute", headers={'Content-Type': 'application/json'},
|
||||
data=payload, timeout=90)
|
||||
if response.status_code == 200:
|
||||
logger.info("Command executed successfully: %s", response.text)
|
||||
return response.json()
|
||||
else:
|
||||
logger.error("Failed to execute command. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to execute command.")
|
||||
except requests.exceptions.ReadTimeout:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to execute the command: %s", e)
|
||||
logger.info("Retrying to execute command.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to execute command.")
|
||||
return None
|
||||
|
||||
def execute_action(self, action: Dict[str, Any]):
|
||||
"""
|
||||
@@ -272,29 +309,47 @@ class PythonController:
|
||||
"""
|
||||
Starts recording the screen.
|
||||
"""
|
||||
response = requests.post(self.http_server + "/start_recording")
|
||||
if response.status_code == 200:
|
||||
logger.info("Recording started successfully")
|
||||
else:
|
||||
logger.error("Failed to start recording. Status code: %d", response.status_code)
|
||||
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.post(self.http_server + "/start_recording")
|
||||
if response.status_code == 200:
|
||||
logger.info("Recording started successfully")
|
||||
return
|
||||
else:
|
||||
logger.error("Failed to start recording. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to start recording.")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to start recording: %s", e)
|
||||
logger.info("Retrying to start recording.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to start recording.")
|
||||
|
||||
def end_recording(self, dest: str):
|
||||
"""
|
||||
Ends recording the screen.
|
||||
"""
|
||||
try:
|
||||
response = requests.post(self.http_server + "/end_recording")
|
||||
if response.status_code == 200:
|
||||
logger.info("Recording stopped successfully")
|
||||
with open(dest, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
else:
|
||||
logger.error("Failed to stop recording. Status code: %d", response.status_code)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to download the recording: %s", e)
|
||||
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.post(self.http_server + "/end_recording")
|
||||
if response.status_code == 200:
|
||||
logger.info("Recording stopped successfully")
|
||||
with open(dest, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
return
|
||||
else:
|
||||
logger.error("Failed to stop recording. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to stop recording.")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to stop recording: %s", e)
|
||||
logger.info("Retrying to stop recording.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to stop recording.")
|
||||
|
||||
# Additional info
|
||||
def get_vm_platform(self):
|
||||
@@ -307,60 +362,109 @@ class PythonController:
|
||||
"""
|
||||
Gets the size of the vm screen.
|
||||
"""
|
||||
response = requests.post(self.http_server + "/screen_size")
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.error("Failed to get screen size. Status code: %d", response.status_code)
|
||||
return None
|
||||
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.post(self.http_server + "/screen_size")
|
||||
if response.status_code == 200:
|
||||
logger.info("Got screen size successfully")
|
||||
return response.json()
|
||||
else:
|
||||
logger.error("Failed to get screen size. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get screen size.")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to get the screen size: %s", e)
|
||||
logger.info("Retrying to get screen size.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to get screen size.")
|
||||
return None
|
||||
|
||||
def get_vm_window_size(self, app_class_name: str):
|
||||
"""
|
||||
Gets the size of the vm app window.
|
||||
"""
|
||||
response = requests.post(self.http_server + "/window_size", data={"app_class_name": app_class_name})
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.error("Failed to get window size. Status code: %d", response.status_code)
|
||||
return None
|
||||
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.post(self.http_server + "/window_size", data={"app_class_name": app_class_name})
|
||||
if response.status_code == 200:
|
||||
logger.info("Got window size successfully")
|
||||
return response.json()
|
||||
else:
|
||||
logger.error("Failed to get window size. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get window size.")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to get the window size: %s", e)
|
||||
logger.info("Retrying to get window size.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to get window size.")
|
||||
return None
|
||||
|
||||
def get_vm_wallpaper(self):
|
||||
"""
|
||||
Gets the wallpaper of the vm.
|
||||
"""
|
||||
response = requests.post(self.http_server + "/wallpaper")
|
||||
if response.status_code == 200:
|
||||
logger.info("Wallpaper downloaded successfully")
|
||||
return response.content
|
||||
else:
|
||||
logger.error("Failed to get wallpaper. Status code: %d", response.status_code)
|
||||
return None
|
||||
|
||||
def get_vm_desktop_path(self):
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.post(self.http_server + "/wallpaper")
|
||||
if response.status_code == 200:
|
||||
logger.info("Got wallpaper successfully")
|
||||
return response.content
|
||||
else:
|
||||
logger.error("Failed to get wallpaper. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get wallpaper.")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to get the wallpaper: %s", e)
|
||||
logger.info("Retrying to get wallpaper.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to get wallpaper.")
|
||||
return None
|
||||
|
||||
def get_vm_desktop_path(self) -> Optional[str]:
|
||||
"""
|
||||
Gets the desktop path of the vm.
|
||||
"""
|
||||
response = requests.post(self.http_server + "/desktop_path")
|
||||
if response.status_code == 200:
|
||||
logger.info("Desktop path downloaded successfully")
|
||||
return response.json()["desktop_path"]
|
||||
else:
|
||||
logger.error("Failed to get desktop path. Status code: %d", response.status_code)
|
||||
return None
|
||||
|
||||
def get_vm_directory_tree(self, path):
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.post(self.http_server + "/desktop_path")
|
||||
if response.status_code == 200:
|
||||
logger.info("Got desktop path successfully")
|
||||
return response.json()["desktop_path"]
|
||||
else:
|
||||
logger.error("Failed to get desktop path. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get desktop path.")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to get the desktop path: %s", e)
|
||||
logger.info("Retrying to get desktop path.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to get desktop path.")
|
||||
return None
|
||||
|
||||
def get_vm_directory_tree(self, path) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Gets the directory tree of the vm.
|
||||
"""
|
||||
payload = json.dumps({"path": path})
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
response = requests.post(self.http_server + "/list_directory", headers=headers, data=payload)
|
||||
if response.status_code == 200:
|
||||
logger.info("Directory tree downloaded successfully")
|
||||
return response.json()["directory_tree"]
|
||||
else:
|
||||
logger.error("Failed to get directory tree. Status code: %d", response.status_code)
|
||||
return None
|
||||
|
||||
for _ in range(self.retry_times):
|
||||
try:
|
||||
response = requests.post(self.http_server + "/list_directory", headers={'Content-Type': 'application/json'}, data=payload)
|
||||
if response.status_code == 200:
|
||||
logger.info("Got directory tree successfully")
|
||||
return response.json()["directory_tree"]
|
||||
else:
|
||||
logger.error("Failed to get directory tree. Status code: %d", response.status_code)
|
||||
logger.info("Retrying to get directory tree.")
|
||||
except Exception as e:
|
||||
logger.error("An error occurred while trying to get directory tree: %s", e)
|
||||
logger.info("Retrying to get directory tree.")
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
logger.error("Failed to get directory tree.")
|
||||
return None
|
||||
|
||||
@@ -14,12 +14,11 @@ import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
from groq import Groq
|
||||
|
||||
import requests
|
||||
import tiktoken
|
||||
from PIL import Image
|
||||
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
|
||||
from groq import Groq
|
||||
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
|
||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||
@@ -954,8 +953,8 @@ class PromptAgent:
|
||||
}
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
for part in message["content"]:
|
||||
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
|
||||
'type'] == "image_url" else None
|
||||
qwen_message['content'].append({"image": "file://" + save_to_tmp_img_file(part['image_url']['url'])}) if part[
|
||||
'type'] == "image_url" else None
|
||||
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
||||
|
||||
qwen_messages.append(qwen_message)
|
||||
@@ -966,14 +965,30 @@ class PromptAgent:
|
||||
if flag > 20:
|
||||
break
|
||||
logger.info("Generating content with model: %s", self.model)
|
||||
response = dashscope.Generation.call(
|
||||
model=self.model,
|
||||
messages=qwen_messages,
|
||||
result_format="message",
|
||||
max_length=max_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
if self.model in ["qwen-vl-plus", "qwen-vl-max"]:
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
model=self.model,
|
||||
messages=qwen_messages,
|
||||
result_format="message",
|
||||
max_length=max_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
elif self.model in ["qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-0428", "qwen-max-0403",
|
||||
"qwen-max-0107", "qwen-max-longcontext"]:
|
||||
response = dashscope.Generation.call(
|
||||
model=self.model,
|
||||
messages=qwen_messages,
|
||||
result_format="message",
|
||||
max_length=max_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid model: " + self.model)
|
||||
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
break
|
||||
@@ -989,11 +1004,16 @@ class PromptAgent:
|
||||
else:
|
||||
for i in range(len(qwen_messages[-1]["content"])):
|
||||
if "text" in qwen_messages[-1]["content"][i]:
|
||||
qwen_messages[-1]["content"][i]["text"] = ' '.join(qwen_messages[-1]["content"][i]["text"].split()[:-500])
|
||||
qwen_messages[-1]["content"][i]["text"] = ' '.join(
|
||||
qwen_messages[-1]["content"][i]["text"].split()[:-500])
|
||||
flag = flag + 1
|
||||
|
||||
try:
|
||||
return response['output']['choices'][0]['message']['content']
|
||||
if self.model in ["qwen-vl-plus", "qwen-vl-max"]:
|
||||
return response['output']['choices'][0]['message']['content'][0]['text']
|
||||
else:
|
||||
return response['output']['choices'][0]['message']['content']
|
||||
|
||||
except Exception as e:
|
||||
print("Failed to call LLM: " + str(e))
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user