From ddb8372a6cbb51a29583cc1c0fe8c090e61219b7 Mon Sep 17 00:00:00 2001 From: Yan98 <291157173@qq.com> Date: Tue, 7 Oct 2025 01:16:31 +1100 Subject: [PATCH] init public release (#350) --- mm_agents/gta1/cua_tool.py | 302 ++++++++++++++ mm_agents/gta1/format_message.py | 73 ++++ mm_agents/gta1/gta15_agent.py | 616 +++++++++++++++++++++++++++++ mm_agents/{ => gta1}/gta1_agent.py | 9 +- run_multienv_gta1.py | 132 +++++-- 5 files changed, 1106 insertions(+), 26 deletions(-) create mode 100644 mm_agents/gta1/cua_tool.py create mode 100644 mm_agents/gta1/format_message.py create mode 100644 mm_agents/gta1/gta15_agent.py rename mm_agents/{ => gta1}/gta1_agent.py (99%) diff --git a/mm_agents/gta1/cua_tool.py b/mm_agents/gta1/cua_tool.py new file mode 100644 index 0000000..b3e45f6 --- /dev/null +++ b/mm_agents/gta1/cua_tool.py @@ -0,0 +1,302 @@ +tools = [ + { + "type": "function", + "function": { + "name": "click", + "description": "Click on the element", + "parameters": { + "type": "object", + "properties": { + "instruction": { + "type": "string", + "description": "Decribe the element you want to interact with in detail including the visual description and function description. And make it clear and concise. For example you can describe what the element looks like, and what will be the expected result when you interact with it." + }, + "num_clicks": { + "type": "integer", + "description": "Number of times to click the element.", + "default": 1 + }, + "button_type": { + "type": "string", + "enum": ["left", "middle", "right"], + "description": "Which mouse button to press.", + "default": "left" + }, + "hold_keys": { + "type": "array", + "items": {"type": "string"}, + "description": "List of keys to hold while clicking", + "default": [] + } + }, + "required": ["instruction"] + } + } + }, + { + "type": "function", + "function": { + "name": "drag_and_drop", + "description": "Drag from the starting description to the ending description", + "parameters": { + "type": "object", + "properties": { + "starting_description": { + "type": "string", + "description": "A very detailed description of where to start the drag action. This description should be at least a full sentence. And make it clear and concise." + }, + "ending_description": { + "type": "string", + "description": "A very detailed description of where to end the drag action. This description should be at least a full sentence. And make it clear and concise." + }, + "hold_keys": { + "type": "array", + "items": {"type": "string"}, + "description": "List of keys to hold while dragging", + "default": [] + } + }, + "required": ["starting_description", "ending_description"] + } + } + }, + { + "type": "function", + "function": { + "name": "highlight_text_span", + "description": "Highlight a text span between a provided starting phrase and ending phrase. Use this to highlight words, lines, and paragraphs.", + "parameters": { + "type": "object", + "properties": { + "starting_phrase": { + "type": "string", + "description": "The phrase that denotes the start of the text span you want to highlight. If you only want to highlight one word, just pass in that single word." + }, + "ending_phrase": { + "type": "string", + "description": "The phrase that denotes the end of the text span you want to highlight. If you only want to highlight one word, just pass in that single word." + } + }, + "required": ["starting_phrase", "ending_phrase"] + } + } + }, + { + "type": "function", + "function": { + "name": "hold_and_press", + "description": "Hold a list of keys and press a list of keys", + "parameters": { + "type": "object", + "properties": { + "hold_keys": { + "type": "array", + "items": {"type": "string"}, + "description": "List of keys to hold" + }, + "press_keys": { + "type": "array", + "items": {"type": "string"}, + "description": "List of keys to press in a sequence" + } + }, + "required": ["hold_keys", "press_keys"] + } + } + }, + { + "type": "function", + "function": { + "name": "hotkey", + "description": "Press a hotkey combination", + "parameters": { + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": {"type": "string"}, + "description": "List the keys to press in combination in a list format (e.g. ['ctrl', 'c'])" + } + }, + "required": ["keys"] + } + } + }, + { + "type": "function", + "function": { + "name": "open", + "description": "Open any application or file with name app_or_filename. Use this action to open applications or files on the desktop, do not open manually.", + "parameters": { + "type": "object", + "properties": { + "app_or_filename": { + "type": "string", + "description": "The name of the application or filename to open" + } + }, + "required": ["app_or_filename"] + } + } + }, + { + "type": "function", + "function": { + "name": "scroll", + "description": "Scroll the element in the specified direction", + "parameters": { + "type": "object", + "properties": { + "instruction": { + "type": "string", + "description": "A very detailed description of which element to enter scroll in. This description should be at least a full sentence. And make it clear and concise." + }, + "clicks": { + "type": "integer", + "description": "The number of clicks to scroll can be positive (up) or negative (down)." + }, + "shift": { + "type": "boolean", + "description": "Whether to use shift+scroll for horizontal scrolling", + "default": False + } + }, + "required": ["instruction", "clicks"] + } + } + }, + { + "type": "function", + "function": { + "name": "set_cell_values", + "description": """Use this to set individual cell values or formulas in a spreadsheet. For setting values: pass {"A2": "hello", "B2": "world"} to set text, or {"A1": 42, "B1": 3.14} for numbers. For setting formulas: start with '=' like {"A2": "=B2+C2", "C1": "=SUM(A1:A10)"}. The sheet must be opened before this command can be used.""", + "parameters": { + "type": "object", + "properties": { + "cell_values": { + "type": "object", + "description": """A dictionary of cell values or formulas to set in the spreadsheet. Keys are cell coordinates like "A1", "B2", etc. Examples: For values: {"A2": "hello", "B1": 42}. For formulas: {"A2": "=B2+C2", "C1": "=SUM(A1:A10)"}. Always start formulas with '='.""", + "additionalProperties": { + "type": ["number", "string"] + }, + "default": {} + }, + "app_name": { + "type": "string", + "description": "Spreadsheet application/file name (e.g., 'Some_sheet.xlsx')." + }, + "sheet_name": { + "type": "string", + "description": "Sheet name (e.g., 'Sheet1')." + } + }, + "required": ["cell_values", "app_name", "sheet_name"] + } + } + }, + { + "type": "function", + "function": { + "name": "switch_applications", + "description": "Switch to a different application that is already open", + "parameters": { + "type": "object", + "properties": { + "app_code": { + "type": "string", + "description": "The code/name of the application to switch to from the open apps list." + } + }, + "required": ["app_code"] + } + } + }, + { + "type": "function", + "function": { + "name": "type", + "description": "Type text into a specific element", + "parameters": { + "type": "object", + "properties": { + "element_description": { + "type": ["string", "null"], + "description": "Detailed, full-sentence description of the element to type into. If omitted, types into the focused element.", + "default": None + }, + "text": { + "type": "string", + "description": "The text to type.", + "default": "" + }, + "overwrite": { + "type": "boolean", + "description": "If true, clear existing text before typing.", + "default": False + }, + "enter": { + "type": "boolean", + "description": "If true, press Enter after typing.", + "default": False + } + }, + "required": ["text"] + } + } + }, + { + "type": "function", + "function": { + "name": "wait", + "description": "Wait for a specified amount of time", + "parameters": { + "type": "object", + "properties": { + "time": { + "type": "number", + "description": "Time to wait in seconds." + } + }, + "required": ["time"] + } + } + }, + { + "type": "function", + "function": { + "name": "fast_open_terminal", + "description": "Save the file in focus, close it, and open a terminal.", + "parameters": { + "type": "object", + "properties": {}, + "required": [] + } + } + } +] + +def to_response_api_tools(completion_tools): + """ + Convert completion-style tools (nested under the 'function' key) + into response-style tools (flattened: type/name/description/parameters). + + Example: + {"type": "function", "function": {"name": "click", ...}} + -> + {"type": "function", "name": "click", ...} + """ + response_tools = [] + for tool in completion_tools or []: + if isinstance(tool, dict) and tool.get("type") == "function" and isinstance(tool.get("function"), dict): + fn = tool["function"] + response_tools.append({ + "type": "function", + "name": fn.get("name"), + "description": fn.get("description"), + "parameters": fn.get("parameters"), + }) + else: + response_tools.append(tool) + return response_tools + +response_api_tools = to_response_api_tools(tools) \ No newline at end of file diff --git a/mm_agents/gta1/format_message.py b/mm_agents/gta1/format_message.py new file mode 100644 index 0000000..fd1f1b1 --- /dev/null +++ b/mm_agents/gta1/format_message.py @@ -0,0 +1,73 @@ + +import base64 +import os +from typing import Dict, Any, List, Union +import numpy as np +import cv2 + +class FormatMessage: + def __init__(self): + self.text_key = "input_text" + self.image_key = "input_image" + + def encode_image(self, image_content: bytes) -> str: + return base64.b64encode(image_content).decode('utf-8') + + def format_image(self, image: bytes, detail: str="high") -> Dict[str, Any]: + return { + "type": self.image_key, + "image_url": f"data:image/png;base64,{self.encode_image(image)}", + "detail": detail + } + + def format_text_message(self, text: str) -> Dict[str, Any]: + return {"type": self.text_key, "text": text} + + def create_system_message(self, content: str) -> Dict[str, Any]: + return { + "role": "system", + "content": [self.format_text_message(content)] + } + + def create_user_message(self, text: str=None, image: bytes=None, detail: str="high", image_first: bool=False) -> Dict[str, Any]: + if text is None and image is None: + raise ValueError("At least one of text or image must be provided") + + content = [] + + # Add text if provided + if text is not None: + content.append(self.format_text_message(text)) + + # Add image if provided + if image is not None: + content.append(self.format_image(image, detail)) + + if image_first: + content.reverse() + return { + "role": "user", + "content": content + } + + def create_assistant_message(self, text: str) -> Dict[str, Any]: + return { + "role": "assistant", + "content": [{"type": "output_text", "text": text}] + } + + +def encode_numpy_image_to_base64(image: np.ndarray) -> str: + # Convert numpy array to bytes + success, buffer = cv2.imencode('.png', image) + if not success: + raise ValueError("Failed to encode image to png format") + + # Convert bytes to base64 string + image_bytes = buffer.tobytes() + base64_string = base64.b64encode(image_bytes).decode('utf-8') + + return base64_string + +def encode_image_bytes(image_content): + return base64.b64encode(image_content).decode('utf-8') \ No newline at end of file diff --git a/mm_agents/gta1/gta15_agent.py b/mm_agents/gta1/gta15_agent.py new file mode 100644 index 0000000..f920ec5 --- /dev/null +++ b/mm_agents/gta1/gta15_agent.py @@ -0,0 +1,616 @@ +import json +import logging +import os +import time +from typing import Any, Dict, List, Tuple, Callable +from desktop_env.desktop_env import DesktopEnv +from openai import OpenAI +from mm_agents.gta1.format_message import FormatMessage +from mm_agents.gta1.cua_tool import response_api_tools as CUA_TOOLS +import inspect +import concurrent.futures +import re +from mm_agents.utils.qwen_vl_utils import smart_resize +from mm_agents.gta1.gta1_agent import OSWorldACI +import httpx +import numpy as np +from PIL import Image +from io import BytesIO +from mm_agents.gta1.format_message import encode_numpy_image_to_base64, encode_image_bytes + + +GTA1_SERVICE_URL=os.getenv("GTA1_SERVICE_URL",None) + +GTA1_GROUNDING_SYSTEM_PROMPT=( + "You are a GUI agent. You are given a task and a screenshot of the screen. " + "You need to perform a series of pyautogui actions to complete the task." +) + +CUA_SYSTEM_PROMPT_GPT5 = """# Role and Objective +- An agent with strong computer knowledge and a good internet connection, designed to execute desktop computer tasks on Ubuntu precisely as instructed by the user. +- Assumes tool calls will run to control the computer. +- Has access to all its reasoning and knowledge for use in tasks. + +# Instructions +- Begin each user task with a concise checklist (3–7 items) of conceptual, non-implementation sub-tasks. +- Revise the sub-tasks checklist as the task progresses, based on the latest screenshot and previous actions. +- Interact solely using the provided tool actions; do not invent or assume any unlisted methods. Use only tools explicitly listed in the available actions for every step. +- Base every action on observable elements in the latest screenshot; never anticipate or assume elements not yet present or visible. +- For each step, you will receive a new screenshot, tool execution results, and the remaining number of steps allowed in the user task. +- If an option or input is not specified in the user task (e.g., creating a new file without specifying a name), use the default settings. + +## Action Execution Guidelines +- Execute exactly one tool call per interaction. +- Prefer the `hotkey` action (tool call) over `click` or `drag_and_drop` where possible. +- For spreadsheet value or formula changes in LibreOffice Calc, Writer, Impress, always use `set_cell_values` for both single-cell and multi-cell value or formula editing. +- When highlighting text, use only the `highlight_text_span` or `hotkey` (tool calls). +- Dismiss "Authentication required" prompts by clicking "Cancel". +- All tool calls are permitted within the provided action list; do not attempt actions outside this set. + +# Additional Information +- Leave windows/applications open at task completion. +- Upon fully completing the user's task, briefly summarize results if applicable, then return `TERMINATE`. +- **Feasibility First**: Confirm the task can be completed with available files, applications, and environments before starting. +- **Strict Adherence**: Only perform actions the user has explicitly requested; avoid unnecessary steps. +- **Completion Criteria**: Only return "TERMINATE" when all user requirements are met in full. +- **Impossibility Handling**: Return "INFEASIBLE" if completion is blocked by environmental constraints. +- **Screenshot Verification**: Always check the screenshot before proceeding. + +# Additional Rules +- The sudo password is "{CLIENT_PASSWORD}"; use it if sudo privileges are required. +- Leave all windows and applications open after completing the task. +- Only use `TERMINATE` when all user requirements have been fully satisfied; provide a brief summary of results if applicable. +- Before proceeding, confirm that the task is feasible with the currently available files, applications, and environment; if it is impossible to complete due to environmental constraints, return `INFEASIBLE`. +- Strictly follow user instructions, avoiding unnecessary or extraneous steps. +- Always review the latest screenshot before every action. + +# Execution Procedure +- Briefly review prior actions, the current checklist, and the latest screenshot before each tool call. +- Before each action, state in one line the purpose and required minimal inputs. +- After each action, validate the result in 1–2 lines using the updated screenshot. If the action was unsuccessful, adapt your approach before proceeding. +- Only return the selected action(s); do not elaborate or output other information. +- Work deliberately and avoid unnecessary or extraneous steps; strictly adhere to user instructions. + +Proceed methodically and efficiently, ensuring all user requirements are met before terminating.""" + +CUA_START_MESSAGE = """ +Please check the screenshot and see if the task is impossible to complete due to environmental constraints. If it is, reply with 'INFEASIBLE'. +If it is possible to complete, please complete the task, and before making any tool call, you should reasoning the next move according to the UI screenshot and instruction, while refer to the previous actions (tool calls), screenshots, and observations for reflection. + +User task: +{instruction} + +""".strip() + + +CUA_DEFAULT_REPLY = """Note the user task is: + +{instruction} + +If you have completed the user task, reply with 'TERMINATE'. +If the task is impossible to complete due to environmental constraints, reply with 'INFEASIBLE'.""" + + +GTA1_JUDGE_SYSTEM_PROMPT='''# Role and Objective +Assess the planning and reasoning of a UI agent to determine the most effective action for advancing toward a specified task goal. You may use the computer password '{CLIENT_PASSWORD}' during this process if needed. + +# Workflow Checklist +Begin each assessment by generating a concise checklist (adapt as appropriate for task complexity) of evaluation steps to ensure a systematic and methodical analysis. +# Inputs +For each assessment, you will receive: +- The task goal +- The history of planning and actions performed +- A current UI screenshot +- A list of {N_PLANNING} alternative planning approaches for achieving the goal, in the current context. Each approach will be formatted as: + - Thought: + - Action: + +# Action Function Definition +Actions are formatted as function calls. The specification for these calls is provided here: +{FUNCTION_CALL_DEFINITION} + +# Assessment Criteria +- Correctness: Does the proposed action logically advance the goal? +- Effectiveness: Is immediate progress made? +- Alignment: Does it support both the step and overall objective? +- Planning Quality: Reasoning is clear, concise, and logical. +- Appropriateness: Action is valid/executable in the current context. +- Matchness: Does the action correspond exactly to names/nouns in the user task? Avoid generalization or conflation. +- Exactness: Does the action relate to the user task? No extra or unnecessary steps are performed. +- Completeness: If terminate, does the action complete the user task? + +Be aware that some planning approaches may be similar—evaluate each on its own merits, and do not allow the frequency of similar approaches to bias your assessment. +Carefully assess each approach and select the best one based on the above criteria. + +# Output Format +Produce a single, strictly valid JSON object with the following fields: +- `explaining` (string, required): A concise (1–4 sentences) justification for why the chosen approach is optimal in light of the assessment criteria; or, if none are effective, briefly explain why. +- `index` (integer, required): The 0-based index (0, 1, ..., {N_INDEX}) identifying the best approach. You must choose one of the approaches. +Do not output anything except the required JSON object. + +**Carefully evaluate each approach and select the best one based on the criteria.**''' + +def make_single_request(client: OpenAI, logger: logging.Logger, *args, **kwargs): + for retry in range(10): + try: + response = client.responses.create( + *args, + **kwargs + ) + response.output_text + return response + except Exception as e: + if os.getenv("VERBOSEDEBUG", None) is not None: + print(f"Error in response.create: {e}") + time.sleep(min(retry**2, 16)) + return None + +def extract_answer_from_response(response): + if not response or not isinstance(response, str): + raise ValueError("Response must be a non-empty string") + json_pattern = r'```json\s*(.*?)\s*```' + json_match = re.search(json_pattern, response, re.DOTALL) + + if json_match: + json_str = json_match.group(1) + try: + answer = json.loads(json_str) + if "explaining" in answer and "index" in answer: + answer["index"] = int(answer["index"]) + return answer + else: + raise ValueError("JSON missing required fields 'explaining' or 'index'") + + except json.JSONDecodeError: + pass + + direct_json_pattern = r'\{[\s\S]*?"explaining"[\s\S]*?"index"[\s\S]*?\}' + direct_match = re.search(direct_json_pattern, response) + + if direct_match: + try: + json_str = direct_match.group(0) + json_str = json_str.replace(''', "'").replace(''', "'").replace('"', '"').replace('"', '"') + answer = json.loads(json_str) + answer["index"] = int(answer["index"]) + return answer + except json.JSONDecodeError: + pass + index_pattern = r'"index"\s*:\s*(\d+)' + index_match = re.search(index_pattern, response) + + explaining_pattern = r'"explaining"\s*:\s*"(.*?)"(?=,|\s*})' + explaining_match = re.search(explaining_pattern, response, re.DOTALL) + + if not explaining_match: + explaining_pattern = r'"explaining"\s*:\s*(.*?)(?=,\s*"index"|\s*})' + explaining_match = re.search(explaining_pattern, response, re.DOTALL) + + if index_match and explaining_match: + return { + "index": int(index_match.group(1)), + "explaining": explaining_match.group(1).strip('" \t\n') + } + if index_match: + return { + "index": int(index_match.group(1)), + "explaining": "Explanation not found in response" + } + raise ValueError("Could not extract valid answer from response") + +def select_response(summary_info, responses, client_password): + summary_info, curr_obs, instruction = summary_info + + MAX_RETRY_TIMES = 10 + + system_promt = GTA1_JUDGE_SYSTEM_PROMPT.format(N_PLANNING=len(responses), N_INDEX=len(responses)-1, CLIENT_PASSWORD=client_password, FUNCTION_CALL_DEFINITION=json.dumps(CUA_TOOLS,indent=2)) + + message_formater = FormatMessage() + messages = [ + message_formater.create_system_message(system_promt), + message_formater.create_user_message(text=f"The goal of the task is:\n{instruction}\n\n\n"), + + ] + + if len(summary_info) == 0: + messages.append(message_formater.create_user_message(text=f"No history available. The action just started.\n")) + else: + for idx, (curr_obs, action_call, content_text) in enumerate(summary_info): + name = action_call['name'] + args = action_call['arguments'] + action = f"{name}({args})" + if os.getenv("JUDGE_SCREENSHOT_PROMPT", None) is not None and idx >= len(summary_info) - 5: + messages.append(message_formater.create_user_message(text=f"\n### {idx} Screenshot before taking the action:\n")) + messages.append(message_formater.create_user_message(image=curr_obs['screenshot'])) + messages.append(message_formater.create_user_message(text=f"\n")) + messages.append(message_formater.create_user_message(text=f"### Past step {idx}:\nThought:{content_text}\nAction:{action_call}\n\n\n")) + messages.append(message_formater.create_user_message(text=f"Here are the different plans to compare:\n")) + for idx, plan in enumerate(responses): + messages.append(message_formater.create_user_message(text=f"### Index {idx}:\n{plan}\n\n\n")) + + messages.append(message_formater.create_user_message(text=f"Here are the current screenshot:\n")) + messages.append(message_formater.create_user_message(image=curr_obs['screenshot'])) + messages.append(message_formater.create_user_message(text=f"Here are the different plans to compare for completing the task:\n")) + for idx, rsp in enumerate(responses): + content_text = rsp.output_text + action = "No Action is performed." + for i, o in enumerate(rsp.output): + typ = o["type"] if isinstance(o, dict) else getattr(o, "type", None) + if typ == 'function_call': + name = o.name + args = json.loads(o.arguments) + action = f"{name}({args})" + break + messages.append(message_formater.create_user_message(text=f"### Index {idx}:\nThought:{content_text}\nAction:{action}\n\n\n")) + + messages.append(message_formater.create_user_message(text=f"Please select the best plan to complete the task.")) + + if os.getenv("X_API_KEY") and os.getenv("X_API_URL"): + client = OpenAI(base_url=os.getenv("X_API_URL"), api_key="dummy", default_headers = {"X-Api-Key": os.getenv("X_API_KEY")}) + else: + client = OpenAI() + wait = 1 + for _ in range(MAX_RETRY_TIMES): + try: + prediction = client.responses.create( + model="gpt-5", + input=messages, + reasoning={"effort": "high"}, + max_output_tokens=4096 * 4, + timeout=100, + ) + prediction = prediction.output_text + if os.getenv("VERBOSEDEBUG", None) is not None: + print(f"Prediction: {prediction}") + prediction = extract_answer_from_response(prediction) + return responses[prediction['index']] + except: + time.sleep(wait) + wait *=2 + wait = min(wait,16) + continue + return responses[0] + +def call_openai_cua(client: OpenAI, + history_inputs: list, + cua_model: str, + logger: logging.Logger = None, + tts_step: int = 1, + summary_info: List[Any] = None, + client_password: str = "", + ) -> Tuple[Any, float]: + retry = 0 + response = None + if tts_step == 1: + response = make_single_request(client, logger, + model=cua_model, + tools=CUA_TOOLS, + parallel_tool_calls=False, + reasoning={"effort": "high"}, + max_output_tokens=4096 * 4, + input=history_inputs, + timeout=500) + else: + potential_responses = [] + retry = 0 + while len(potential_responses) < tts_step and retry < 5: + retry += 1 + with concurrent.futures.ThreadPoolExecutor(max_workers=tts_step-len(potential_responses)) as executor: + futures = [executor.submit(make_single_request, client, logger, + model=cua_model, + tools=CUA_TOOLS, + parallel_tool_calls=False, + reasoning={"effort": "high"}, + max_output_tokens=4096 * 4, + input=history_inputs, + timeout=500) for _ in range(tts_step-len(potential_responses))] + responses = [future.result() for future in concurrent.futures.as_completed(futures)] + responses = [response for response in responses if response is not None] + potential_responses.extend(responses) + responses = potential_responses + if os.getenv("VERBOSEDEBUG", None) is not None: + print(f"Responses: {responses}") + response = select_response(summary_info,responses,client_password) + return response + +def _tool_call_to_pyautogui(agent: OSWorldACI, + action_call: Dict[str, Any], + obs: Dict[str, Any], + request_vllm: Callable, + logger: logging.Logger = None) -> Tuple[str, str]: + tool_output = "Action (tool call) is executed. For your reference, you have maximum of {max_steps} steps, and current step is {step_no} out of {max_steps}." + method = None + try: + name = action_call['name'] + args = action_call['arguments'] + # Default: no coordinates needed + agent.coords1, agent.coords2 = None, None + + # Compute coordinates for description-based actions + if name == "click" and isinstance(args.get("instruction"), str): + agent.coords1 = agent.generate_coords(args["instruction"], obs, request_vllm) + elif name == "type": + element_description = args.get("element_description") + if isinstance(element_description, str) and element_description: + agent.coords1 = agent.generate_coords(element_description, obs, request_vllm) + elif name == "scroll" and isinstance(args.get("instruction"), str): + agent.coords1 = agent.generate_coords(args["instruction"], obs, request_vllm) + elif name == "drag_and_drop": + sd = args.get("starting_description") + ed = args.get("ending_description") + if isinstance(sd, str) and isinstance(ed, str): + agent.coords1 = agent.generate_coords(sd, obs, request_vllm) + agent.coords2 = agent.generate_coords(ed, obs, request_vllm) + elif name == "highlight_text_span": + sp = args.get("starting_phrase") + ep = args.get("ending_phrase") + if isinstance(sp, str) and isinstance(ep, str): + agent.coords1 = agent.generate_text_coords(sp, obs, alignment="start") + agent.coords2 = agent.generate_text_coords(ep, obs, alignment="end") + + # Dispatch to OSWorldACI method to build pyautogui command + if hasattr(agent, name): + method = getattr(agent, name) + # Some arguments may be missing; rely on method defaults + return method(**args),tool_output + except Exception as e: + if os.getenv("VERBOSEDEBUG", None) is not None: + print(f"Error in _tool_call_to_pyautogui: {e}") + tool_output = "Error: " + str(e).replace("OSWorldACI.","").strip() + if method is not None: + sig = inspect.signature(method) + tool_output += f"\nThe tool signature is: {method.__name__}{sig}" + + return "WAIT", tool_output + +def request_vllm(image, prompt): + CLICK_REGEXES = [ + # pyautogui.click(x=123, y=456) + re.compile(r"click\s*\(\s*x\s*=\s*(\d+)\s*,\s*y\s*=\s*(\d+)\s*\)", re.IGNORECASE), + # pyautogui.click(123, 456) or click(123,456) + re.compile(r"click\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)", re.IGNORECASE), + ] + + def parse_xy_from_text(text: str): + if "click" not in text.lower(): + return [-1, -1] + for rx in CLICK_REGEXES: + m = rx.search(text) + if m: + try: + return int(m.group(1)), int(m.group(2)) + except Exception: + continue + return None + + if isinstance(image, bytes): + image = np.array(Image.open(BytesIO(image)).convert('RGB')) + H, W, C = image.shape + H, W = smart_resize( + H, + W, + factor=28, + min_pixels=1000, + max_pixels=1000000000000, + ) + assert C == 3 + if isinstance(image, np.ndarray): + image_base64 = encode_numpy_image_to_base64(image) + elif isinstance(image, bytes): + image_base64 = encode_image_bytes(image) + else: + raise ValueError(f"Invalid image type: {type(image)}") + messages=[ + {"role": "system", "content": GTA1_GROUNDING_SYSTEM_PROMPT}, + { + "role": "user", + "content": [ + { + "type": "image","image": f"data:image/png;base64,{image_base64}" + }, + { + "type": "text", + "text": prompt + }, + ], + }] + base_url = GTA1_SERVICE_URL + payload = { + "messages": messages, + "max_new_tokens": 100, + "temperature": 0.0, + "top_p": 0.9, + } + for _ in range(10): + try: + httpx_client = httpx.Client() + r = httpx_client.post(f"{base_url}/call_llm", json=payload, timeout=10) + r.raise_for_status() + resp = r.json() + if isinstance(resp, dict): + result_items = [resp] + else: + result_items = resp + first = result_items[0] + x,y = parse_xy_from_text(first.get("response")) + x = x/W + y = y/H + return x,y + except: + if os.getenv("VERBOSEDEBUG", None) is not None: + print(resp) + time.sleep(1) + continue + raise RuntimeError(f"Failed to execute grounding") + + + +def _prune_history_images(messages: List[Dict[str, Any]], max_recent_images: int) -> None: + """Keep only the very first image message and the latest N image messages. + + - Preserves the earliest image-containing message (initial screenshot) + - Preserves up to `max_recent_images` most recent image messages + - Removes any other image messages + """ + try: + if max_recent_images is None: + return + if max_recent_images < 0: + return + + image_indices: List[int] = [] + for idx, msg in enumerate(messages): + if isinstance(msg, dict) and isinstance(msg.get('content'), list): + for blk in msg['content']: + if isinstance(blk, dict) and blk.get('type') in ('image_url', 'input_image'): + image_indices.append(idx) + break + + if len(image_indices) <= 1: + return # Zero or one image message — nothing to prune + + first_image_idx = image_indices[0] + recent_keep: List[int] = image_indices[-max_recent_images:] if max_recent_images > 0 else [] + keep_set = set([first_image_idx] + recent_keep) + delete_indices = [i for i in image_indices if i not in keep_set] + + # Remove from end to avoid reindexing issues + if os.getenv("VERBOSEDEBUG", None) is not None: + print(f"Pruning history images: {delete_indices}") + for i in sorted(delete_indices, reverse=True): + messages.pop(i) + except Exception: + # Be conservative: never fail the main loop due to pruning + pass + +def run_cua_gpt5gta1( + env: DesktopEnv, + instruction: str, + max_steps: int, + save_path: str = './', + sleep_after_execution: float = 0.3, + client_password: str = "", + cua_model: str = "gpt-5", + tts_step: int = 8, + purge_history_images: int = 8, + request_vllm: Callable = request_vllm, + logger: logging.Logger = None, + **kwargs: Any, +): + if os.getenv("X_API_KEY"): + client = OpenAI(base_url=os.getenv("X_API_URL"), api_key="dummy", default_headers = {"X-Api-Key": os.getenv("X_API_KEY")}) + else: + client = OpenAI() + agent = OSWorldACI(platform="linux") + message_formater = FormatMessage() + default_reply = CUA_DEFAULT_REPLY.format(instruction=instruction) + + # 0 / reset & first screenshot + os.makedirs(save_path, exist_ok=True) + obs_bytes = env.controller.get_screenshot() + with open(os.path.join(save_path, "initial_screenshot.png"), "wb") as f: + f.write(obs_bytes) + traj = [] + history_inputs = [ + message_formater.create_system_message(CUA_SYSTEM_PROMPT_GPT5.format(CLIENT_PASSWORD=client_password)), + message_formater.create_user_message(text=CUA_START_MESSAGE.format(instruction=instruction),image=obs_bytes,image_first=False), + ] + + curr_obs = {"screenshot": obs_bytes} + + summary_info = [] + step_no = 0 + logger.info(f"--------------------------------CUA Step {step_no+1}--------------------------------") + response = call_openai_cua(client, history_inputs, cua_model, logger=logger, tts_step=tts_step, summary_info=[summary_info,curr_obs,instruction], client_password=client_password) + reasoning = "" + # 1 / iterative dialogue + while step_no < max_steps: + step_no += 1 + + # --- extract function calls and handle assistant content ------------- + calls: List[Dict[str, Any]] = [] + content_text = "" + buffer_history = [] + + # Collect function calls from chat completions tool_calls + for i, o in enumerate(response.output): + typ = o["type"] if isinstance(o, dict) else getattr(o, "type", None) + if typ == 'function_call': + buffer_history.append(o) + calls.append({ + 'call_id': o.call_id, + 'name': o.name, + 'arguments': json.loads(o.arguments), + }) + elif typ == 'message': + content_text = o.content[0].text + if os.getenv("VERBOSEDEBUG", None) is not None: + print(content_text) + buffer_history.append( + {"role": o.role, "content": o.content} + ) + assert len(calls) <= 1, f"Unexpected assistant content: {content_text} \n {calls}" + + history_inputs.extend(buffer_history) + for action_call in calls: + traj.append(action_call) + logger.info(f"[Action Call]: {action_call}") + py_cmd, tool_output = _tool_call_to_pyautogui(agent, action_call, curr_obs, request_vllm, logger=logger) + summary_info.append([curr_obs, action_call, content_text]) + # --- execute in VM --------------------------------------------------- + obs, *_ = env.step(py_cmd, sleep_after_execution) + + # --- send screenshot back ------------------------------------------- + with open(os.path.join(save_path, f"step_{step_no}.png"), "wb") as f: + f.write(obs["screenshot"]) + + history_inputs.append( + { + 'type': 'function_call_output', + 'call_id': action_call['call_id'], + 'output':tool_output.format(max_steps=max_steps, step_no=step_no) + } + ) + # Provide the screenshot as a separate user message so the model can actually see it + history_inputs.append( + message_formater.create_user_message( + text=f"Here is the screenshot after the {step_no}-th action (tool call) is executed.", + image=obs['screenshot'] + ) + ) + # Prune history to keep first image and at most N latest images + if purge_history_images > 0: + _prune_history_images(history_inputs, purge_history_images) + curr_obs = obs + # Handle plain assistant content string + content_text = response.output_text or '' + if isinstance(content_text, str) and content_text: + if 'TERMINATE' in content_text: + traj.append({"type": "TERMINATE"}) + logger.info(f"#Terminate message:\n{content_text}.") + step_no-=1 + env.step("DONE", sleep_after_execution) + return "DONE", traj + elif 'INFEASIBLE' in content_text: + traj.append({"type": "INFEASIBLE"}) + logger.info(f"Stop reason (unfinished):\n{content_text}.") + step_no-=1 + env.step("FAIL", sleep_after_execution) + return "FAIL", traj + else: + if len(calls) < 1: + step_no-=1 + remaining_steps = max_steps - step_no + if len(calls) < 1 or remaining_steps <= 1: + remind_terminate_message = "" + if remaining_steps <= 1: + remind_terminate_message = "\n\n\nThe maximum number of steps has been reached. Please check the screenshot. Return 'TERMINATE' if the task is completed, or reply with 'INFEASIBLE' if the task is impossible to complete due to environmental constraints." + history_inputs.append(message_formater.create_user_message(text=default_reply + remind_terminate_message)) + + assert len(calls) <= 1, f"Unexpected assistant content: {content_text} \n {calls}" + + logger.info(f"--------------------------------CUA Step {step_no+1}--------------------------------") + response = call_openai_cua(client, history_inputs, cua_model, logger=logger, tts_step=tts_step, summary_info=[summary_info,curr_obs,instruction], client_password=client_password) + traj.append({"type": "INFEASIBLE"}) + env.step("FAIL", sleep_after_execution) + return reasoning, traj \ No newline at end of file diff --git a/mm_agents/gta1_agent.py b/mm_agents/gta1/gta1_agent.py similarity index 99% rename from mm_agents/gta1_agent.py rename to mm_agents/gta1/gta1_agent.py index ca39c32..96e3b02 100644 --- a/mm_agents/gta1_agent.py +++ b/mm_agents/gta1/gta1_agent.py @@ -62,7 +62,7 @@ class LMMEngineOpenAI: self.model = model api_key = api_key or os.getenv("OPENAI_API_KEY") - if api_key is None: + if api_key is None and os.getenv("X_API_KEY") is None: raise ValueError( "An API Key needs to be provided in either the api_key parameter or as an environment variable named OPENAI_API_KEY" ) @@ -72,10 +72,10 @@ class LMMEngineOpenAI: self.api_key = api_key self.request_interval = 0 if rate_limit == -1 else 60.0 / rate_limit - if not self.base_url: + if api_key: self.llm_client = OpenAI(api_key=self.api_key) else: - self.llm_client = OpenAI(base_url=self.base_url, api_key=self.api_key) + self.llm_client = client = OpenAI(base_url=os.getenv("X_API_URL"), api_key="dummy", default_headers = {"X-Api-Key": os.getenv("X_API_KEY")}) @backoff.on_exception( backoff.expo, (APIConnectionError, APIError, RateLimitError), max_time=60 @@ -425,6 +425,7 @@ class OSWorldACI: platform: 'linux', width: int = 1920, height: int = 1080, + model: str = "o3", ): self.platform = ( platform # Dictates how the switch_applications agent action works. @@ -432,7 +433,7 @@ class OSWorldACI: engine_params_for_generation = engine_params = { "engine_type": 'openai', - "model": 'o3', + "model": model, "base_url": '', "api_key": os.environ.get("OPENAI_API_KEY", ""), } diff --git a/run_multienv_gta1.py b/run_multienv_gta1.py index 529331f..080ff07 100644 --- a/run_multienv_gta1.py +++ b/run_multienv_gta1.py @@ -14,7 +14,8 @@ from multiprocessing import Process, Manager from multiprocessing import current_process import lib_run_single from desktop_env.desktop_env import DesktopEnv -from mm_agents.gta1_agent import GTA1Agent +from mm_agents.gta1.gta1_agent import GTA1Agent +from mm_agents.gta1.gta15_agent import run_cua_gpt5gta1 # Global variables for signal handling active_environments = [] @@ -58,6 +59,8 @@ def config() -> argparse.Namespace: # lm config parser.add_argument("--model", type=str, default="o3") + parser.add_argument("--tts_step", type=int, default=8) + parser.add_argument("--purge_history_images", type=int, default=8) # example config parser.add_argument("--domain", type=str, default="all") @@ -156,7 +159,7 @@ def process_signal_handler(signum, frame, env_idx): sys.exit(0) -def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list): +def run_env_tasks_o3(task_queue: Queue, args: argparse.Namespace, shared_scores: list): active_environments = [] env = None try: @@ -200,9 +203,6 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") example_result_dir = os.path.join( args.result_dir, - args.action_space, - args.observation_type, - args.model, domain, example_id, ) @@ -228,7 +228,7 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li ) except Exception as rec_e: logger.error(f"Failed to end recording: {rec_e}") - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + with open(os.path.join(example_result_dir, "traj.jsonl"), "w") as f: f.write( json.dumps( {"Error": f"{domain}/{example_id} - {e}"} @@ -253,6 +253,106 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li logger.error(f"{current_process().name} error during environment cleanup: {e}") +def run_env_tasks_gpt5(task_queue: Queue, args: argparse.Namespace, shared_scores: list): + active_environments = [] + env = None + try: + if args.provider_name == "aws": + from desktop_env.providers.aws.manager import IMAGE_ID_MAP + REGION = args.region + screen_size = (args.screen_width, args.screen_height) + ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)]) + else: + REGION = None + ami_id = None + screen_size = (args.screen_width, args.screen_height) + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + provider_name=args.provider_name, + region=REGION, + snapshot_name=ami_id, + screen_size=screen_size, + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + enable_proxy=True, + client_password=args.client_password + ) + active_environments.append(env) + logger.info(f"Process {current_process().name} started.") + while True: + try: + item = task_queue.get(timeout=5) + except Exception: + break + domain, example_id = item + try: + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + logger.info(f"[{current_process().name}][Domain]: {domain}") + logger.info(f"[{current_process().name}][Example ID]: {example_id}") + logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") + example_result_dir = os.path.join( + args.result_dir, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + try: + env.reset(task_config=example) + time.sleep(15) + obs = env._get_obs() + + _, traj = run_cua_gpt5gta1( + env=env, + instruction=example["instruction"], + max_steps=args.max_steps, + save_path=example_result_dir, + sleep_after_execution=args.sleep_after_execution, + screen_width=args.screen_width, + screen_height=args.screen_height, + client_password=args.client_password, + tts_step=args.tts_step, + purge_history_images=args.purge_history_images, + cua_model=args.model, + logger=logger, + ) + time.sleep(15) + result = env.evaluate() + shared_scores.append(result) + + with open(os.path.join(example_result_dir, "traj.jsonl"), "w") as f: + json.dump(traj, f) + + with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: + f.write(f"{result}\n") + + except Exception as e: + import traceback + logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") + logger.error(traceback.format_exc()) + except Exception as e: + logger.error(f"Task-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + except Exception as e: + logger.error(f"Process-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + finally: + logger.info(f"{current_process().name} cleaning up environment...") + try: + if env: + env.close() + logger.info(f"{current_process().name} environment closed successfully") + except Exception as e: + logger.error(f"{current_process().name} error during environment cleanup: {e}") + + def signal_handler(signum, frame): """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" global is_terminating, active_environments, processes @@ -313,7 +413,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: processes = [] for i in range(num_envs): p = Process( - target=run_env_tasks, + target=run_env_tasks_o3 if args.model == "o3" else run_env_tasks_gpt5, args=(task_queue, args, shared_scores), name=f"EnvProcess-{i+1}" ) @@ -328,7 +428,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: if not p.is_alive(): logger.warning(f"Process {p.name} died, restarting...") new_p = Process( - target=run_env_tasks, + target=run_env_tasks_o3 if args.model == "o3" else run_env_tasks_gpt5, args=(task_queue, args, shared_scores), name=f"EnvProcess-Restart-{idx+1}" ) @@ -367,7 +467,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: def get_unfinished( action_space, use_model, observation_type, result_dir, total_file_json ): - target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + target_dir = result_dir if not os.path.exists(target_dir): return total_file_json @@ -402,7 +502,7 @@ def get_unfinished( def get_result(action_space, use_model, observation_type, result_dir, total_file_json): - target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + target_dir = result_dir if not os.path.exists(target_dir): print("New experiment, no result yet.") return None @@ -446,18 +546,6 @@ if __name__ == "__main__": try: args = config() - - # save args to json in result_dir/action_space/observation_type/model/args.json - path_to_args = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - args.model, - "args.json", - ) - os.makedirs(os.path.dirname(path_to_args), exist_ok=True) - with open(path_to_args, "w", encoding="utf-8") as f: - json.dump(vars(args), f, indent=4) with open(args.test_all_meta_path, "r", encoding="utf-8") as f: test_all_meta = json.load(f)