Add Support for QWEN VL models from API (QWEN-VL-max, etc.); Improve on the robustness of getting observation/files, etc.

This commit is contained in:
Timothyxxx
2024-05-21 21:08:22 +08:00
parent f9594e476e
commit 306dcbda71
2 changed files with 251 additions and 127 deletions

View File

@@ -14,12 +14,11 @@ import backoff
import dashscope
import google.generativeai as genai
import openai
from groq import Groq
import requests
import tiktoken
from PIL import Image
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
from groq import Groq
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
@@ -954,8 +953,8 @@ class PromptAgent:
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]:
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
'type'] == "image_url" else None
qwen_message['content'].append({"image": "file://" + save_to_tmp_img_file(part['image_url']['url'])}) if part[
'type'] == "image_url" else None
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
qwen_messages.append(qwen_message)
@@ -966,14 +965,30 @@ class PromptAgent:
if flag > 20:
break
logger.info("Generating content with model: %s", self.model)
response = dashscope.Generation.call(
model=self.model,
messages=qwen_messages,
result_format="message",
max_length=max_tokens,
top_p=top_p,
temperature=temperature
)
if self.model in ["qwen-vl-plus", "qwen-vl-max"]:
response = dashscope.MultiModalConversation.call(
model=self.model,
messages=qwen_messages,
result_format="message",
max_length=max_tokens,
top_p=top_p,
temperature=temperature
)
elif self.model in ["qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-0428", "qwen-max-0403",
"qwen-max-0107", "qwen-max-longcontext"]:
response = dashscope.Generation.call(
model=self.model,
messages=qwen_messages,
result_format="message",
max_length=max_tokens,
top_p=top_p,
temperature=temperature
)
else:
raise ValueError("Invalid model: " + self.model)
if response.status_code == HTTPStatus.OK:
break
@@ -989,11 +1004,16 @@ class PromptAgent:
else:
for i in range(len(qwen_messages[-1]["content"])):
if "text" in qwen_messages[-1]["content"][i]:
qwen_messages[-1]["content"][i]["text"] = ' '.join(qwen_messages[-1]["content"][i]["text"].split()[:-500])
qwen_messages[-1]["content"][i]["text"] = ' '.join(
qwen_messages[-1]["content"][i]["text"].split()[:-500])
flag = flag + 1
try:
return response['output']['choices'][0]['message']['content']
if self.model in ["qwen-vl-plus", "qwen-vl-max"]:
return response['output']['choices'][0]['message']['content'][0]['text']
else:
return response['output']['choices'][0]['message']['content']
except Exception as e:
print("Failed to call LLM: " + str(e))
return ""