Update: seed agent
This commit is contained in:
@@ -14,6 +14,7 @@ import math
|
|||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from volcenginesdkarkruntime import Ark
|
||||||
|
|
||||||
FINISH_WORD = "finished"
|
FINISH_WORD = "finished"
|
||||||
WAIT_WORD = "wait"
|
WAIT_WORD = "wait"
|
||||||
@@ -356,7 +357,7 @@ def modify_conversations(conversations):
|
|||||||
new_conversations.append(conversation)
|
new_conversations.append(conversation)
|
||||||
return new_conversations
|
return new_conversations
|
||||||
|
|
||||||
class Seed16Agent:
|
class SeedAgent:
|
||||||
"""
|
"""
|
||||||
UI-TARS Agent based on Seed1.5-VL model implementation.
|
UI-TARS Agent based on Seed1.5-VL model implementation.
|
||||||
Integrates the GUI folder UI-TARS-1.5 implementation with the mm_agents architecture.
|
Integrates the GUI folder UI-TARS-1.5 implementation with the mm_agents architecture.
|
||||||
@@ -422,7 +423,7 @@ class Seed16Agent:
|
|||||||
self.platform = "ubuntu"
|
self.platform = "ubuntu"
|
||||||
self.use_thinking = use_thinking
|
self.use_thinking = use_thinking
|
||||||
|
|
||||||
self.inference_func = self.inference_with_thinking
|
self.inference_func = self.inference_with_thinking_ark
|
||||||
self.resize_image = resize_image
|
self.resize_image = resize_image
|
||||||
self.resized_image_width = resized_image_width
|
self.resized_image_width = resized_image_width
|
||||||
self.resized_image_height = resized_image_height
|
self.resized_image_height = resized_image_height
|
||||||
@@ -492,7 +493,6 @@ class Seed16Agent:
|
|||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(api_url, headers=headers, json=data)
|
response = requests.post(api_url, headers=headers, json=data)
|
||||||
print(response.json()["choices"][0])
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return response.json()["choices"][0]["message"]
|
return response.json()["choices"][0]["message"]
|
||||||
else:
|
else:
|
||||||
@@ -501,6 +501,53 @@ class Seed16Agent:
|
|||||||
"details": response.text
|
"details": response.text
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def inference_with_thinking_ark(self, openai_messages):
|
||||||
|
# 打印 Ark 的 URL 和 API Key
|
||||||
|
api_key = os.environ['DOUBAO_API_KEY']
|
||||||
|
api_url = os.environ['DOUBAO_API_URL']
|
||||||
|
|
||||||
|
# 初始化 Ark 实例
|
||||||
|
vlm = Ark(
|
||||||
|
base_url=api_url,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 调用 Ark 的 chat.completions.create 方法
|
||||||
|
completion = vlm.chat.completions.create(
|
||||||
|
model=self.model,
|
||||||
|
stream=True,
|
||||||
|
reasoning_effort='high',
|
||||||
|
messages=openai_messages,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化预测结果
|
||||||
|
think_token = "think_never_used_51bce0c785ca2f68081bfa7d91973934"
|
||||||
|
added_think_token = False
|
||||||
|
|
||||||
|
# 处理流式返回的结果
|
||||||
|
prediction = ''
|
||||||
|
reasoning_content = ''
|
||||||
|
content = ''
|
||||||
|
for chunk in completion:
|
||||||
|
if hasattr(chunk, 'choices') and chunk.choices:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
|
||||||
|
reasoning_content += delta.reasoning_content
|
||||||
|
if hasattr(delta, 'content') and delta.content:
|
||||||
|
if not added_think_token:
|
||||||
|
prediction += f"</{think_token}>"
|
||||||
|
added_think_token = True
|
||||||
|
content += delta.content
|
||||||
|
|
||||||
|
prediction = f"<{think_token}>" + reasoning_content + f"</{think_token}>" + content
|
||||||
|
|
||||||
|
# 返回预测结果
|
||||||
|
return prediction
|
||||||
|
|
||||||
def inference_without_thinking(self, messages):
|
def inference_without_thinking(self, messages):
|
||||||
api_key = os.environ['DOUBAO_API_KEY']
|
api_key = os.environ['DOUBAO_API_KEY']
|
||||||
api_url = os.environ['DOUBAO_API_URL']
|
api_url = os.environ['DOUBAO_API_URL']
|
||||||
@@ -599,7 +646,8 @@ class Seed16Agent:
|
|||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": history_response
|
"content": history_response.split("</think_never_used_51bce0c785ca2f68081bfa7d91973934>")[-1],
|
||||||
|
"reasoning_content": history_response.split("</think_never_used_51bce0c785ca2f68081bfa7d91973934>")[0].replace("<think_never_used_51bce0c785ca2f68081bfa7d91973934>", "")
|
||||||
})
|
})
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
@@ -621,17 +669,10 @@ class Seed16Agent:
|
|||||||
while True:
|
while True:
|
||||||
if try_times <= 0:
|
if try_times <= 0:
|
||||||
print(f"Reach max retry times to fetch response from client, as error flag.")
|
print(f"Reach max retry times to fetch response from client, as error flag.")
|
||||||
return prediction, ["FAIL"]
|
raise ValueError("Client error")
|
||||||
try:
|
try:
|
||||||
logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}")
|
logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}")
|
||||||
# json.dump(messages, open("debug_seed16.json", "w"), indent=4, ensure_ascii=False)
|
prediction = self.inference_func(messages)
|
||||||
response = self.inference_func(messages)
|
|
||||||
content = response["content"]
|
|
||||||
if "reasoning_content" in response and response["reasoning_content"]:
|
|
||||||
reasoning_content = response["reasoning_content"]
|
|
||||||
prediction = f"<think_never_used_51bce0c785ca2f68081bfa7d91973934>{reasoning_content}</think_never_used_51bce0c785ca2f68081bfa7d91973934>{content}"
|
|
||||||
else:
|
|
||||||
prediction = content
|
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -650,7 +691,7 @@ class Seed16Agent:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Parsing action error: {prediction}, with error:\n{e}")
|
print(f"Parsing action error: {prediction}, with error:\n{e}")
|
||||||
return prediction, ["FAIL"]
|
raise ValueError("Parsing action error")
|
||||||
|
|
||||||
thoughts = prediction.split("</think_never_used_51bce0c785ca2f68081bfa7d91973934>")[0]
|
thoughts = prediction.split("</think_never_used_51bce0c785ca2f68081bfa7d91973934>")[0]
|
||||||
self.thoughts.append(thoughts)
|
self.thoughts.append(thoughts)
|
||||||
@@ -12,7 +12,7 @@ from multiprocessing import Process, Manager
|
|||||||
from multiprocessing import current_process
|
from multiprocessing import current_process
|
||||||
import lib_run_single
|
import lib_run_single
|
||||||
from desktop_env.desktop_env import DesktopEnv
|
from desktop_env.desktop_env import DesktopEnv
|
||||||
from mm_agents.seed16 import Seed16Agent
|
from mm_agents.seed_agent import SeedAgent
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
@@ -184,7 +184,7 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
|||||||
client_password=args.client_password
|
client_password=args.client_password
|
||||||
)
|
)
|
||||||
active_environments.append(env)
|
active_environments.append(env)
|
||||||
agent = Seed16Agent(
|
agent = SeedAgent(
|
||||||
model=args.model,
|
model=args.model,
|
||||||
model_type=args.model_type,
|
model_type=args.model_type,
|
||||||
max_tokens=args.max_tokens,
|
max_tokens=args.max_tokens,
|
||||||
Reference in New Issue
Block a user