update aworldguiAgent code (#342)
This commit is contained in:
230
mm_agents/aworldguiagent/workflow.py
Normal file
230
mm_agents/aworldguiagent/workflow.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
This code is adapted from AgentS2 (https://github.com/simular-ai/Agent-S)
|
||||
with modifications to suit specific requirements.
|
||||
"""
|
||||
import logging
|
||||
import textwrap
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from aworld.config.conf import AgentConfig
|
||||
from aworld.agents.llm_agent import Agent
|
||||
from aworld.core.common import Observation
|
||||
|
||||
from aworld.core.task import Task
|
||||
from aworld.core.context.base import Context
|
||||
from aworld.core.event.base import Message
|
||||
from aworld.models.llm import get_llm_model
|
||||
from aworld.utils.common import sync_exec
|
||||
|
||||
from mm_agents.aworldguiagent.grounding import ACI
|
||||
from mm_agents.aworldguiagent.prompt import GENERATOR_SYS_PROMPT, REFLECTION_SYS_PROMPT
|
||||
from mm_agents.aworldguiagent.utils import encode_image, extract_first_agent_function, parse_single_code_from_string, sanitize_code
|
||||
from mm_agents.aworldguiagent.utils import prune_image_messages, reps_action_result
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
class Worker:
|
||||
def __init__(
|
||||
self,
|
||||
engine_params: Dict,
|
||||
grounding_agent: ACI,
|
||||
platform: str = "ubuntu",
|
||||
max_trajectory_length: int = 16,
|
||||
enable_reflection: bool = True,
|
||||
):
|
||||
"""
|
||||
Worker receives the main task and generates actions, without the need of hierarchical planning
|
||||
Args:
|
||||
engine_params: Dict
|
||||
Parameters for the multimodal engine
|
||||
grounding_agent: Agent
|
||||
The grounding agent to use
|
||||
platform: str
|
||||
OS platform the agent runs on (darwin, linux, windows)
|
||||
max_trajectory_length: int
|
||||
The amount of images turns to keep
|
||||
enable_reflection: bool
|
||||
Whether to enable reflection
|
||||
"""
|
||||
# super().__init__(engine_params, platform)
|
||||
|
||||
self.grounding_agent = grounding_agent
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.enable_reflection = enable_reflection
|
||||
self.use_thinking = engine_params.get("model", "") in [
|
||||
"claude-3-7-sonnet-20250219"
|
||||
]
|
||||
|
||||
self.generator_agent_config = AgentConfig(
|
||||
llm_provider=engine_params.get("engine_type", "openai"),
|
||||
llm_model_name=engine_params.get("model", "openai/o3",),
|
||||
llm_temperature=engine_params.get("temperature", 1.0),
|
||||
llm_base_url=engine_params.get("base_url", "https://openrouter.ai/api/v1"),
|
||||
llm_api_key=engine_params.get("api_key", ""),
|
||||
)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
|
||||
self.generator_agent = Agent(
|
||||
name="generator_agent",
|
||||
conf=self.generator_agent_config,
|
||||
system_prompt=GENERATOR_SYS_PROMPT,
|
||||
resp_parse_func=reps_action_result
|
||||
)
|
||||
|
||||
self.reflection_agent = Agent(
|
||||
name="reflection_agent",
|
||||
conf=self.generator_agent_config,
|
||||
system_prompt=REFLECTION_SYS_PROMPT,
|
||||
resp_parse_func=reps_action_result
|
||||
)
|
||||
|
||||
self.turn_count = 0
|
||||
self.worker_history = []
|
||||
self.reflections = []
|
||||
self.cost_this_turn = 0
|
||||
self.screenshot_inputs = []
|
||||
|
||||
self.dummy_task = Task()
|
||||
self.dummy_context = Context()
|
||||
self.dummy_context.set_task(self.dummy_task)
|
||||
self.dummy_message = Message(headers={'context': self.dummy_context})
|
||||
|
||||
self.planning_model = get_llm_model(self.generator_agent_config)
|
||||
|
||||
self.first_done = False
|
||||
self.first_image = None
|
||||
|
||||
def generate_next_action(
|
||||
self,
|
||||
instruction: str,
|
||||
obs: Dict,
|
||||
) -> Tuple[Dict, List]:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
agent = self.grounding_agent
|
||||
generator_message = (
|
||||
""
|
||||
if self.turn_count > 0
|
||||
else "The initial screen is provided. No action has been taken yet."
|
||||
)
|
||||
|
||||
# Load the task into the system prompt
|
||||
if self.turn_count == 0:
|
||||
self.generator_agent.system_prompt = self.generator_agent.system_prompt.replace(
|
||||
"TASK_DESCRIPTION", instruction)
|
||||
|
||||
# Get the per-step reflection
|
||||
reflection = None
|
||||
reflection_thoughts = None
|
||||
if self.enable_reflection:
|
||||
# Load the initial message
|
||||
if self.turn_count == 0:
|
||||
text_content = textwrap.dedent(
|
||||
f"""
|
||||
Task Description: {instruction}
|
||||
Current Trajectory below:
|
||||
"""
|
||||
)
|
||||
updated_sys_prompt = (
|
||||
self.reflection_agent.system_prompt + "\n" + text_content
|
||||
)
|
||||
self.reflection_agent.system_prompt = updated_sys_prompt
|
||||
|
||||
image_content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"The initial screen is provided. No action has been taken yet."
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "data:image/png;base64," + encode_image(obs["screenshot"])
|
||||
}
|
||||
}
|
||||
]
|
||||
self.reflection_agent._init_context(context=self.dummy_context)
|
||||
|
||||
sync_exec(
|
||||
self.reflection_agent._add_human_input_to_memory,
|
||||
image_content,
|
||||
self.dummy_context,
|
||||
"message"
|
||||
)
|
||||
|
||||
# Load the latest action
|
||||
else:
|
||||
|
||||
image = "data:image/png;base64," + encode_image(obs["screenshot"])
|
||||
reflection_message = self.worker_history[-1] + "\n" + f"Here is function execute result: {obs['action_response']}.\n"
|
||||
|
||||
reflection_observation = Observation(content=reflection_message, image=image)
|
||||
|
||||
self.reflection_agent._init_context(context=self.dummy_context)
|
||||
reflection_actions = self.reflection_agent.policy(reflection_observation, message=self.dummy_message)
|
||||
|
||||
reflection = reflection_actions[0].action_name
|
||||
reflection_thoughts = reflection_actions[0].policy_info
|
||||
|
||||
self.reflections.append(reflection)
|
||||
|
||||
generator_message += f"Here is your function execute result: {obs['action_response']}.\n"
|
||||
|
||||
generator_message += f"REFLECTION: You may use this reflection on the previous action and overall trajectory:\n{reflection}\n"
|
||||
logger.info("REFLECTION: %s", reflection)
|
||||
|
||||
if self.first_done:
|
||||
pass
|
||||
|
||||
else:
|
||||
# Add finalized message to conversation
|
||||
generator_message += f"\nCurrent Text Buffer = [{','.join(agent.notes)}]\n"
|
||||
|
||||
image = "data:image/png;base64," + encode_image(obs["screenshot"])
|
||||
generator_observation = Observation(content=generator_message, image=image)
|
||||
|
||||
self.generator_agent._init_context(context=self.dummy_context)
|
||||
generator_actions = self.generator_agent.policy(generator_observation, message=self.dummy_message)
|
||||
|
||||
plan = generator_actions[0].action_name
|
||||
plan_thoughts = generator_actions[0].policy_info
|
||||
|
||||
prune_image_messages(self.generator_agent.memory.memory_store, 16)
|
||||
prune_image_messages(self.reflection_agent.memory.memory_store, 16)
|
||||
|
||||
self.worker_history.append(plan)
|
||||
|
||||
logger.info("FULL PLAN:\n %s", plan)
|
||||
|
||||
# self.generator_agent.add_message(plan, role="assistant")
|
||||
# Use the grounding agent to convert agent_action("desc") into agent_action([x, y])
|
||||
|
||||
try:
|
||||
agent.assign_coordinates(plan, obs)
|
||||
plan_code = parse_single_code_from_string(plan.split("Grounded Action")[-1])
|
||||
plan_code = sanitize_code(plan_code)
|
||||
plan_code = extract_first_agent_function(plan_code)
|
||||
exec_code = eval(plan_code)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error in parsing plan code: %s", e)
|
||||
plan_code = "agent.wait(1.0)"
|
||||
exec_code = eval(plan_code)
|
||||
|
||||
executor_info = {
|
||||
"full_plan": plan,
|
||||
"executor_plan": plan,
|
||||
"plan_thoughts": plan_thoughts,
|
||||
"plan_code": plan_code,
|
||||
"reflection": reflection,
|
||||
"reflection_thoughts": reflection_thoughts,
|
||||
}
|
||||
self.turn_count += 1
|
||||
|
||||
self.screenshot_inputs.append(obs["screenshot"])
|
||||
|
||||
return executor_info, [exec_code]
|
||||
Reference in New Issue
Block a user