From 306dcbda71aa73d99c024757b34d5b9c17279647 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Tue, 21 May 2024 21:08:22 +0800 Subject: [PATCH] Add Support for QWEN VL models from API (QWEN-VL-max, etc.); Improve on the robustness of getting observation/files, etc. --- desktop_env/controllers/python.py | 330 ++++++++++++++++++++---------- mm_agents/agent.py | 48 +++-- 2 files changed, 251 insertions(+), 127 deletions(-) diff --git a/desktop_env/controllers/python.py b/desktop_env/controllers/python.py index 325e0ed..b3dab4a 100644 --- a/desktop_env/controllers/python.py +++ b/desktop_env/controllers/python.py @@ -2,7 +2,7 @@ import json import logging import random from typing import Any, Dict, Optional - +import time import requests from desktop_env.envs.actions import KEYBOARD_KEYS @@ -16,66 +16,96 @@ class PythonController: self.vm_ip = vm_ip self.http_server = f"http://{vm_ip}:5000" self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages + self.retry_times = 3 + self.retry_interval = 5 - def get_screenshot(self, retry_times=20): + def get_screenshot(self) -> Optional[bytes]: """ - Gets a screenshot from the server. With the cursor. + Gets a screenshot from the server. With the cursor. None -> no screenshot or unexpected error. """ - response = requests.get(self.http_server + "/screenshot") - if response.status_code == 200: - return response.content - else: - for _ in range(retry_times): - logger.error("Failed to get screenshot. Status code: %d", response.status_code) - logger.info("Retrying to get screenshot.") + + for _ in range(self.retry_times): + try: response = requests.get(self.http_server + "/screenshot") if response.status_code == 200: + logger.info("Got screenshot successfully") return response.content - logger.error("Failed to get screenshot. Status code: %d", response.status_code) - return None + else: + logger.error("Failed to get screenshot. Status code: %d", response.status_code) + logger.info("Retrying to get screenshot.") + except Exception as e: + logger.error("An error occurred while trying to get the screenshot: %s", e) + logger.info("Retrying to get screenshot.") + time.sleep(self.retry_interval) - def get_terminal_output(self, retry_times=20): - """ Gets the terminal output from the server. None -> no terminal output or unexpected error. + logger.error("Failed to get screenshot.") + return None + + def get_accessibility_tree(self) -> Optional[str]: """ - response = requests.get(self.http_server + "/terminal") - if response.status_code == 200: - return response.json()["output"] - else: - for _ in range(retry_times): - logger.error("Failed to get terminal output. Status code: %d", response.status_code) - logger.info("Retrying to get terminal output.") + Gets the accessibility tree from the server. None -> no accessibility tree or unexpected error. + """ + + for _ in range(self.retry_times): + try: + response: requests.Response = requests.get(self.http_server + "/accessibility") + if response.status_code == 200: + logger.info("Got accessibility tree successfully") + return response.json()["AT"] + else: + logger.error("Failed to get accessibility tree. Status code: %d", response.status_code) + logger.info("Retrying to get accessibility tree.") + except Exception as e: + logger.error("An error occurred while trying to get the accessibility tree: %s", e) + logger.info("Retrying to get accessibility tree.") + time.sleep(self.retry_interval) + + logger.error("Failed to get accessibility tree.") + return None + + def get_terminal_output(self) -> Optional[str]: + """ + Gets the terminal output from the server. None -> no terminal output or unexpected error. + """ + + for _ in range(self.retry_times): + try: response = requests.get(self.http_server + "/terminal") if response.status_code == 200: + logger.info("Got terminal output successfully") return response.json()["output"] - logger.error("Failed to get terminal output. Status code: %d", response.status_code) - return None + else: + logger.error("Failed to get terminal output. Status code: %d", response.status_code) + logger.info("Retrying to get terminal output.") + except Exception as e: + logger.error("An error occurred while trying to get the terminal output: %s", e) + logger.info("Retrying to get terminal output.") + time.sleep(self.retry_interval) - def get_accessibility_tree(self, retry_times=20) -> Optional[str]: + logger.error("Failed to get terminal output.") + return None - response: requests.Response = requests.get(self.http_server + "/accessibility") - if response.status_code == 200: - return response.json()["AT"] - else: - for _ in range(retry_times): - logger.error("Failed to get accessibility tree. Status code: %d", response.status_code) - logger.info("Retrying to get accessibility tree.") - response = requests.get(self.http_server + "/accessibility") - if response.status_code == 200: - return response.json()["AT"] - logger.error("Failed to get accessibility tree. Status code: %d", response.status_code) - return None - - def get_file(self, file_path: str): + def get_file(self, file_path: str) -> Optional[bytes]: """ Gets a file from the server. """ - response = requests.post(self.http_server + "/file", data={"file_path": file_path}) - if response.status_code == 200: - logger.info("File downloaded successfully") - return response.content - else: - logger.error("Failed to get file. Status code: %d", response.status_code) - return None + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/file", data={"file_path": file_path}) + if response.status_code == 200: + logger.info("File downloaded successfully") + return response.content + else: + logger.error("Failed to get file. Status code: %d", response.status_code) + logger.info("Retrying to get file.") + except Exception as e: + logger.error("An error occurred while trying to get the file: %s", e) + logger.info("Retrying to get file.") + time.sleep(self.retry_interval) + + logger.error("Failed to get file.") + return None def execute_python_command(self, command: str) -> None: """ @@ -85,19 +115,26 @@ class PythonController: # command_list = ["python", "-c", self.pkgs_prefix.format(command=command)] command_list = ["python", "-c", self.pkgs_prefix.format(command=command)] payload = json.dumps({"command": command_list, "shell": False}) - headers = { - 'Content-Type': 'application/json' - } - try: - response = requests.post(self.http_server + "/execute", headers=headers, data=payload, timeout=90) - if response.status_code == 200: - logger.info("Command executed successfully: %s", response.text) - else: - logger.error("Failed to execute command. Status code: %d", response.status_code) - return response.json() - except requests.exceptions.RequestException as e: - logger.error("An error occurred while trying to execute the command: %s", e) + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/execute", headers={'Content-Type': 'application/json'}, + data=payload, timeout=90) + if response.status_code == 200: + logger.info("Command executed successfully: %s", response.text) + return response.json() + else: + logger.error("Failed to execute command. Status code: %d", response.status_code) + logger.info("Retrying to execute command.") + except requests.exceptions.ReadTimeout: + break + except Exception as e: + logger.error("An error occurred while trying to execute the command: %s", e) + logger.info("Retrying to execute command.") + time.sleep(self.retry_interval) + + logger.error("Failed to execute command.") + return None def execute_action(self, action: Dict[str, Any]): """ @@ -272,29 +309,47 @@ class PythonController: """ Starts recording the screen. """ - response = requests.post(self.http_server + "/start_recording") - if response.status_code == 200: - logger.info("Recording started successfully") - else: - logger.error("Failed to start recording. Status code: %d", response.status_code) + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/start_recording") + if response.status_code == 200: + logger.info("Recording started successfully") + return + else: + logger.error("Failed to start recording. Status code: %d", response.status_code) + logger.info("Retrying to start recording.") + except Exception as e: + logger.error("An error occurred while trying to start recording: %s", e) + logger.info("Retrying to start recording.") + time.sleep(self.retry_interval) + + logger.error("Failed to start recording.") def end_recording(self, dest: str): """ Ends recording the screen. """ - try: - response = requests.post(self.http_server + "/end_recording") - if response.status_code == 200: - logger.info("Recording stopped successfully") - with open(dest, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - if chunk: - f.write(chunk) - else: - logger.error("Failed to stop recording. Status code: %d", response.status_code) - return None - except Exception as e: - logger.error("An error occurred while trying to download the recording: %s", e) + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/end_recording") + if response.status_code == 200: + logger.info("Recording stopped successfully") + with open(dest, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + return + else: + logger.error("Failed to stop recording. Status code: %d", response.status_code) + logger.info("Retrying to stop recording.") + except Exception as e: + logger.error("An error occurred while trying to stop recording: %s", e) + logger.info("Retrying to stop recording.") + time.sleep(self.retry_interval) + + logger.error("Failed to stop recording.") # Additional info def get_vm_platform(self): @@ -307,60 +362,109 @@ class PythonController: """ Gets the size of the vm screen. """ - response = requests.post(self.http_server + "/screen_size") - if response.status_code == 200: - return response.json() - else: - logger.error("Failed to get screen size. Status code: %d", response.status_code) - return None + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/screen_size") + if response.status_code == 200: + logger.info("Got screen size successfully") + return response.json() + else: + logger.error("Failed to get screen size. Status code: %d", response.status_code) + logger.info("Retrying to get screen size.") + except Exception as e: + logger.error("An error occurred while trying to get the screen size: %s", e) + logger.info("Retrying to get screen size.") + time.sleep(self.retry_interval) + + logger.error("Failed to get screen size.") + return None def get_vm_window_size(self, app_class_name: str): """ Gets the size of the vm app window. """ - response = requests.post(self.http_server + "/window_size", data={"app_class_name": app_class_name}) - if response.status_code == 200: - return response.json() - else: - logger.error("Failed to get window size. Status code: %d", response.status_code) - return None + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/window_size", data={"app_class_name": app_class_name}) + if response.status_code == 200: + logger.info("Got window size successfully") + return response.json() + else: + logger.error("Failed to get window size. Status code: %d", response.status_code) + logger.info("Retrying to get window size.") + except Exception as e: + logger.error("An error occurred while trying to get the window size: %s", e) + logger.info("Retrying to get window size.") + time.sleep(self.retry_interval) + + logger.error("Failed to get window size.") + return None def get_vm_wallpaper(self): """ Gets the wallpaper of the vm. """ - response = requests.post(self.http_server + "/wallpaper") - if response.status_code == 200: - logger.info("Wallpaper downloaded successfully") - return response.content - else: - logger.error("Failed to get wallpaper. Status code: %d", response.status_code) - return None - def get_vm_desktop_path(self): + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/wallpaper") + if response.status_code == 200: + logger.info("Got wallpaper successfully") + return response.content + else: + logger.error("Failed to get wallpaper. Status code: %d", response.status_code) + logger.info("Retrying to get wallpaper.") + except Exception as e: + logger.error("An error occurred while trying to get the wallpaper: %s", e) + logger.info("Retrying to get wallpaper.") + time.sleep(self.retry_interval) + + logger.error("Failed to get wallpaper.") + return None + + def get_vm_desktop_path(self) -> Optional[str]: """ Gets the desktop path of the vm. """ - response = requests.post(self.http_server + "/desktop_path") - if response.status_code == 200: - logger.info("Desktop path downloaded successfully") - return response.json()["desktop_path"] - else: - logger.error("Failed to get desktop path. Status code: %d", response.status_code) - return None - def get_vm_directory_tree(self, path): + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/desktop_path") + if response.status_code == 200: + logger.info("Got desktop path successfully") + return response.json()["desktop_path"] + else: + logger.error("Failed to get desktop path. Status code: %d", response.status_code) + logger.info("Retrying to get desktop path.") + except Exception as e: + logger.error("An error occurred while trying to get the desktop path: %s", e) + logger.info("Retrying to get desktop path.") + time.sleep(self.retry_interval) + + logger.error("Failed to get desktop path.") + return None + + def get_vm_directory_tree(self, path) -> Optional[Dict[str, Any]]: """ Gets the directory tree of the vm. """ payload = json.dumps({"path": path}) - headers = { - 'Content-Type': 'application/json' - } - response = requests.post(self.http_server + "/list_directory", headers=headers, data=payload) - if response.status_code == 200: - logger.info("Directory tree downloaded successfully") - return response.json()["directory_tree"] - else: - logger.error("Failed to get directory tree. Status code: %d", response.status_code) - return None + + for _ in range(self.retry_times): + try: + response = requests.post(self.http_server + "/list_directory", headers={'Content-Type': 'application/json'}, data=payload) + if response.status_code == 200: + logger.info("Got directory tree successfully") + return response.json()["directory_tree"] + else: + logger.error("Failed to get directory tree. Status code: %d", response.status_code) + logger.info("Retrying to get directory tree.") + except Exception as e: + logger.error("An error occurred while trying to get directory tree: %s", e) + logger.info("Retrying to get directory tree.") + time.sleep(self.retry_interval) + + logger.error("Failed to get directory tree.") + return None diff --git a/mm_agents/agent.py b/mm_agents/agent.py index ef98a41..9105e61 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -14,12 +14,11 @@ import backoff import dashscope import google.generativeai as genai import openai -from groq import Groq - import requests import tiktoken from PIL import Image from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest +from groq import Groq from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ @@ -954,8 +953,8 @@ class PromptAgent: } assert len(message["content"]) in [1, 2], "One text, or one text with one image" for part in message["content"]: - qwen_message['content'].append({"image": part['image_url']['url']}) if part[ - 'type'] == "image_url" else None + qwen_message['content'].append({"image": "file://" + save_to_tmp_img_file(part['image_url']['url'])}) if part[ + 'type'] == "image_url" else None qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None qwen_messages.append(qwen_message) @@ -966,14 +965,30 @@ class PromptAgent: if flag > 20: break logger.info("Generating content with model: %s", self.model) - response = dashscope.Generation.call( - model=self.model, - messages=qwen_messages, - result_format="message", - max_length=max_tokens, - top_p=top_p, - temperature=temperature - ) + + if self.model in ["qwen-vl-plus", "qwen-vl-max"]: + response = dashscope.MultiModalConversation.call( + model=self.model, + messages=qwen_messages, + result_format="message", + max_length=max_tokens, + top_p=top_p, + temperature=temperature + ) + + elif self.model in ["qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-0428", "qwen-max-0403", + "qwen-max-0107", "qwen-max-longcontext"]: + response = dashscope.Generation.call( + model=self.model, + messages=qwen_messages, + result_format="message", + max_length=max_tokens, + top_p=top_p, + temperature=temperature + ) + + else: + raise ValueError("Invalid model: " + self.model) if response.status_code == HTTPStatus.OK: break @@ -989,11 +1004,16 @@ class PromptAgent: else: for i in range(len(qwen_messages[-1]["content"])): if "text" in qwen_messages[-1]["content"][i]: - qwen_messages[-1]["content"][i]["text"] = ' '.join(qwen_messages[-1]["content"][i]["text"].split()[:-500]) + qwen_messages[-1]["content"][i]["text"] = ' '.join( + qwen_messages[-1]["content"][i]["text"].split()[:-500]) flag = flag + 1 try: - return response['output']['choices'][0]['message']['content'] + if self.model in ["qwen-vl-plus", "qwen-vl-max"]: + return response['output']['choices'][0]['message']['content'][0]['text'] + else: + return response['output']['choices'][0]['message']['content'] + except Exception as e: print("Failed to call LLM: " + str(e)) return ""