Add EvoCUA Support (#401)
* evocua init * setup max_token --------- Co-authored-by: xuetaofeng <xuetaofeng@meituan.com> Co-authored-by: Tianbao Xie <47296835+Timothyxxx@users.noreply.github.com>
This commit is contained in:
@@ -462,82 +462,102 @@ def run_single_example_uipath(agent, env, example, max_steps, instruction, args,
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
from mm_agents.os_symphony.utils.common_utils import draw_coordinates
|
||||
from mm_agents.os_symphony.utils.process_context import set_current_result_dir
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
def run_single_example_os_symphony(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
set_current_result_dir(example_result_dir)
|
||||
def run_single_example_evocua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
"""
|
||||
Unified run function for EvoCUAAgent (supporting both S1 and S2 modes).
|
||||
"""
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
|
||||
agent.reset(result_dir=example_result_dir)
|
||||
# Reset Environment
|
||||
env.reset(task_config=example)
|
||||
time.sleep(30) # Wait for the environment to be ready
|
||||
|
||||
# Reset Agent
|
||||
# Handle agent reset signature differences if any
|
||||
try:
|
||||
agent.reset(runtime_logger, vm_ip=env.vm_ip)
|
||||
except Exception:
|
||||
try:
|
||||
agent.reset(runtime_logger)
|
||||
except Exception:
|
||||
agent.reset()
|
||||
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
# env.controller.start_recording()
|
||||
start_time = time.time()
|
||||
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs,
|
||||
step_idx == max_steps - 1
|
||||
)
|
||||
# EvoCUAAgent.predict unified signature: returns (response, actions)
|
||||
# It handles both modes internally.
|
||||
predict_res = agent.predict(instruction, obs)
|
||||
|
||||
# Check return signature logic
|
||||
if len(predict_res) == 3:
|
||||
# Compatibility with S1 original signature if agent was updated to match
|
||||
response, actions, info_dict = predict_res
|
||||
else:
|
||||
response, actions = predict_res
|
||||
info_dict = {}
|
||||
|
||||
logger.info(f"Step {step_idx + 1} Actions: {actions}")
|
||||
|
||||
# Break if no actions (fail-safe)
|
||||
if not actions or (len(actions) == 1 and (actions[0] == "" or "error" in actions[0].lower())):
|
||||
# Allow "FAIL" or "DONE" to process through execution loop if agent outputs them as actions
|
||||
if not (actions and actions[0] in ["FAIL", "DONE"]):
|
||||
logger.warning("No valid actions returned. Breaking loop.")
|
||||
break
|
||||
|
||||
for action in actions:
|
||||
# Save screenshot and trajectory information
|
||||
if "reflection" in response and response["reflection"].get("is_milestone"):
|
||||
img_name = f"step_{step_idx + 1}_milestone.png"
|
||||
else:
|
||||
img_name = f"step_{step_idx + 1}.png"
|
||||
|
||||
with open(os.path.join(example_result_dir, img_name),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
if "coordinates" in response and response["coordinates"]:
|
||||
draw_coordinates(
|
||||
image_bytes=obs['screenshot'],
|
||||
coordinates=response["coordinates"],
|
||||
save_path=os.path.join(example_result_dir, img_name[:-4] + "_draw.png")
|
||||
)
|
||||
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S%f")
|
||||
logger.info("Executing action: %s", action)
|
||||
|
||||
# Execute
|
||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
|
||||
|
||||
# Save screenshot
|
||||
screenshot_file = f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
with open(os.path.join(example_result_dir, screenshot_file), "wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
# Log Trajectory
|
||||
log_entry = {
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": screenshot_file
|
||||
}
|
||||
# Add natural language info if available (S1 style)
|
||||
if info_dict:
|
||||
log_entry["natural_language_action"] = info_dict.get("action")
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"instruction": instruction,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}))
|
||||
f.write(json.dumps(log_entry, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
with open(os.path.join(example_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": img_name
|
||||
}, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
time.sleep(60)
|
||||
break
|
||||
|
||||
step_idx += 1
|
||||
end_time = time.time()
|
||||
result = float(env.evaluate())
|
||||
|
||||
time.sleep(20) # Wait for environment to settle
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
log_task_completion(example, result, example_result_dir, args)
|
||||
|
||||
with open(os.path.join(example_result_dir, "time.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{end_time-start_time:.2f}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
619
mm_agents/evocua/evocua_agent.py
Normal file
619
mm_agents/evocua/evocua_agent.py
Normal file
@@ -0,0 +1,619 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import logging
|
||||
import backoff
|
||||
import openai
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from mm_agents.evocua.utils import (
|
||||
process_image,
|
||||
encode_image,
|
||||
rewrite_pyautogui_text_inputs,
|
||||
project_coordinate_to_absolute_scale,
|
||||
log_messages
|
||||
)
|
||||
|
||||
from mm_agents.evocua.prompts import (
|
||||
S1_SYSTEM_PROMPT,
|
||||
S1_INSTRUTION_TEMPLATE,
|
||||
S1_STEP_TEMPLATE,
|
||||
S1_ACTION_HISTORY_TEMPLATE,
|
||||
S2_ACTION_DESCRIPTION,
|
||||
S2_DESCRIPTION_PROMPT_TEMPLATE,
|
||||
S2_SYSTEM_PROMPT,
|
||||
build_s2_tools_def
|
||||
)
|
||||
|
||||
logger = logging.getLogger("desktopenv.evocua")
|
||||
|
||||
class EvoCUAAgent:
|
||||
"""
|
||||
EvoCUA - A Native GUI agent model for desktop automation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "EvoCUA-S2",
|
||||
max_tokens: int = 32768,
|
||||
top_p: float = 0.9,
|
||||
temperature: float = 0.0,
|
||||
action_space: str = "pyautogui",
|
||||
observation_type: str = "screenshot",
|
||||
max_steps: int = 50,
|
||||
prompt_style: str = "S2", # "S1" or "S2"
|
||||
max_history_turns: int = 4,
|
||||
screen_size: Tuple[int, int] = (1920, 1080),
|
||||
coordinate_type: str = "relative",
|
||||
password: str = "osworld-public-evaluation",
|
||||
resize_factor: int = 32,
|
||||
**kwargs
|
||||
):
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_steps = max_steps
|
||||
|
||||
self.prompt_style = prompt_style
|
||||
assert self.prompt_style in ["S1", "S2"], f"Invalid prompt_style: {self.prompt_style}"
|
||||
|
||||
self.max_history_turns = max_history_turns
|
||||
|
||||
self.screen_size = screen_size
|
||||
self.coordinate_type = coordinate_type
|
||||
self.password = password
|
||||
self.resize_factor = resize_factor
|
||||
|
||||
# Action space assertion
|
||||
assert self.action_space == "pyautogui", f"Invalid action space: {self.action_space}"
|
||||
assert self.observation_type == "screenshot", f"Invalid observation type: {self.observation_type}"
|
||||
|
||||
# State
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.responses = []
|
||||
self.screenshots = [] # Stores encoded string
|
||||
self.cots = [] # For S1 style history
|
||||
|
||||
def reset(self, _logger=None, vm_ip=None):
|
||||
global logger
|
||||
if _logger:
|
||||
logger = _logger
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.responses = []
|
||||
self.screenshots = []
|
||||
self.cots = []
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
"""
|
||||
Main prediction loop.
|
||||
"""
|
||||
|
||||
logger.info(f"========================== {self.model} ===================================")
|
||||
logger.info(f"Instruction: \n{instruction}")
|
||||
|
||||
screenshot_bytes = obs["screenshot"]
|
||||
|
||||
try:
|
||||
original_img = Image.open(BytesIO(screenshot_bytes))
|
||||
original_width, original_height = original_img.size
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read screenshot size, falling back to screen_size: {e}")
|
||||
original_width, original_height = self.screen_size
|
||||
|
||||
if self.prompt_style == "S1":
|
||||
raw_b64 = encode_image(screenshot_bytes)
|
||||
self.screenshots.append(raw_b64)
|
||||
return self._predict_s1(instruction, obs, raw_b64)
|
||||
else:
|
||||
processed_b64, p_width, p_height = process_image(screenshot_bytes, factor=self.resize_factor)
|
||||
self.screenshots.append(processed_b64)
|
||||
return self._predict_s2(
|
||||
instruction,
|
||||
obs,
|
||||
processed_b64,
|
||||
p_width,
|
||||
p_height,
|
||||
original_width,
|
||||
original_height,
|
||||
)
|
||||
|
||||
|
||||
def _predict_s2(self, instruction, obs, processed_b64, p_width, p_height, original_width, original_height):
|
||||
current_step = len(self.actions)
|
||||
current_history_n = self.max_history_turns
|
||||
|
||||
response = None
|
||||
|
||||
if self.coordinate_type == "absolute":
|
||||
resolution_info = f"* The screen's resolution is {p_width}x{p_height}."
|
||||
else:
|
||||
resolution_info = "* The screen's resolution is 1000x1000."
|
||||
|
||||
description_prompt = S2_DESCRIPTION_PROMPT_TEMPLATE.format(resolution_info=resolution_info)
|
||||
|
||||
tools_def = build_s2_tools_def(description_prompt)
|
||||
|
||||
system_prompt = S2_SYSTEM_PROMPT.format(tools_xml=json.dumps(tools_def))
|
||||
|
||||
# Retry loop for context length
|
||||
while True:
|
||||
messages = self._build_s2_messages(
|
||||
instruction,
|
||||
processed_b64,
|
||||
current_step,
|
||||
current_history_n,
|
||||
system_prompt
|
||||
)
|
||||
|
||||
try:
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
# Handle Context Too Large
|
||||
if self._should_giveup_on_context_error(e) and current_history_n > 0:
|
||||
current_history_n -= 1
|
||||
logger.warning(f"Context too large, retrying with history_n={current_history_n}")
|
||||
else:
|
||||
logger.error(f"Error in predict: {e}")
|
||||
break
|
||||
|
||||
self.responses.append(response)
|
||||
|
||||
low_level_instruction, pyautogui_code = self._parse_response_s2(
|
||||
response, p_width, p_height, original_width, original_height
|
||||
)
|
||||
|
||||
# new added
|
||||
current_step = len(self.actions) + 1
|
||||
first_action = pyautogui_code[0] if pyautogui_code else ""
|
||||
if current_step >= self.max_steps and str(first_action).upper() not in ("DONE", "FAIL"):
|
||||
logger.warning(f"Reached maximum steps {self.max_steps}. Forcing termination with FAIL.")
|
||||
low_level_instruction = "Fail the task because reaching the maximum step limit."
|
||||
pyautogui_code = ["FAIL"]
|
||||
|
||||
logger.info(f"Low level instruction: {low_level_instruction}")
|
||||
logger.info(f"Pyautogui code: {pyautogui_code}")
|
||||
|
||||
self.actions.append(low_level_instruction)
|
||||
return response, pyautogui_code
|
||||
|
||||
def _build_s2_messages(self, instruction, current_img, step, history_n, system_prompt):
|
||||
messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}]
|
||||
|
||||
previous_actions = []
|
||||
history_start_idx = max(0, step - history_n)
|
||||
for i in range(history_start_idx):
|
||||
if i < len(self.actions):
|
||||
previous_actions.append(f"Step {i+1}: {self.actions[i]}")
|
||||
previous_actions_str = "\n".join(previous_actions) if previous_actions else "None"
|
||||
|
||||
# Add History
|
||||
history_len = min(history_n, len(self.responses))
|
||||
if history_len > 0:
|
||||
hist_responses = self.responses[-history_len:]
|
||||
hist_imgs = self.screenshots[-history_len-1:-1]
|
||||
|
||||
for i in range(history_len):
|
||||
if i < len(hist_imgs):
|
||||
screenshot_b64 = hist_imgs[i]
|
||||
if i == 0:
|
||||
# First history item: Inject Instruction + Previous Actions Context
|
||||
img_url = f"data:image/png;base64,{screenshot_b64}"
|
||||
instruction_prompt = f"""
|
||||
Please generate the next move according to the UI screenshot, instruction and previous actions.
|
||||
|
||||
Instruction: {instruction}
|
||||
|
||||
Previous actions:
|
||||
{previous_actions_str}"""
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": img_url}},
|
||||
{"type": "text", "text": instruction_prompt}
|
||||
]
|
||||
})
|
||||
else:
|
||||
img_url = f"data:image/png;base64,{screenshot_b64}"
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": img_url}},
|
||||
]
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": hist_responses[i]}]
|
||||
})
|
||||
|
||||
# Current Turn
|
||||
# We re-use previous_actions_str logic for the case where history_len == 0
|
||||
|
||||
if history_len == 0:
|
||||
# First turn logic: Include Instruction + Previous Actions
|
||||
instruction_prompt = f"""
|
||||
Please generate the next move according to the UI screenshot, instruction and previous actions.
|
||||
|
||||
Instruction: {instruction}
|
||||
|
||||
Previous actions:
|
||||
{previous_actions_str}"""
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{current_img}"}},
|
||||
{"type": "text", "text": instruction_prompt}
|
||||
]
|
||||
})
|
||||
else:
|
||||
# Subsequent turns logic (context already in first history message): Image Only
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{current_img}"}}
|
||||
]
|
||||
})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _parse_response_s2(
|
||||
self,
|
||||
response: str,
|
||||
processed_width: int = None,
|
||||
processed_height: int = None,
|
||||
original_width: Optional[int] = None,
|
||||
original_height: Optional[int] = None,
|
||||
) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Parse LLM response and convert it to low level action and pyautogui code.
|
||||
"""
|
||||
# Prefer the real screenshot resolution (passed from predict), fallback to configured screen_size.
|
||||
if not (original_width and original_height):
|
||||
original_width, original_height = self.screen_size
|
||||
low_level_instruction = ""
|
||||
pyautogui_code: List[str] = []
|
||||
|
||||
if response is None or not response.strip():
|
||||
return low_level_instruction, pyautogui_code
|
||||
|
||||
def adjust_coordinates(x: float, y: float) -> Tuple[int, int]:
|
||||
if not (original_width and original_height):
|
||||
return int(x), int(y)
|
||||
if self.coordinate_type == "absolute":
|
||||
# scale from processed pixels to original
|
||||
if processed_width and processed_height:
|
||||
x_scale = original_width / processed_width
|
||||
y_scale = original_height / processed_height
|
||||
return int(x * x_scale), int(y * y_scale)
|
||||
return int(x), int(y)
|
||||
# relative: scale from 0..999 grid
|
||||
x_scale = original_width / 999
|
||||
y_scale = original_height / 999
|
||||
return int(x * x_scale), int(y * y_scale)
|
||||
|
||||
def process_tool_call(json_str: str) -> None:
|
||||
try:
|
||||
tool_call = json.loads(json_str)
|
||||
if tool_call.get("name") == "computer_use":
|
||||
args = tool_call["arguments"]
|
||||
action = args["action"]
|
||||
|
||||
if action == "left_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(f"pyautogui.click({adj_x}, {adj_y})")
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.click()")
|
||||
|
||||
elif action == "right_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.rightClick({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.rightClick()")
|
||||
|
||||
elif action == "middle_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.middleClick({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.middleClick()")
|
||||
|
||||
elif action == "double_click":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.doubleClick({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.doubleClick()")
|
||||
|
||||
elif action == "type":
|
||||
text = args.get("text", "")
|
||||
|
||||
try:
|
||||
text = text.encode('latin-1', 'backslashreplace').decode('unicode_escape')
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to unescape text: {e}")
|
||||
|
||||
logger.info(f"Pyautogui code[before rewrite]: {text}")
|
||||
|
||||
result = ""
|
||||
for char in text:
|
||||
if char == '\n':
|
||||
result += "pyautogui.press('enter')\n"
|
||||
elif char == "'":
|
||||
result += 'pyautogui.press("\'")\n'
|
||||
elif char == '\\':
|
||||
result += "pyautogui.press('\\\\')\n"
|
||||
elif char == '"':
|
||||
result += "pyautogui.press('\"')\n"
|
||||
else:
|
||||
result += f"pyautogui.press('{char}')\n"
|
||||
|
||||
pyautogui_code.append(result)
|
||||
logger.info(f"Pyautogui code[after rewrite]: {pyautogui_code}")
|
||||
|
||||
|
||||
elif action == "key":
|
||||
keys = args.get("keys", [])
|
||||
if isinstance(keys, list):
|
||||
cleaned_keys = []
|
||||
for key in keys:
|
||||
if isinstance(key, str):
|
||||
if key.startswith("keys=["):
|
||||
key = key[6:]
|
||||
if key.endswith("]"):
|
||||
key = key[:-1]
|
||||
if key.startswith("['") or key.startswith('["'):
|
||||
key = key[2:] if len(key) > 2 else key
|
||||
if key.endswith("']") or key.endswith('"]'):
|
||||
key = key[:-2] if len(key) > 2 else key
|
||||
key = key.strip()
|
||||
cleaned_keys.append(key)
|
||||
else:
|
||||
cleaned_keys.append(key)
|
||||
keys = cleaned_keys
|
||||
|
||||
keys_str = ", ".join([f"'{key}'" for key in keys])
|
||||
if len(keys) > 1:
|
||||
pyautogui_code.append(f"pyautogui.hotkey({keys_str})")
|
||||
else:
|
||||
pyautogui_code.append(f"pyautogui.press({keys_str})")
|
||||
|
||||
elif action == "scroll":
|
||||
pixels = args.get("pixels", 0)
|
||||
pyautogui_code.append(f"pyautogui.scroll({pixels})")
|
||||
|
||||
elif action == "wait":
|
||||
pyautogui_code.append("WAIT")
|
||||
|
||||
elif action == "terminate":
|
||||
pyautogui_code.append("DONE")
|
||||
|
||||
elif action == "mouse_move":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.moveTo({adj_x}, {adj_y})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.moveTo(0, 0)")
|
||||
|
||||
elif action == "left_click_drag":
|
||||
if "coordinate" in args:
|
||||
x, y = args["coordinate"]
|
||||
adj_x, adj_y = adjust_coordinates(x, y)
|
||||
duration = args.get("duration", 0.5)
|
||||
pyautogui_code.append(
|
||||
f"pyautogui.dragTo({adj_x}, {adj_y}, duration={duration})"
|
||||
)
|
||||
else:
|
||||
pyautogui_code.append("pyautogui.dragTo(0, 0)")
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.error(f"Failed to parse tool call: {e}")
|
||||
|
||||
lines = response.split("\n")
|
||||
inside_tool_call = False
|
||||
current_tool_call: List[str] = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.lower().startswith(("action:")):
|
||||
if not low_level_instruction:
|
||||
low_level_instruction = line.split("Action:")[-1].strip()
|
||||
continue
|
||||
|
||||
if line.startswith("<tool_call>"):
|
||||
inside_tool_call = True
|
||||
continue
|
||||
elif line.startswith("</tool_call>"):
|
||||
if current_tool_call:
|
||||
process_tool_call("\n".join(current_tool_call))
|
||||
current_tool_call = []
|
||||
inside_tool_call = False
|
||||
continue
|
||||
|
||||
if inside_tool_call:
|
||||
current_tool_call.append(line)
|
||||
continue
|
||||
|
||||
if line.startswith("{") and line.endswith("}"):
|
||||
try:
|
||||
json_obj = json.loads(line)
|
||||
if "name" in json_obj and "arguments" in json_obj:
|
||||
process_tool_call(line)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
if current_tool_call:
|
||||
process_tool_call("\n".join(current_tool_call))
|
||||
|
||||
if not low_level_instruction and len(pyautogui_code) > 0:
|
||||
action_type = pyautogui_code[0].split(".", 1)[1].split("(", 1)[0]
|
||||
low_level_instruction = f"Performing {action_type} action"
|
||||
|
||||
return low_level_instruction, pyautogui_code
|
||||
|
||||
|
||||
|
||||
def _predict_s1(self, instruction, obs, processed_b64):
|
||||
messages = [{"role": "system", "content": S1_SYSTEM_PROMPT.format(password=self.password)}]
|
||||
|
||||
# Reconstruct History Logic for S1 mode
|
||||
history_step_texts = []
|
||||
|
||||
for i in range(len(self.actions)):
|
||||
cot = self.cots[i] if i < len(self.cots) else {}
|
||||
|
||||
# Step Content string
|
||||
step_content = S1_STEP_TEMPLATE.format(step_num=i+1) + S1_ACTION_HISTORY_TEMPLATE.format(action=cot.get('action', ''))
|
||||
|
||||
if i > len(self.actions) - self.max_history_turns:
|
||||
# Recent history: Add User(Image) and Assistant(Text)
|
||||
if i < len(self.screenshots) - 1: # Screenshot exists for this step
|
||||
img = self.screenshots[i]
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img}"}}
|
||||
]
|
||||
})
|
||||
messages.append({"role": "assistant", "content": step_content})
|
||||
else:
|
||||
# Old history: Collect text
|
||||
history_step_texts.append(step_content)
|
||||
# If this is the last step before the recent window, flush collected texts
|
||||
if i == len(self.actions) - self.max_history_turns:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": "\n".join(history_step_texts)
|
||||
})
|
||||
|
||||
# Current
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{processed_b64}"}},
|
||||
{"type": "text", "text": S1_INSTRUTION_TEMPLATE.format(instruction=instruction)}
|
||||
]
|
||||
})
|
||||
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
})
|
||||
|
||||
low_level, codes, cot_data = self._parse_response_s1(response)
|
||||
|
||||
self.observations.append(obs)
|
||||
self.cots.append(cot_data)
|
||||
self.actions.append(low_level)
|
||||
self.responses.append(response)
|
||||
|
||||
return response, codes
|
||||
|
||||
|
||||
def _parse_response_s1(self, response):
|
||||
sections = {}
|
||||
# Simple Regex Parsing
|
||||
for key, pattern in [
|
||||
('observation', r'#{1,2}\s*Observation\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)'),
|
||||
('thought', r'#{1,2}\s*Thought\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)'),
|
||||
('action', r'#{1,2}\s*Action\s*:?[\n\r]+(.*?)(?=^#{1,2}\s|$)')
|
||||
]:
|
||||
m = re.search(pattern, response, re.DOTALL | re.MULTILINE)
|
||||
if m: sections[key] = m.group(1).strip()
|
||||
|
||||
code_blocks = re.findall(r'```(?:code|python)?\s*(.*?)\s*```', response, re.DOTALL | re.IGNORECASE)
|
||||
code = code_blocks[-1].strip() if code_blocks else "FAIL"
|
||||
|
||||
sections['code'] = code
|
||||
|
||||
# Post-process code
|
||||
if "computer.terminate" in code:
|
||||
final_code = ["DONE"] if "success" in code.lower() else ["FAIL"]
|
||||
elif "computer.wait" in code:
|
||||
final_code = ["WAIT"]
|
||||
else:
|
||||
# Project coordinates
|
||||
code = project_coordinate_to_absolute_scale(
|
||||
code,
|
||||
self.screen_size[0],
|
||||
self.screen_size[1],
|
||||
self.coordinate_type,
|
||||
self.resize_factor
|
||||
)
|
||||
logger.info(f"[rewrite before]: {code}")
|
||||
final_code = [rewrite_pyautogui_text_inputs(code)]
|
||||
logger.info(f"[rewrite after]: {final_code}")
|
||||
|
||||
return sections.get('action', 'Acting'), final_code, sections
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _should_giveup_on_context_error(e):
|
||||
"""对于 context length 相关的错误,立即放弃重试,交给外层处理"""
|
||||
error_str = str(e)
|
||||
return "Too Large" in error_str or "context_length_exceeded" in error_str or "413" in error_str
|
||||
|
||||
@backoff.on_exception(backoff.constant, Exception, interval=30, max_tries=10, giveup=_should_giveup_on_context_error.__func__)
|
||||
def call_llm(self, payload):
|
||||
"""Unified OpenAI-compatible API call"""
|
||||
# Get env vars
|
||||
base_url = os.environ.get("OPENAI_BASE_URL", "url-xxx")
|
||||
api_key = os.environ.get("OPENAI_API_KEY", "sk-xxx")
|
||||
|
||||
client = openai.OpenAI(base_url=base_url, api_key=api_key)
|
||||
|
||||
messages = payload["messages"]
|
||||
log_messages(messages, "LLM Request")
|
||||
|
||||
params = {
|
||||
"model": payload["model"],
|
||||
"messages": messages,
|
||||
"max_tokens": payload["max_tokens"],
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p
|
||||
}
|
||||
|
||||
try:
|
||||
resp = client.chat.completions.create(**params)
|
||||
content = resp.choices[0].message.content
|
||||
logger.info(f"LLM Response:\n{content}")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.error(f"LLM Call failed: {e}")
|
||||
raise e
|
||||
145
mm_agents/evocua/prompts.py
Normal file
145
mm_agents/evocua/prompts.py
Normal file
@@ -0,0 +1,145 @@
|
||||
S1_SYSTEM_PROMPT = """You are a GUI agent. You are given a task, a screenshot of the screen and your previous interactions with the computer. You need to perform a series of actions to complete the task. The password of the computer is "{password}", use it when you need sudo rights. You need to **wait** explicitly for installation, waiting website loading or running commands to finish. Don't terminate the task unless you are sure the task is finished. If you find that you can't finish the task, or the task is not finished exactly as the instruction indicates (you have made progress but not finished the task completely), or the task is impossible to complete, you must report **failure**.
|
||||
|
||||
For each step, provide your response in this format:
|
||||
# Step: {{step number}}
|
||||
## Thought:
|
||||
{{thought}}
|
||||
## Action:
|
||||
{{action}}
|
||||
## Code:
|
||||
{{code}}
|
||||
|
||||
For the Thought section, you should include the following parts:
|
||||
- Reflection on the task when there is previous action:
|
||||
- Consider the correnctness of previous action and its outcomes
|
||||
- If the previous action was correct, describe the change in the state of the computer and reason
|
||||
- If the previous action was incorrect, reflect on what went wrong and why
|
||||
- Step by Step Progress Assessment:
|
||||
- Add necessary information according to the history screenshots, former actions and current screenshot.
|
||||
- Analyze what parts of the task have already been completed and how they contribute to the overall goal.
|
||||
- Make a plan on how to complete the task based on the history and currect screenshot.
|
||||
- Next Action Prediction:
|
||||
- Propose the most possible next action and state the reason
|
||||
- For Text Input Actions:
|
||||
- Note current cursor position
|
||||
- Consolidate repetitive actions (specify count for multiple keypresses)
|
||||
- Describe expected final text outcome
|
||||
- Use first-person perspective in reasoning
|
||||
|
||||
For the action section, you should provide clear, concise, and actionable instructions in one sentence.
|
||||
- If the action involves interacting with a specific target:
|
||||
- Describe target explicitly (if multiple elements share that name, you should distinguish the target) without using coordinates
|
||||
- Specify element names when possible (use original language if non-English)
|
||||
- Describe features (shape, color, position) if name unavailable
|
||||
- If the action involves keyboard actions like 'press', 'write', 'hotkey':
|
||||
- Consolidate repetitive keypresses with count
|
||||
- Specify expected text outcome for typing actions
|
||||
|
||||
For the code section, you should output the corresponding code for the action. The code should be either PyAutoGUI code or one of the following functions warped in the code block:
|
||||
- {{"name": "computer.wait", "description": "Make the computer wait for 20 seconds for installation, running code, etc.", "parameters": {{"type": "object", "properties": {{}}, "required": []}}}}
|
||||
- {{"name": "computer.terminate", "description": "Terminate the current task and report its completion status", "parameters": {{"type": "object", "properties": {{"status": {{"type": "string", "enum": ["success", "failure"], "description": "The status of the task"}}, {{"answer": {{"type": "string", "description": "The answer of the task"}}}}, "required": ["status"]}}}}
|
||||
Examples for the code section:
|
||||
```python
|
||||
pyautogui.click(x=123, y=456)
|
||||
```
|
||||
```code
|
||||
computer.terminate(status="success")
|
||||
```
|
||||
```code
|
||||
computer.terminate(status="success", answer='''text''')
|
||||
```"""
|
||||
|
||||
|
||||
# S1 prompt templates for generating trajectories
|
||||
S1_STEP_TEMPLATE = "# Step {step_num}:\n"
|
||||
S1_INSTRUTION_TEMPLATE = "# Task Instruction:\n{instruction}\n\nPlease generate the next move according to the screenshot, task instruction and previous steps (if provided).\n"
|
||||
|
||||
S1_ACTION_HISTORY_TEMPLATE = "## Action:\n{action}\n"
|
||||
|
||||
|
||||
# S2 Prompts
|
||||
S2_ACTION_DESCRIPTION = """
|
||||
* `key`: Performs key down presses on the arguments passed in order, then performs key releases in reverse order.
|
||||
* `type`: Type a string of text on the keyboard.
|
||||
* `mouse_move`: Move the cursor to a specified (x, y) pixel coordinate on the screen.
|
||||
* `left_click`: Click the left mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `left_click_drag`: Click and drag the cursor to a specified (x, y) pixel coordinate on the screen.
|
||||
* `right_click`: Click the right mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `middle_click`: Click the middle mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `double_click`: Double-click the left mouse button at a specified (x, y) pixel coordinate on the screen.
|
||||
* `triple_click`: Triple-click the left mouse button at a specified (x, y) pixel coordinate on the screen (simulated as double-click since it's the closest action).
|
||||
* `scroll`: Performs a scroll of the mouse scroll wheel.
|
||||
* `hscroll`: Performs a horizontal scroll (mapped to regular scroll).
|
||||
* `wait`: Wait specified seconds for the change to happen.
|
||||
* `terminate`: Terminate the current task and report its completion status.
|
||||
* `answer`: Answer a question.
|
||||
"""
|
||||
|
||||
S2_DESCRIPTION_PROMPT_TEMPLATE = """Use a mouse and keyboard to interact with a computer, and take screenshots.
|
||||
* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.
|
||||
* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.
|
||||
{resolution_info}
|
||||
* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.
|
||||
* If you tried clicking on a program or link but it failed to load even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.
|
||||
* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked."""
|
||||
|
||||
S2_SYSTEM_PROMPT = """# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{tools_xml}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{{"name": <function-name>, "arguments": <args-json-object>}}
|
||||
</tool_call>
|
||||
|
||||
# Response format
|
||||
|
||||
Response format for every step:
|
||||
1) Action: a short imperative describing what to do in the UI.
|
||||
2) A single <tool_call>...</tool_call> block containing only the JSON: {{"name": <function-name>, "arguments": <args-json-object>}}.
|
||||
|
||||
Rules:
|
||||
- Output exactly in the order: Action, <tool_call>.
|
||||
- Be brief: one sentence for Action.
|
||||
- Do not output anything else outside those parts.
|
||||
- If finishing, use action=terminate in the tool call."""
|
||||
|
||||
|
||||
def build_s2_tools_def(description_prompt):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name_for_human": "computer_use",
|
||||
"name": "computer_use",
|
||||
"description": description_prompt,
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"action": {
|
||||
"description": S2_ACTION_DESCRIPTION,
|
||||
"enum": ["key", "type", "mouse_move", "left_click", "left_click_drag",
|
||||
"right_click", "middle_click", "double_click", "scroll", "wait", "terminate"],
|
||||
"type": "string"
|
||||
},
|
||||
"keys": {"description": "Required only by `action=key`.", "type": "array"},
|
||||
"text": {"description": "Required only by `action=type`.", "type": "string"},
|
||||
"coordinate": {"description": "The x,y coordinates for mouse actions.", "type": "array"},
|
||||
"pixels": {"description": "The amount of scrolling.", "type": "number"},
|
||||
"time": {"description": "The seconds to wait.", "type": "number"},
|
||||
"status": {
|
||||
"description": "The status of the task.",
|
||||
"type": "string",
|
||||
"enum": ["success", "failure"]
|
||||
}
|
||||
},
|
||||
"required": ["action"],
|
||||
"type": "object"
|
||||
},
|
||||
"args_format": "Format the arguments as a JSON object."
|
||||
}
|
||||
}
|
||||
|
||||
302
mm_agents/evocua/utils.py
Normal file
302
mm_agents/evocua/utils.py
Normal file
@@ -0,0 +1,302 @@
|
||||
import base64
|
||||
import re
|
||||
import ast
|
||||
import logging
|
||||
from io import BytesIO
|
||||
import json
|
||||
from PIL import Image
|
||||
|
||||
from mm_agents.utils.qwen_vl_utils import smart_resize
|
||||
|
||||
logger = logging.getLogger("desktopenv.evocua.utils")
|
||||
|
||||
def encode_image(image_content):
|
||||
"""Encode image bytes to base64 string."""
|
||||
return base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
|
||||
def process_image(image_bytes, factor=32):
|
||||
"""
|
||||
Process an image for VL models.
|
||||
factor: 32 for S2 mode, 28 for S1 mode default
|
||||
"""
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
|
||||
resized_height, resized_width = smart_resize(
|
||||
height=height,
|
||||
width=width,
|
||||
factor=factor,
|
||||
max_pixels=16 * 16 * 4 * 12800, # Large buffer
|
||||
)
|
||||
|
||||
image = image.resize((resized_width, resized_height))
|
||||
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
processed_bytes = buffer.getvalue()
|
||||
|
||||
return base64.b64encode(processed_bytes).decode("utf-8"), resized_width, resized_height
|
||||
|
||||
def _fallback_rewrite_pyautogui_text_inputs(code: str) -> str:
|
||||
"""
|
||||
Regex-based fallback to handle malformed pyautogui.write/typewrite calls.
|
||||
"""
|
||||
logger.info(f"SyntaxError detected in code, using regex fallback. Original code: {code}")
|
||||
|
||||
def _replacer(match):
|
||||
call_content = match.group(0)
|
||||
m = re.search(r'pyautogui\.(?:write|typewrite)\s*\(', call_content)
|
||||
if not m:
|
||||
return call_content
|
||||
|
||||
args_part = call_content[m.end():].strip()
|
||||
args_part = re.sub(r'^(?:message|text)\s*=\s*', '', args_part)
|
||||
|
||||
text_content = ""
|
||||
if args_part.startswith(("'''", '"""')):
|
||||
quote_type = args_part[:3]
|
||||
content = args_part[3:]
|
||||
end_idx = content.rfind(quote_type)
|
||||
if end_idx != -1:
|
||||
text_content = content[:end_idx]
|
||||
else:
|
||||
text_content = content[:-1] if content.endswith(')') else content
|
||||
elif args_part.startswith(("'", '"')):
|
||||
quote_type = args_part[0]
|
||||
content = args_part[1:]
|
||||
if content.endswith(quote_type + ")"):
|
||||
text_content = content[:-2]
|
||||
elif content.endswith(")"):
|
||||
if len(content) > 1 and content[-2] == quote_type:
|
||||
text_content = content[:-2]
|
||||
else:
|
||||
text_content = content[:-1]
|
||||
elif content.endswith(quote_type):
|
||||
text_content = content[:-1]
|
||||
else:
|
||||
text_content = content
|
||||
else:
|
||||
text_content = args_part[:-1] if args_part.endswith(')') else args_part
|
||||
|
||||
new_cmds = []
|
||||
for char in text_content:
|
||||
p = "enter" if char == "\n" else char
|
||||
p_esc = p.replace("'", "\\'")
|
||||
new_cmds.append(f"pyautogui.press('{p_esc}')")
|
||||
|
||||
return "; ".join(new_cmds)
|
||||
|
||||
pattern = r"pyautogui\.(?:write|typewrite)\s*\(.*?(?=\s*;|\s*$|\n)"
|
||||
new_code = re.sub(pattern, _replacer, code)
|
||||
|
||||
if new_code == code and ("pyautogui.write" in code or "pyautogui.typewrite" in code):
|
||||
new_code = re.sub(r"pyautogui\.(?:write|typewrite)\s*\(.*", _replacer, code)
|
||||
|
||||
return new_code
|
||||
|
||||
def rewrite_pyautogui_text_inputs(code: str) -> str:
|
||||
"""
|
||||
Expand pyautogui.write/typewrite string literals into per-character presses.
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(code)
|
||||
|
||||
class _TextCallRewriter(ast.NodeTransformer):
|
||||
def _extract_text(self, call: ast.Call):
|
||||
if not (
|
||||
isinstance(call.func, ast.Attribute)
|
||||
and isinstance(call.func.value, ast.Name)
|
||||
and call.func.value.id == "pyautogui"
|
||||
and call.func.attr in ("write", "typewrite")
|
||||
):
|
||||
return None
|
||||
|
||||
message_node = call.args[0] if call.args else None
|
||||
if message_node is None:
|
||||
for kw in call.keywords:
|
||||
if kw.arg in ("message", "text"):
|
||||
message_node = kw.value
|
||||
break
|
||||
|
||||
if isinstance(message_node, ast.Constant) and isinstance(message_node.value, str):
|
||||
return message_node.value
|
||||
return None
|
||||
|
||||
def visit_Expr(self, node):
|
||||
self.generic_visit(node)
|
||||
if isinstance(node.value, ast.Call):
|
||||
text = self._extract_text(node.value)
|
||||
if text is not None:
|
||||
new_nodes = []
|
||||
for char in text:
|
||||
press_value = "enter" if char == "\n" else char
|
||||
press_call = ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(id="pyautogui", ctx=ast.Load()),
|
||||
attr="press",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[ast.Constant(value=press_value)],
|
||||
keywords=[],
|
||||
)
|
||||
)
|
||||
new_nodes.append(press_call)
|
||||
return new_nodes if new_nodes else node
|
||||
return node
|
||||
|
||||
tree = _TextCallRewriter().visit(tree)
|
||||
tree = ast.fix_missing_locations(tree)
|
||||
new_code = ast.unparse(tree)
|
||||
return new_code
|
||||
|
||||
except (SyntaxError, Exception):
|
||||
return _fallback_rewrite_pyautogui_text_inputs(code)
|
||||
|
||||
|
||||
|
||||
def project_coordinate_to_absolute_scale(pyautogui_code_relative_coordinates, screen_width, screen_height, coordinate_type="relative", resize_factor=28):
|
||||
"""
|
||||
Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size.
|
||||
"""
|
||||
def _coordinate_projection(x, y, screen_width, screen_height, coordinate_type):
|
||||
if coordinate_type == "qwen25":
|
||||
height, width = smart_resize(
|
||||
height=screen_height,
|
||||
width=screen_width,
|
||||
factor=resize_factor,
|
||||
min_pixels=3136,
|
||||
max_pixels=12845056
|
||||
)
|
||||
if 0 <= x <= 1 and 0 <= y <= 1:
|
||||
# If already normalized, treat like "relative"
|
||||
return int(round(x * width)), int(round(y * height))
|
||||
return int(x / width * screen_width), int(y / height * screen_height)
|
||||
else:
|
||||
raise ValueError(f"Invalid coordinate type: {coordinate_type}. Expected 'qwen25'")
|
||||
|
||||
pattern = r'(pyautogui\.\w+\([^\)]*\))'
|
||||
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
|
||||
|
||||
new_code = pyautogui_code_relative_coordinates
|
||||
|
||||
for full_call in matches:
|
||||
func_name_pattern = r'(pyautogui\.\w+)\((.*)\)'
|
||||
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
|
||||
if not func_match:
|
||||
continue
|
||||
|
||||
func_name = func_match.group(1)
|
||||
args_str = func_match.group(2)
|
||||
|
||||
try:
|
||||
parsed = ast.parse(f"func({args_str})").body[0].value
|
||||
parsed_args = parsed.args
|
||||
parsed_keywords = parsed.keywords
|
||||
|
||||
except SyntaxError:
|
||||
continue
|
||||
|
||||
function_parameters = {
|
||||
'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'],
|
||||
'rightClick': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'middleClick': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
|
||||
'tripleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
|
||||
'moveTo': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'],
|
||||
}
|
||||
|
||||
func_base_name = func_name.split('.')[-1]
|
||||
|
||||
param_names = function_parameters.get(func_base_name, [])
|
||||
|
||||
args = {}
|
||||
for idx, arg in enumerate(parsed_args):
|
||||
if idx < len(param_names):
|
||||
param_name = param_names[idx]
|
||||
try:
|
||||
arg_value = ast.literal_eval(arg)
|
||||
args[param_name] = arg_value
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
for kw in parsed_keywords:
|
||||
param_name = kw.arg
|
||||
arg_value = ast.literal_eval(kw.value)
|
||||
args[param_name] = arg_value
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing keyword arguments: {e}")
|
||||
continue
|
||||
|
||||
updated = False
|
||||
if 'x' in args and 'y' in args:
|
||||
try:
|
||||
x_rel = float(args['x'])
|
||||
y_rel = float(args['y'])
|
||||
# Only project if they look like relative coords (e.g. <= 1.0 or depending on type)
|
||||
# Projection applies unconditionally if type is relative
|
||||
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
|
||||
|
||||
# Apply coordinate transformation
|
||||
args['x'] = x_abs
|
||||
args['y'] = y_abs
|
||||
updated = True
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if updated:
|
||||
reconstructed_args = []
|
||||
for idx, param_name in enumerate(param_names):
|
||||
if param_name in args:
|
||||
arg_value = args[param_name]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"'{arg_value}'"
|
||||
else:
|
||||
arg_repr = str(arg_value)
|
||||
reconstructed_args.append(arg_repr)
|
||||
else:
|
||||
break
|
||||
|
||||
used_params = set(param_names[:len(reconstructed_args)])
|
||||
for kw in parsed_keywords:
|
||||
if kw.arg not in used_params:
|
||||
arg_value = args[kw.arg]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"{kw.arg}='{arg_value}'"
|
||||
else:
|
||||
arg_repr = f"{kw.arg}={arg_value}"
|
||||
reconstructed_args.append(arg_repr)
|
||||
|
||||
new_args_str = ', '.join(reconstructed_args)
|
||||
new_full_call = f"{func_name}({new_args_str})"
|
||||
new_code = new_code.replace(full_call, new_full_call)
|
||||
|
||||
return new_code
|
||||
|
||||
|
||||
def log_messages(messages, prefix="LLM Messages"):
|
||||
"""Log messages with truncated base64 images"""
|
||||
try:
|
||||
log_msgs = []
|
||||
for msg in messages:
|
||||
msg_copy = msg.copy()
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
new_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "image_url":
|
||||
item_copy = item.copy()
|
||||
url = item_copy.get("image_url", {}).get("url", "")
|
||||
if len(url) > 100:
|
||||
item_copy["image_url"] = {"url": url[:30] + "...[base64_truncated]..." + url[-10:]}
|
||||
new_content.append(item_copy)
|
||||
else:
|
||||
new_content.append(item)
|
||||
msg_copy["content"] = new_content
|
||||
log_msgs.append(msg_copy)
|
||||
logger.info(f"{prefix}:\n{json.dumps(log_msgs, indent=2, ensure_ascii=False)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log messages: {e}")
|
||||
554
run_multienv_evocua.py
Normal file
554
run_multienv_evocua.py
Normal file
@@ -0,0 +1,554 @@
|
||||
"""
|
||||
Script to run EvoCUA native agent model on OSWorld tasks.
|
||||
|
||||
export AWS_ACCESS_KEY_ID="xx"
|
||||
export AWS_SECRET_ACCESS_KEY="xx"
|
||||
export AWS_REGION="xx"
|
||||
export AWS_SECURITY_GROUP_ID="xx"
|
||||
export AWS_SUBNET_ID="xx"
|
||||
export OPENAI_API_KEY="xxxx"
|
||||
export OPENAI_BASE_URL="xxxx"
|
||||
|
||||
Example Usage (S2):
|
||||
python3 run_multienv_evocua.py \
|
||||
--headless \
|
||||
--provider_name aws \
|
||||
--observation_type screenshot \
|
||||
--model EvoCUA-S2 \
|
||||
--result_dir ./evocua_s2 \
|
||||
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||
--max_steps 50 \
|
||||
--num_envs 30 \
|
||||
--max_history_turns 4 \
|
||||
--coordinate_type relative \
|
||||
--resize_factor 32 \
|
||||
--prompt_style S2
|
||||
|
||||
|
||||
Example Usage (S1):
|
||||
python3 run_multienv_evocua.py \
|
||||
--headless \
|
||||
--provider_name aws \
|
||||
--observation_type screenshot \
|
||||
--model EvoCUA-S1 \
|
||||
--result_dir ./evocua_s1 \
|
||||
--test_all_meta_path evaluation_examples/test_nogdrive.json \
|
||||
--max_steps 50 \
|
||||
--num_envs 30 \
|
||||
--max_history_turns 3 \
|
||||
--coordinate_type qwen25 \
|
||||
--max_tokens 10240 \
|
||||
--resize_factor 28 \
|
||||
--prompt_style S1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from typing import List
|
||||
from multiprocessing import Process, Manager, Queue
|
||||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.evocua.evocua_agent import EvoCUAAgent
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation with EvoCUAAgent"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
|
||||
parser.add_argument("--max_steps", type=int, default=50)
|
||||
|
||||
# evaluation config
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="evocua", help="Model name.")
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=32768)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
parser.add_argument("--prompt_style", type=str, default="S2", choices=["S1", "S2"], help="Prompt style: 'S1' (structured reasoning) or 'S2' (tool calling)")
|
||||
parser.add_argument("--history_type", type=str, default="action_history", help="[S1] History type")
|
||||
parser.add_argument("--coordinate_type", type=str, default="relative", help="Coordinate type: relative, absolute, qwen25")
|
||||
parser.add_argument("--password", type=str, default="osworld-public-evaluation", help="VM Password")
|
||||
|
||||
# Unified History Parameter
|
||||
parser.add_argument("--max_history_turns", type=int, default=3, help="Number of history turns to include")
|
||||
parser.add_argument("--resize_factor", type=int, default=32, help="Image resize factor (S1: 28, S2: 32)")
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help="Set the logging level")
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client_password", type=str, default="", help="Client password"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_width", type=int, default=1920, help="Screen width"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--screen_height", type=int, default=1080, help="Screen height"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = config()
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
return all_tasks
|
||||
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
|
||||
# Determine snapshot based on provider
|
||||
snapshot_name = "init_state"
|
||||
if args.provider_name == "aws":
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION].get((1920, 1080)))
|
||||
snapshot_name = ami_id
|
||||
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
# Initialize EvoCUAAgent
|
||||
agent = EvoCUAAgent(
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_steps=args.max_steps,
|
||||
prompt_style=args.prompt_style,
|
||||
max_history_turns=args.max_history_turns,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
coordinate_type=args.coordinate_type,
|
||||
password=args.password,
|
||||
resize_factor=args.resize_factor,
|
||||
)
|
||||
|
||||
try:
|
||||
lib_run_single.run_single_example_evocua(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({"Error": f"{domain}/{example_id} - {e}"}))
|
||||
f.write("\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||
global is_terminating, active_environments, processes
|
||||
|
||||
if is_terminating:
|
||||
return
|
||||
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal as sig
|
||||
os.kill(p.pid, sig.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
all_tasks = distribute_tasks(test_all_meta)
|
||||
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
task_queue = manager.Queue()
|
||||
for item in all_tasks:
|
||||
task_queue.put(item)
|
||||
num_envs = args.num_envs
|
||||
processes = []
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
)
|
||||
p.daemon = True
|
||||
p.start()
|
||||
processes.append(p)
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
try:
|
||||
while True:
|
||||
alive_count = 0
|
||||
for idx, p in enumerate(processes):
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
new_p.daemon = True
|
||||
new_p.start()
|
||||
processes[idx] = new_p
|
||||
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||
else:
|
||||
alive_count += 1
|
||||
if task_queue.empty():
|
||||
logger.info("All tasks finished.")
|
||||
break
|
||||
if alive_count == 0:
|
||||
logger.error("All processes died, exiting.")
|
||||
break
|
||||
time.sleep(5)
|
||||
for p in processes:
|
||||
p.join()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
scores = list(shared_scores)
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
args = config()
|
||||
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt.")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
finally:
|
||||
logger.info("Main process final cleanup...")
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info("Closing environment in final cleanup...")
|
||||
env.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during final environment cleanup: {e}")
|
||||
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating process: {e}")
|
||||
|
||||
time.sleep(1)
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
os.kill(p.pid, signal.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error force killing process: {e}")
|
||||
Reference in New Issue
Block a user