Update: seed agent

This commit is contained in:
Ubuntu
2025-12-15 11:45:57 +00:00
parent 78433ecfcf
commit 41477a9c40
2 changed files with 57 additions and 16 deletions

View File

@@ -14,6 +14,7 @@ import math
import io import io
import re import re
from PIL import Image from PIL import Image
from volcenginesdkarkruntime import Ark
FINISH_WORD = "finished" FINISH_WORD = "finished"
WAIT_WORD = "wait" WAIT_WORD = "wait"
@@ -356,7 +357,7 @@ def modify_conversations(conversations):
new_conversations.append(conversation) new_conversations.append(conversation)
return new_conversations return new_conversations
class Seed16Agent: class SeedAgent:
""" """
UI-TARS Agent based on Seed1.5-VL model implementation. UI-TARS Agent based on Seed1.5-VL model implementation.
Integrates the GUI folder UI-TARS-1.5 implementation with the mm_agents architecture. Integrates the GUI folder UI-TARS-1.5 implementation with the mm_agents architecture.
@@ -422,7 +423,7 @@ class Seed16Agent:
self.platform = "ubuntu" self.platform = "ubuntu"
self.use_thinking = use_thinking self.use_thinking = use_thinking
self.inference_func = self.inference_with_thinking self.inference_func = self.inference_with_thinking_ark
self.resize_image = resize_image self.resize_image = resize_image
self.resized_image_width = resized_image_width self.resized_image_width = resized_image_width
self.resized_image_height = resized_image_height self.resized_image_height = resized_image_height
@@ -492,7 +493,6 @@ class Seed16Agent:
} }
response = requests.post(api_url, headers=headers, json=data) response = requests.post(api_url, headers=headers, json=data)
print(response.json()["choices"][0])
if response.status_code == 200: if response.status_code == 200:
return response.json()["choices"][0]["message"] return response.json()["choices"][0]["message"]
else: else:
@@ -501,6 +501,53 @@ class Seed16Agent:
"details": response.text "details": response.text
} }
def inference_with_thinking_ark(self, openai_messages):
# 打印 Ark 的 URL 和 API Key
api_key = os.environ['DOUBAO_API_KEY']
api_url = os.environ['DOUBAO_API_URL']
# 初始化 Ark 实例
vlm = Ark(
base_url=api_url,
api_key=api_key
)
# 调用 Ark 的 chat.completions.create 方法
completion = vlm.chat.completions.create(
model=self.model,
stream=True,
reasoning_effort='high',
messages=openai_messages,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p
)
# 初始化预测结果
think_token = "think_never_used_51bce0c785ca2f68081bfa7d91973934"
added_think_token = False
# 处理流式返回的结果
prediction = ''
reasoning_content = ''
content = ''
for chunk in completion:
if hasattr(chunk, 'choices') and chunk.choices:
delta = chunk.choices[0].delta
if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
reasoning_content += delta.reasoning_content
if hasattr(delta, 'content') and delta.content:
if not added_think_token:
prediction += f"</{think_token}>"
added_think_token = True
content += delta.content
prediction = f"<{think_token}>" + reasoning_content + f"</{think_token}>" + content
# 返回预测结果
return prediction
def inference_without_thinking(self, messages): def inference_without_thinking(self, messages):
api_key = os.environ['DOUBAO_API_KEY'] api_key = os.environ['DOUBAO_API_KEY']
api_url = os.environ['DOUBAO_API_URL'] api_url = os.environ['DOUBAO_API_URL']
@@ -599,7 +646,8 @@ class Seed16Agent:
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
"content": history_response "content": history_response.split("</think_never_used_51bce0c785ca2f68081bfa7d91973934>")[-1],
"reasoning_content": history_response.split("</think_never_used_51bce0c785ca2f68081bfa7d91973934>")[0].replace("<think_never_used_51bce0c785ca2f68081bfa7d91973934>", "")
}) })
messages.append({ messages.append({
"role": "tool", "role": "tool",
@@ -621,17 +669,10 @@ class Seed16Agent:
while True: while True:
if try_times <= 0: if try_times <= 0:
print(f"Reach max retry times to fetch response from client, as error flag.") print(f"Reach max retry times to fetch response from client, as error flag.")
return prediction, ["FAIL"] raise ValueError("Client error")
try: try:
logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}") logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}")
# json.dump(messages, open("debug_seed16.json", "w"), indent=4, ensure_ascii=False) prediction = self.inference_func(messages)
response = self.inference_func(messages)
content = response["content"]
if "reasoning_content" in response and response["reasoning_content"]:
reasoning_content = response["reasoning_content"]
prediction = f"<think_never_used_51bce0c785ca2f68081bfa7d91973934>{reasoning_content}</think_never_used_51bce0c785ca2f68081bfa7d91973934>{content}"
else:
prediction = content
break break
except Exception as e: except Exception as e:
@@ -650,7 +691,7 @@ class Seed16Agent:
except Exception as e: except Exception as e:
print(f"Parsing action error: {prediction}, with error:\n{e}") print(f"Parsing action error: {prediction}, with error:\n{e}")
return prediction, ["FAIL"] raise ValueError("Parsing action error")
thoughts = prediction.split("</think_never_used_51bce0c785ca2f68081bfa7d91973934>")[0] thoughts = prediction.split("</think_never_used_51bce0c785ca2f68081bfa7d91973934>")[0]
self.thoughts.append(thoughts) self.thoughts.append(thoughts)

View File

@@ -12,7 +12,7 @@ from multiprocessing import Process, Manager
from multiprocessing import current_process from multiprocessing import current_process
import lib_run_single import lib_run_single
from desktop_env.desktop_env import DesktopEnv from desktop_env.desktop_env import DesktopEnv
from mm_agents.seed16 import Seed16Agent from mm_agents.seed_agent import SeedAgent
import os import os
@@ -184,7 +184,7 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
client_password=args.client_password client_password=args.client_password
) )
active_environments.append(env) active_environments.append(env)
agent = Seed16Agent( agent = SeedAgent(
model=args.model, model=args.model,
model_type=args.model_type, model_type=args.model_type,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,