Add Support for QWEN VL models from API (QWEN-VL-max, etc.); Improve on the robustness of getting observation/files, etc.
This commit is contained in:
@@ -14,12 +14,11 @@ import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
from groq import Groq
|
||||
|
||||
import requests
|
||||
import tiktoken
|
||||
from PIL import Image
|
||||
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
|
||||
from groq import Groq
|
||||
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
|
||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||
@@ -954,8 +953,8 @@ class PromptAgent:
|
||||
}
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
for part in message["content"]:
|
||||
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
|
||||
'type'] == "image_url" else None
|
||||
qwen_message['content'].append({"image": "file://" + save_to_tmp_img_file(part['image_url']['url'])}) if part[
|
||||
'type'] == "image_url" else None
|
||||
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
||||
|
||||
qwen_messages.append(qwen_message)
|
||||
@@ -966,14 +965,30 @@ class PromptAgent:
|
||||
if flag > 20:
|
||||
break
|
||||
logger.info("Generating content with model: %s", self.model)
|
||||
response = dashscope.Generation.call(
|
||||
model=self.model,
|
||||
messages=qwen_messages,
|
||||
result_format="message",
|
||||
max_length=max_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
if self.model in ["qwen-vl-plus", "qwen-vl-max"]:
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
model=self.model,
|
||||
messages=qwen_messages,
|
||||
result_format="message",
|
||||
max_length=max_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
elif self.model in ["qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-0428", "qwen-max-0403",
|
||||
"qwen-max-0107", "qwen-max-longcontext"]:
|
||||
response = dashscope.Generation.call(
|
||||
model=self.model,
|
||||
messages=qwen_messages,
|
||||
result_format="message",
|
||||
max_length=max_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid model: " + self.model)
|
||||
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
break
|
||||
@@ -989,11 +1004,16 @@ class PromptAgent:
|
||||
else:
|
||||
for i in range(len(qwen_messages[-1]["content"])):
|
||||
if "text" in qwen_messages[-1]["content"][i]:
|
||||
qwen_messages[-1]["content"][i]["text"] = ' '.join(qwen_messages[-1]["content"][i]["text"].split()[:-500])
|
||||
qwen_messages[-1]["content"][i]["text"] = ' '.join(
|
||||
qwen_messages[-1]["content"][i]["text"].split()[:-500])
|
||||
flag = flag + 1
|
||||
|
||||
try:
|
||||
return response['output']['choices'][0]['message']['content']
|
||||
if self.model in ["qwen-vl-plus", "qwen-vl-max"]:
|
||||
return response['output']['choices'][0]['message']['content'][0]['text']
|
||||
else:
|
||||
return response['output']['choices'][0]['message']['content']
|
||||
|
||||
except Exception as e:
|
||||
print("Failed to call LLM: " + str(e))
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user