Update: seed agent
This commit is contained in:
@@ -14,6 +14,7 @@ import math
|
||||
import io
|
||||
import re
|
||||
from PIL import Image
|
||||
from volcenginesdkarkruntime import Ark
|
||||
|
||||
FINISH_WORD = "finished"
|
||||
WAIT_WORD = "wait"
|
||||
@@ -356,7 +357,7 @@ def modify_conversations(conversations):
|
||||
new_conversations.append(conversation)
|
||||
return new_conversations
|
||||
|
||||
class Seed16Agent:
|
||||
class SeedAgent:
|
||||
"""
|
||||
UI-TARS Agent based on Seed1.5-VL model implementation.
|
||||
Integrates the GUI folder UI-TARS-1.5 implementation with the mm_agents architecture.
|
||||
@@ -422,7 +423,7 @@ class Seed16Agent:
|
||||
self.platform = "ubuntu"
|
||||
self.use_thinking = use_thinking
|
||||
|
||||
self.inference_func = self.inference_with_thinking
|
||||
self.inference_func = self.inference_with_thinking_ark
|
||||
self.resize_image = resize_image
|
||||
self.resized_image_width = resized_image_width
|
||||
self.resized_image_height = resized_image_height
|
||||
@@ -492,7 +493,6 @@ class Seed16Agent:
|
||||
}
|
||||
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
print(response.json()["choices"][0])
|
||||
if response.status_code == 200:
|
||||
return response.json()["choices"][0]["message"]
|
||||
else:
|
||||
@@ -501,6 +501,53 @@ class Seed16Agent:
|
||||
"details": response.text
|
||||
}
|
||||
|
||||
def inference_with_thinking_ark(self, openai_messages):
|
||||
# 打印 Ark 的 URL 和 API Key
|
||||
api_key = os.environ['DOUBAO_API_KEY']
|
||||
api_url = os.environ['DOUBAO_API_URL']
|
||||
|
||||
# 初始化 Ark 实例
|
||||
vlm = Ark(
|
||||
base_url=api_url,
|
||||
api_key=api_key
|
||||
)
|
||||
|
||||
|
||||
# 调用 Ark 的 chat.completions.create 方法
|
||||
completion = vlm.chat.completions.create(
|
||||
model=self.model,
|
||||
stream=True,
|
||||
reasoning_effort='high',
|
||||
messages=openai_messages,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p
|
||||
)
|
||||
|
||||
# 初始化预测结果
|
||||
think_token = "think_never_used_51bce0c785ca2f68081bfa7d91973934"
|
||||
added_think_token = False
|
||||
|
||||
# 处理流式返回的结果
|
||||
prediction = ''
|
||||
reasoning_content = ''
|
||||
content = ''
|
||||
for chunk in completion:
|
||||
if hasattr(chunk, 'choices') and chunk.choices:
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
|
||||
reasoning_content += delta.reasoning_content
|
||||
if hasattr(delta, 'content') and delta.content:
|
||||
if not added_think_token:
|
||||
prediction += f"</{think_token}>"
|
||||
added_think_token = True
|
||||
content += delta.content
|
||||
|
||||
prediction = f"<{think_token}>" + reasoning_content + f"</{think_token}>" + content
|
||||
|
||||
# 返回预测结果
|
||||
return prediction
|
||||
|
||||
def inference_without_thinking(self, messages):
|
||||
api_key = os.environ['DOUBAO_API_KEY']
|
||||
api_url = os.environ['DOUBAO_API_URL']
|
||||
@@ -599,7 +646,8 @@ class Seed16Agent:
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": history_response
|
||||
"content": history_response.split("</think_never_used_51bce0c785ca2f68081bfa7d91973934>")[-1],
|
||||
"reasoning_content": history_response.split("</think_never_used_51bce0c785ca2f68081bfa7d91973934>")[0].replace("<think_never_used_51bce0c785ca2f68081bfa7d91973934>", "")
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
@@ -621,17 +669,10 @@ class Seed16Agent:
|
||||
while True:
|
||||
if try_times <= 0:
|
||||
print(f"Reach max retry times to fetch response from client, as error flag.")
|
||||
return prediction, ["FAIL"]
|
||||
raise ValueError("Client error")
|
||||
try:
|
||||
logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}")
|
||||
# json.dump(messages, open("debug_seed16.json", "w"), indent=4, ensure_ascii=False)
|
||||
response = self.inference_func(messages)
|
||||
content = response["content"]
|
||||
if "reasoning_content" in response and response["reasoning_content"]:
|
||||
reasoning_content = response["reasoning_content"]
|
||||
prediction = f"<think_never_used_51bce0c785ca2f68081bfa7d91973934>{reasoning_content}</think_never_used_51bce0c785ca2f68081bfa7d91973934>{content}"
|
||||
else:
|
||||
prediction = content
|
||||
prediction = self.inference_func(messages)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
@@ -650,7 +691,7 @@ class Seed16Agent:
|
||||
|
||||
except Exception as e:
|
||||
print(f"Parsing action error: {prediction}, with error:\n{e}")
|
||||
return prediction, ["FAIL"]
|
||||
raise ValueError("Parsing action error")
|
||||
|
||||
thoughts = prediction.split("</think_never_used_51bce0c785ca2f68081bfa7d91973934>")[0]
|
||||
self.thoughts.append(thoughts)
|
||||
Reference in New Issue
Block a user