Add Support for QWEN models from API (QWEN-max, etc.); Improve on the robustness of getting observation
This commit is contained in:
@@ -11,12 +11,13 @@ logger = logging.getLogger("desktopenv.pycontroller")
|
|||||||
|
|
||||||
|
|
||||||
class PythonController:
|
class PythonController:
|
||||||
def __init__(self, vm_ip: str, pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"):
|
def __init__(self, vm_ip: str,
|
||||||
|
pkgs_prefix: str = "import pyautogui; import time; pyautogui.FAILSAFE = False; {command}"):
|
||||||
self.vm_ip = vm_ip
|
self.vm_ip = vm_ip
|
||||||
self.http_server = f"http://{vm_ip}:5000"
|
self.http_server = f"http://{vm_ip}:5000"
|
||||||
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
|
self.pkgs_prefix = pkgs_prefix # fixme: this is a hacky way to execute python commands. fix it and combine it with installation of packages
|
||||||
|
|
||||||
def get_screenshot(self):
|
def get_screenshot(self, retry_times=20):
|
||||||
"""
|
"""
|
||||||
Gets a screenshot from the server. With the cursor.
|
Gets a screenshot from the server. With the cursor.
|
||||||
"""
|
"""
|
||||||
@@ -24,25 +25,43 @@ class PythonController:
|
|||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response.content
|
return response.content
|
||||||
else:
|
else:
|
||||||
|
for _ in range(retry_times):
|
||||||
|
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
|
||||||
|
logger.info("Retrying to get screenshot.")
|
||||||
|
response = requests.get(self.http_server + "/screenshot")
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.content
|
||||||
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
|
logger.error("Failed to get screenshot. Status code: %d", response.status_code)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_terminal_output(self):
|
def get_terminal_output(self, retry_times=20):
|
||||||
""" Gets the terminal output from the server. None -> no terminal output or unexpected error.
|
""" Gets the terminal output from the server. None -> no terminal output or unexpected error.
|
||||||
"""
|
"""
|
||||||
response = requests.get(self.http_server + "/terminal")
|
response = requests.get(self.http_server + "/terminal")
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response.json()["output"]
|
return response.json()["output"]
|
||||||
else:
|
else:
|
||||||
|
for _ in range(retry_times):
|
||||||
|
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
|
||||||
|
logger.info("Retrying to get terminal output.")
|
||||||
|
response = requests.get(self.http_server + "/terminal")
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()["output"]
|
||||||
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
|
logger.error("Failed to get terminal output. Status code: %d", response.status_code)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_accessibility_tree(self) -> Optional[str]:
|
def get_accessibility_tree(self, retry_times=20) -> Optional[str]:
|
||||||
|
|
||||||
response: requests.Response = requests.get(self.http_server + "/accessibility")
|
response: requests.Response = requests.get(self.http_server + "/accessibility")
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response.json()["AT"]
|
return response.json()["AT"]
|
||||||
else:
|
else:
|
||||||
|
for _ in range(retry_times):
|
||||||
|
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
|
||||||
|
logger.info("Retrying to get accessibility tree.")
|
||||||
|
response = requests.get(self.http_server + "/accessibility")
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()["AT"]
|
||||||
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
|
logger.error("Failed to get accessibility tree. Status code: %d", response.status_code)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -138,24 +138,9 @@ class DesktopEnv(gym.Env):
|
|||||||
def _save_state(self):
|
def _save_state(self):
|
||||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
|
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
|
||||||
|
|
||||||
def _get_screenshot(self):
|
|
||||||
screenshot = None
|
|
||||||
# Get the screenshot and save to the image_path
|
|
||||||
max_retries = 20
|
|
||||||
for _ in range(max_retries):
|
|
||||||
screenshot = self.controller.get_screenshot()
|
|
||||||
if screenshot is not None:
|
|
||||||
break
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
if screenshot is None:
|
|
||||||
logger.error("Failed to get screenshot!")
|
|
||||||
|
|
||||||
return screenshot
|
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
return {
|
return {
|
||||||
"screenshot": self._get_screenshot(),
|
"screenshot": self.controller.get_screenshot(),
|
||||||
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
||||||
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
|
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
|
||||||
"instruction": self.instruction
|
"instruction": self.instruction
|
||||||
|
|||||||
@@ -943,8 +943,7 @@ class PromptAgent:
|
|||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
max_tokens = payload["max_tokens"]
|
max_tokens = payload["max_tokens"]
|
||||||
top_p = payload["top_p"]
|
top_p = payload["top_p"]
|
||||||
if payload["temperature"]:
|
temperature = payload["temperature"]
|
||||||
logger.warning("Qwen model does not support temperature parameter, it will be ignored.")
|
|
||||||
|
|
||||||
qwen_messages = []
|
qwen_messages = []
|
||||||
|
|
||||||
@@ -961,23 +960,42 @@ class PromptAgent:
|
|||||||
|
|
||||||
qwen_messages.append(qwen_message)
|
qwen_messages.append(qwen_message)
|
||||||
|
|
||||||
response = dashscope.MultiModalConversation.call(
|
flag = 0
|
||||||
model='qwen-vl-plus',
|
while True:
|
||||||
messages=messages,
|
|
||||||
max_length=max_tokens,
|
|
||||||
top_p=top_p,
|
|
||||||
)
|
|
||||||
# The response status_code is HTTPStatus.OK indicate success,
|
|
||||||
# otherwise indicate request is failed, you can get error code
|
|
||||||
# and message from code and message.
|
|
||||||
if response.status_code == HTTPStatus.OK:
|
|
||||||
try:
|
try:
|
||||||
return response.json()['output']['choices'][0]['message']['content']
|
if flag > 20:
|
||||||
except Exception:
|
break
|
||||||
return ""
|
logger.info("Generating content with model: %s", self.model)
|
||||||
else:
|
response = dashscope.Generation.call(
|
||||||
print(response.code) # The error code.
|
model=self.model,
|
||||||
print(response.message) # The error message.
|
messages=qwen_messages,
|
||||||
|
result_format="message",
|
||||||
|
max_length=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
||||||
|
response.request_id, response.status_code,
|
||||||
|
response.code, response.message
|
||||||
|
))
|
||||||
|
raise Exception("Failed to call LLM: " + response.message)
|
||||||
|
except:
|
||||||
|
if flag == 0:
|
||||||
|
qwen_messages = [qwen_messages[0]] + qwen_messages[-1:]
|
||||||
|
else:
|
||||||
|
for i in range(len(qwen_messages[-1]["content"])):
|
||||||
|
if "text" in qwen_messages[-1]["content"][i]:
|
||||||
|
qwen_messages[-1]["content"][i]["text"] = ' '.join(qwen_messages[-1]["content"][i]["text"].split()[:-500])
|
||||||
|
flag = flag + 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
return response['output']['choices'][0]['message']['content']
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to call LLM: " + str(e))
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user