From f9594e476e9758052c25cfc22beef0b3d77bfb40 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Mon, 20 May 2024 00:47:43 +0800 Subject: [PATCH] Add Support for QWEN models from API (QWEN-max, etc.); Improve on the robustness of getting observation --- desktop_env/controllers/python.py | 27 +++++++++++++--- desktop_env/envs/desktop_env.py | 17 +--------- mm_agents/agent.py | 54 ++++++++++++++++++++----------- 3 files changed, 60 insertions(+), 38 deletions(-) diff --git a/desktop_env/controllers/python.py b/desktop_env/controllers/python.py index b9b3387..325e0ed 100644 --- a/desktop_env/controllers/python.py +++ b/desktop_env/controllers/python.py @@ -11,12 +11,13 @@ logger = logging.getLogger("desktopenv.pycontroller") class PythonController: - def __init__(self, vm_ip: str, pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"): + def __init__(self, vm_ip: str, + pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"): self.vm_ip = vm_ip self.http_server = f"http://{vm_ip}:5000" self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages - def get_screenshot(self): + def get_screenshot(self, retry_times=20): """ Gets a screenshot from the server. With the cursor. """ @@ -24,25 +25,43 @@ class PythonController: if response.status_code == 200: return response.content else: + for _ in range(retry_times): + logger.error("Failed to get screenshot. Status code: %d", response.status_code) + logger.info("Retrying to get screenshot.") + response = requests.get(self.http_server + "/screenshot") + if response.status_code == 200: + return response.content logger.error("Failed to get screenshot. Status code: %d", response.status_code) return None - def get_terminal_output(self): + def get_terminal_output(self, retry_times=20): """ Gets the terminal output from the server. None -> no terminal output or unexpected error. """ response = requests.get(self.http_server + "/terminal") if response.status_code == 200: return response.json()["output"] else: + for _ in range(retry_times): + logger.error("Failed to get terminal output. Status code: %d", response.status_code) + logger.info("Retrying to get terminal output.") + response = requests.get(self.http_server + "/terminal") + if response.status_code == 200: + return response.json()["output"] logger.error("Failed to get terminal output. Status code: %d", response.status_code) return None - def get_accessibility_tree(self) -> Optional[str]: + def get_accessibility_tree(self, retry_times=20) -> Optional[str]: response: requests.Response = requests.get(self.http_server + "/accessibility") if response.status_code == 200: return response.json()["AT"] else: + for _ in range(retry_times): + logger.error("Failed to get accessibility tree. Status code: %d", response.status_code) + logger.info("Retrying to get accessibility tree.") + response = requests.get(self.http_server + "/accessibility") + if response.status_code == 200: + return response.json()["AT"] logger.error("Failed to get accessibility tree. Status code: %d", response.status_code) return None diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index 2f84efe..6f061f9 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -138,24 +138,9 @@ class DesktopEnv(gym.Env): def _save_state(self): _execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name]) - def _get_screenshot(self): - screenshot = None - # Get the screenshot and save to the image_path - max_retries = 20 - for _ in range(max_retries): - screenshot = self.controller.get_screenshot() - if screenshot is not None: - break - time.sleep(1) - - if screenshot is None: - logger.error("Failed to get screenshot!") - - return screenshot - def _get_obs(self): return { - "screenshot": self._get_screenshot(), + "screenshot": self.controller.get_screenshot(), "accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None, "terminal": self.controller.get_terminal_output() if self.require_terminal else None, "instruction": self.instruction diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 893d45f..ef98a41 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -943,8 +943,7 @@ class PromptAgent: messages = payload["messages"] max_tokens = payload["max_tokens"] top_p = payload["top_p"] - if payload["temperature"]: - logger.warning("Qwen model does not support temperature parameter, it will be ignored.") + temperature = payload["temperature"] qwen_messages = [] @@ -961,23 +960,42 @@ class PromptAgent: qwen_messages.append(qwen_message) - response = dashscope.MultiModalConversation.call( - model='qwen-vl-plus', - messages=messages, - max_length=max_tokens, - top_p=top_p, - ) - # The response status_code is HTTPStatus.OK indicate success, - # otherwise indicate request is failed, you can get error code - # and message from code and message. - if response.status_code == HTTPStatus.OK: + flag = 0 + while True: try: - return response.json()['output']['choices'][0]['message']['content'] - except Exception: - return "" - else: - print(response.code) # The error code. - print(response.message) # The error message. + if flag > 20: + break + logger.info("Generating content with model: %s", self.model) + response = dashscope.Generation.call( + model=self.model, + messages=qwen_messages, + result_format="message", + max_length=max_tokens, + top_p=top_p, + temperature=temperature + ) + + if response.status_code == HTTPStatus.OK: + break + else: + logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % ( + response.request_id, response.status_code, + response.code, response.message + )) + raise Exception("Failed to call LLM: " + response.message) + except: + if flag == 0: + qwen_messages = [qwen_messages[0]] + qwen_messages[-1:] + else: + for i in range(len(qwen_messages[-1]["content"])): + if "text" in qwen_messages[-1]["content"][i]: + qwen_messages[-1]["content"][i]["text"] = ' '.join(qwen_messages[-1]["content"][i]["text"].split()[:-500]) + flag = flag + 1 + + try: + return response['output']['choices'][0]['message']['content'] + except Exception as e: + print("Failed to call LLM: " + str(e)) return "" else: