update aworldguiAgent code (#342)
This commit is contained in:
70
mm_agents/aworldguiagent/README.md
Normal file
70
mm_agents/aworldguiagent/README.md
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
# aworldGUIAgent-v1
|
||||||
|
|
||||||
|
aworldGUIAgent-v1 built on the [AWorld Framework](https://github.com/inclusionAI/AWorld), specifically designed to tackle complex desktop automation tasks within the [OSWorld-verified](https://os-world.github.io/) benchmark.
|
||||||
|
|
||||||
|
The core logic for our agent's perception and reasoning is adapted from the great work of the [Agent-S project](https://github.com/simular-ai/Agent-S). We have built upon their foundation by introducing a suite of new executable tools that enhance the agent's ability to interact with the OS environment.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Follow these steps to set up the environment and reproduce our results.
|
||||||
|
|
||||||
|
1. **Create Environment & Set Up OSWorld**:
|
||||||
|
* First, create a dedicated Conda environment with **Python 3.11**.
|
||||||
|
```bash
|
||||||
|
conda create -n osworld_env python=3.11
|
||||||
|
conda activate osworld_env
|
||||||
|
```
|
||||||
|
* Next, follow the official setup guide in the [OSWorld README](https://github.com/xlang-ai/OSWorld) to install OSWorld and its dependencies.
|
||||||
|
|
||||||
|
2. **Install AWorld Framework**:
|
||||||
|
* Install the specific version of the AWorld Framework into the **same environment**.
|
||||||
|
```bash
|
||||||
|
# Make sure your osworld_env is still activated
|
||||||
|
git clone https://github.com/inclusionAI/AWorld.git
|
||||||
|
cd AWorld
|
||||||
|
git checkout osworld_benchmark
|
||||||
|
python setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Run the Evaluation Script**:
|
||||||
|
* Our results were achieved using `openai/o3` for reasoning and `bytedance/ui-tars-1.5-7b` for visual grounding, both accessed via OpenRouter.
|
||||||
|
* Remember to replace placeholders like `YOUR_OPENROUTER_API_KEY` and `/path/to/your/vm/Ubuntu.vmx` with your actual credentials and paths.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Activate your OSWorld conda environment (e.g., osworld_env)
|
||||||
|
conda activate osworld_env
|
||||||
|
|
||||||
|
# Run the evaluation with the recommended settings
|
||||||
|
python run_multienv_aworldguiagent.py \
|
||||||
|
--headless \
|
||||||
|
--ground_url YOUR_BASE_URL \
|
||||||
|
--ground_api_key YOUR_API_KEY \
|
||||||
|
--ground_model bytedance/ui-tars-1.5-7b \
|
||||||
|
--ground_provider open_router \
|
||||||
|
--model_url YOUR_BASE_URL \
|
||||||
|
--model_api_key YOUR_API_KEY \
|
||||||
|
--model_temperature 1.0 \
|
||||||
|
--provider_name vmware \
|
||||||
|
--path_to_vm /path/to/your/vm/Ubuntu.vmx \
|
||||||
|
--max_steps 50 \
|
||||||
|
--model_provider open_router \
|
||||||
|
--model openai/o3 \
|
||||||
|
--grounding_width 1920 \
|
||||||
|
--grounding_height 1080 \
|
||||||
|
--test_all_meta_path evaluation_examples/test_all.json \
|
||||||
|
--result_dir ./results \
|
||||||
|
--observation_type screenshot \
|
||||||
|
--num_envs 1 \
|
||||||
|
--region us-east-1 \
|
||||||
|
--client_password osworld-public-evaluation
|
||||||
|
```
|
||||||
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
|
This work would not have been possible without building upon the foundations of several incredible open-source projects.
|
||||||
|
|
||||||
|
- **AWorld Framework**: We thank the developers of the [AWorld Framework](https://github.com/inclusionAI/AWorld) for providing a powerful and flexible platform for agent development. The AWorld Framework is designed for agent training and is especially suited for complex multi-agent scenarios. If you have requirements for designing or experimenting with multi-agent systems, we highly recommend you explore the AWorld Framework further.
|
||||||
|
|
||||||
|
- **Agent-S**: We extend our sincere gratitude to the creators of the [Agent-S project](https://github.com/simular-ai/Agent-S). The core agent logic in our implementation is adapted and enhanced from their codebase. We built upon their work by adding a suite of executable tools to improve the agent's interaction with the OS environment, which effectively boosted the stability and capability of our CUA Agent.
|
||||||
|
|
||||||
|
- **OSWorld Benchmark**: We are grateful to the creators of the [OSWorld Benchmark](https://os-world.github.io/) for developing a challenging and comprehensive testbed for GUI agents.
|
||||||
99
mm_agents/aworldguiagent/agent.py
Normal file
99
mm_agents/aworldguiagent/agent.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""
|
||||||
|
This code is adapted from AgentS2 (https://github.com/simular-ai/Agent-S)
|
||||||
|
with modifications to suit specific requirements.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import platform
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from mm_agents.aworldguiagent.grounding import ACI
|
||||||
|
from mm_agents.aworldguiagent.workflow import Worker
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
|
|
||||||
|
class UIAgent:
|
||||||
|
"""Base class for UI automation agents"""
|
||||||
|
|
||||||
|
""""""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine_params: Dict,
|
||||||
|
grounding_agent: ACI,
|
||||||
|
platform: str = platform.system().lower(),
|
||||||
|
):
|
||||||
|
"""Initialize UIAgent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine_params: Configuration parameters for the LLM engine
|
||||||
|
grounding_agent: Instance of ACI class for UI interaction
|
||||||
|
platform: Operating system platform (macos, linux, windows)
|
||||||
|
"""
|
||||||
|
self.engine_params = engine_params
|
||||||
|
self.grounding_agent = grounding_agent
|
||||||
|
self.platform = platform
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset agent state"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def predict(self, instruction: str, observation: Dict) -> Tuple[Dict, List[str]]:
|
||||||
|
"""Generate next action prediction
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction: Natural language instruction
|
||||||
|
observation: Current UI state observation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple containing agent info dictionary and list of actions
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AworldGUIAgent(UIAgent):
|
||||||
|
"""Agent that uses no hierarchy for less inference time"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine_params: Dict,
|
||||||
|
grounding_agent: ACI,
|
||||||
|
platform: str = platform.system().lower(),
|
||||||
|
max_trajectory_length: int = 8,
|
||||||
|
enable_reflection: bool = True,
|
||||||
|
):
|
||||||
|
"""Initialize a minimalist AgentS2 without hierarchy
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine_params: Configuration parameters for the LLM engine
|
||||||
|
grounding_agent: Instance of ACI class for UI interaction
|
||||||
|
platform: Operating system platform (darwin, linux, windows)
|
||||||
|
max_trajectory_length: Maximum number of image turns to keep
|
||||||
|
enable_reflection: Creates a reflection agent to assist the worker agent
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__(engine_params, grounding_agent, platform)
|
||||||
|
self.max_trajectory_length = max_trajectory_length
|
||||||
|
self.enable_reflection = enable_reflection
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset agent state and initialize components"""
|
||||||
|
self.executor = Worker(
|
||||||
|
engine_params=self.engine_params,
|
||||||
|
grounding_agent=self.grounding_agent,
|
||||||
|
platform=self.platform,
|
||||||
|
max_trajectory_length=self.max_trajectory_length,
|
||||||
|
enable_reflection=self.enable_reflection,
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(self, instruction: str, observation: Dict) -> Tuple[Dict, List[str]]:
|
||||||
|
# Initialize the three info dictionaries
|
||||||
|
executor_info, actions = self.executor.generate_next_action(
|
||||||
|
instruction=instruction, obs=observation
|
||||||
|
)
|
||||||
|
|
||||||
|
# concatenate the three info dictionaries
|
||||||
|
info = {**{k: v for d in [executor_info or {}] for k, v in d.items()}}
|
||||||
|
|
||||||
|
return info, actions
|
||||||
5252
mm_agents/aworldguiagent/grounding.py
Normal file
5252
mm_agents/aworldguiagent/grounding.py
Normal file
File diff suppressed because it is too large
Load Diff
947
mm_agents/aworldguiagent/prompt.py
Normal file
947
mm_agents/aworldguiagent/prompt.py
Normal file
@@ -0,0 +1,947 @@
|
|||||||
|
"""
|
||||||
|
This code is adapted from AgentS2 (https://github.com/simular-ai/Agent-S)
|
||||||
|
with modifications to suit specific requirements.
|
||||||
|
"""
|
||||||
|
GENERATOR_SYS_PROMPT = """You are an expert in graphical user interfaces and Python code. You are responsible for executing the task: `TASK_DESCRIPTION`.
|
||||||
|
You are working in Ubuntu.
|
||||||
|
You are provided with:
|
||||||
|
1. A screenshot of the current time step.
|
||||||
|
2. The history of your previous interactions with the UI.
|
||||||
|
3. Access to the following class and methods to interact with the UI:
|
||||||
|
class Agent:
|
||||||
|
|
||||||
|
def click(self, element_description: str, num_clicks: int = 1, button_type: str = 'left', hold_keys: List = []):
|
||||||
|
'''Click on the element
|
||||||
|
Args:
|
||||||
|
element_description:str, a detailed descriptions of which element to click on. This description should be at least a full sentence.
|
||||||
|
num_clicks:int, number of times to click the element
|
||||||
|
button_type:str, which mouse button to press can be "left", "middle", or "right"
|
||||||
|
hold_keys:List, list of keys to hold while clicking
|
||||||
|
'''
|
||||||
|
|
||||||
|
def done(self, return_value: Union[Dict, str, List, Tuple, int, float, bool, NoneType] = None):
|
||||||
|
'''End the current task with a success and the required return value'''
|
||||||
|
|
||||||
|
def drag_and_drop(self, starting_description: str, ending_description: str, hold_keys: List = []):
|
||||||
|
'''Drag from the starting description to the ending description
|
||||||
|
Args:
|
||||||
|
starting_description:str, a very detailed description of where to start the drag action. This description should be at least a full sentence.
|
||||||
|
ending_description:str, a very detailed description of where to end the drag action. This description should be at least a full sentence.
|
||||||
|
hold_keys:List list of keys to hold while dragging
|
||||||
|
'''
|
||||||
|
|
||||||
|
def fail(self):
|
||||||
|
'''End the current task with a failure, and replan the whole task.'''
|
||||||
|
|
||||||
|
def hold_and_press(self, hold_keys: List, press_keys: List):
|
||||||
|
'''Hold a list of keys and press a list of keys
|
||||||
|
Args:
|
||||||
|
hold_keys:List, list of keys to hold
|
||||||
|
press_keys:List, list of keys to press in a sequence
|
||||||
|
'''
|
||||||
|
|
||||||
|
def hotkey(self, keys: List):
|
||||||
|
'''Press a hotkey combination
|
||||||
|
Args:
|
||||||
|
keys:List the keys to press in combination in a list format (e.g. ['ctrl', 'c'])
|
||||||
|
'''
|
||||||
|
|
||||||
|
def open(self, app_or_filename: str):
|
||||||
|
'''Open any application or file with name app_or_filename. Use this action to open applications or files on the desktop, do not open manually.
|
||||||
|
Args:
|
||||||
|
app_or_filename:str, the name of the application or filename to open
|
||||||
|
'''
|
||||||
|
|
||||||
|
def save_to_knowledge(self, text: List[str]):
|
||||||
|
'''Save facts, elements, texts, etc. to a long-term knowledge bank for reuse during this task. Can be used for copy-pasting text, saving elements, etc.
|
||||||
|
Args:
|
||||||
|
text:List[str] the text to save to the knowledge
|
||||||
|
'''
|
||||||
|
|
||||||
|
def scroll(self, element_description: str, clicks: int, shift: bool = False):
|
||||||
|
'''Scroll the element in the specified direction
|
||||||
|
Args:
|
||||||
|
element_description:str, a very detailed description of which element to enter scroll in. This description should be at least a full sentence.
|
||||||
|
clicks:int, the number of clicks to scroll can be positive (up) or negative (down).
|
||||||
|
shift:bool, whether to use shift+scroll for horizontal scrolling
|
||||||
|
'''
|
||||||
|
|
||||||
|
def set_cell_values(self, cell_values: Dict[str, Any], app_name: str, sheet_name: str):
|
||||||
|
'''Use this to set individual cell values in a spreadsheet. For example, setting A2 to "hello" would be done by passing {"A2": "hello"} as cell_values. The sheet must be opened before this command can be used.
|
||||||
|
Args:
|
||||||
|
cell_values: Dict[str, Any], A dictionary of cell values to set in the spreadsheet. The keys are the cell coordinates in the format "A1", "B2", etc.
|
||||||
|
Supported value types include: float, int, string, bool, formulas.
|
||||||
|
app_name: str, The name of the spreadsheet application. For example, "Some_sheet.xlsx".
|
||||||
|
sheet_name: str, The name of the sheet in the spreadsheet. For example, "Sheet1".
|
||||||
|
'''
|
||||||
|
|
||||||
|
def switch_applications(self, app_code):
|
||||||
|
'''Switch to a different application that is already open
|
||||||
|
Args:
|
||||||
|
app_code:str the code name of the application to switch to from the provided list of open applications
|
||||||
|
'''
|
||||||
|
|
||||||
|
def type(self, element_description: str, text: str = '', overwrite: bool = False, enter: bool = False):
|
||||||
|
'''Type text into a specific element
|
||||||
|
Args:
|
||||||
|
element_description:str, a detailed description of which element to enter text in. This description should be at least a full sentence.
|
||||||
|
text:str, the text to type
|
||||||
|
overwrite:bool, Assign it to True if the text should overwrite the existing text, otherwise assign it to False. Using this argument clears all text in an element.
|
||||||
|
enter:bool, Assign it to True if the enter key should be pressed after typing the text, otherwise assign it to False.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def wait(self, time: float):
|
||||||
|
'''Wait for a specified amount of time
|
||||||
|
Args:
|
||||||
|
time:float the amount of time to wait in seconds
|
||||||
|
'''
|
||||||
|
|
||||||
|
def code_launch_vscode(self, path):
|
||||||
|
'''Launches Visual Studio Code with the specified file path or directory.
|
||||||
|
在存在的窗口中打开一个文件或目录。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): 文件路径或目录。'''
|
||||||
|
|
||||||
|
def code_compare_files(self, file1, file2):
|
||||||
|
'''Compares two files in VSCode.
|
||||||
|
在VSCode中比较两个文件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file1 (str): 第一个文件的路径。
|
||||||
|
file2 (str): 第二个文件的路径。'''
|
||||||
|
|
||||||
|
def code_add_folder(self, folder):
|
||||||
|
'''Adds a folder to the last active window in VSCode.
|
||||||
|
向VSCode的最后一个活动窗口添加文件夹。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder (str): 文件夹路径。'''
|
||||||
|
|
||||||
|
def code_goto_file(self, file_path, line=1, character=1):
|
||||||
|
'''Opens a file at a specific line and character position.
|
||||||
|
在特定行和字符的位置打开文件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): 文件路径。
|
||||||
|
line (int): 行号。
|
||||||
|
character (int): 字符位置。'''
|
||||||
|
|
||||||
|
def code_perform_merge(self, path1, path2, base, result):
|
||||||
|
'''Perform a three-way merge.
|
||||||
|
执行三方合并。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path1 (str): 第一版本文件路径。
|
||||||
|
path2 (str): 第二版本文件路径。
|
||||||
|
base (str): 基础版本文件路径。
|
||||||
|
result (str): 结果文件的保存路径。'''
|
||||||
|
|
||||||
|
def code_remove_folder(self, folder):
|
||||||
|
'''Removes a folder from the last active window in VSCode.
|
||||||
|
在VSCode的最后一个活动窗口中移除文件夹。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder (str): 文件夹路径。'''
|
||||||
|
|
||||||
|
def code_install_extension(self, extension_id, pre_release=False):
|
||||||
|
'''Installs an extension or updates it in VSCode.
|
||||||
|
安装或更新VSCode中的扩展。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extension_id (str): 扩展的标识符。
|
||||||
|
pre_release (bool): 是否安装预发布版本。'''
|
||||||
|
|
||||||
|
def code_uninstall_extension(self, extension_id):
|
||||||
|
'''Uninstalls an extension from VSCode.
|
||||||
|
从VSCode中卸载扩展。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extension_id (str): 扩展的标识符。'''
|
||||||
|
|
||||||
|
def code_list_extensions(self, show_versions=False, category=None):
|
||||||
|
'''Lists installed extensions in VSCode.
|
||||||
|
列出VSCode中安装的扩展。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
show_versions (bool): 是否显示扩展的版本。
|
||||||
|
category (str): 按类别筛选扩展。'''
|
||||||
|
|
||||||
|
def code_update_extensions(self):
|
||||||
|
'''Updates all installed extensions in VSCode to the latest version.
|
||||||
|
更新VSCode中所有安装的扩展到最新版本。'''
|
||||||
|
|
||||||
|
def code_disable_extension(self, extension_id):
|
||||||
|
'''Disables a specific extension for the next instance of VSCode.
|
||||||
|
禁用在下一个VSCode窗口中的指定扩展。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extension_id (str): 扩展的标识符。'''
|
||||||
|
|
||||||
|
def code_toggle_sync(self, state):
|
||||||
|
'''Toggles synchronization on or off in VSCode.
|
||||||
|
在VSCode中开启或关闭同步。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (str): 'on' 或 'off' 表示开启或关闭。'''
|
||||||
|
|
||||||
|
|
||||||
|
def libreoffice_calc_save(self):
|
||||||
|
'''Save the current workbook to its current location
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if save successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_get_workbook_info(self):
|
||||||
|
'''Get workbook information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Workbook information, including file path, file name, sheets and active sheet'''
|
||||||
|
|
||||||
|
def libreoffice_calc_get_column_data(self, column_name):
|
||||||
|
'''Get data from the specified column
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name (str): Name of the column to read
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of values in the specified column'''
|
||||||
|
|
||||||
|
def libreoffice_calc_set_column_as_text(self, column_name):
|
||||||
|
|
||||||
|
'''
|
||||||
|
Set the specified column format as text type.
|
||||||
|
This will convert all numeric values in the column to text format and apply text formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name (str): The column name to format as text (e.g., 'A', 'B', 'C')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Success message or error description
|
||||||
|
|
||||||
|
Example:
|
||||||
|
"Successfully set column A as text format"
|
||||||
|
'''
|
||||||
|
|
||||||
|
def libreoffice_calc_get_active_sheet_data(self):
|
||||||
|
|
||||||
|
'''
|
||||||
|
Get all data from the currently active sheet with detailed coordinate information.
|
||||||
|
Returns data with cell addresses, values, row/column info, and empty cell indicators.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Complete sheet data with detailed cell information
|
||||||
|
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
"data": [
|
||||||
|
[
|
||||||
|
{"address": "A1", "value": "", "row": 1, "col": 1, "col_name": "A", "is_empty": true},
|
||||||
|
{"address": "B1", "value": "Age", "row": 1, "col": 2, "col_name": "B", "is_empty": false}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"address": "A2", "value": "Ryan", "row": 2, "col": 1, "col_name": "A", "is_empty": false},
|
||||||
|
{"address": "B2", "value": 5.0, "row": 2, "col": 2, "col_name": "B", "is_empty": false}
|
||||||
|
],
|
||||||
|
[
|
||||||
|
{"address": "A3", "value": "Jack", "row": 3, "col": 1, "col_name": "A", "is_empty": false},
|
||||||
|
{"address": "B3", "value": 6.0, "row": 3, "col": 2, "col_name": "B", "is_empty": false}
|
||||||
|
]
|
||||||
|
],
|
||||||
|
"rows": 3,
|
||||||
|
"columns": 2,
|
||||||
|
"range": "A1:B3"
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
|
||||||
|
def libreoffice_calc_switch_active_sheet(self, sheet_name):
|
||||||
|
'''Switch to the specified sheet and make it active, create if not exist
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sheet_name (str): Name of the sheet to switch to or create
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_set_column_values(self, column_name, data, start_index=2):
|
||||||
|
'''Set data to the specified column
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name (str): Name of the column to write
|
||||||
|
data (list): List of values to write to the column
|
||||||
|
start_index (int): The index of the first row to write to, default is 2 (skip the first row)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_highlight_range(self, range_str, color=0xFF0000):
|
||||||
|
'''highlight the specified range with the specified color
|
||||||
|
|
||||||
|
Args:
|
||||||
|
range_str (str): Range to highlight, in the format of "A1:B10"
|
||||||
|
color (str): Color to highlight with, default is '0xFF0000' (red)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_transpose_range(self, source_range, target_cell):
|
||||||
|
'''Transpose the specified range and paste it to the target cell
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_range (str): Range to transpose, in the format of "A1:B10"
|
||||||
|
target_cell (str): Target cell to paste the transposed data, in the format of "A1"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_export_to_csv(self):
|
||||||
|
'''Export the current document to a CSV file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_sort_column(self, column_name, ascending=True, start_index=2):
|
||||||
|
'''Sorts the data in the specified column in ascending or descending order
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name (str): The name of the column to sort (e.g. 'A') or the title
|
||||||
|
ascending (bool): Whether to sort in ascending order (default True)
|
||||||
|
start_index (int): The index of the first row to sort, default is 1
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_set_validation_list(self, column_name, values):
|
||||||
|
'''Set a validation list for the specified column
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name (str): The name of the column to set the validation list for
|
||||||
|
values (list): The list of values to use for the validation list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None'''
|
||||||
|
|
||||||
|
def libreoffice_calc_hide_row_data(self, value="N/A"):
|
||||||
|
'''Hide rows that contain the specified value
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value (str): The value to hide rows for, default is 'N/A'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None'''
|
||||||
|
|
||||||
|
def libreoffice_calc_reorder_columns(self, column_order):
|
||||||
|
'''Reorder the columns in the sheet according to the specified order
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_order (list): A list of column names in the desired order
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_create_pivot_table(self,
|
||||||
|
source_sheet,
|
||||||
|
table_name,
|
||||||
|
row_fields=None,
|
||||||
|
col_fields=None,
|
||||||
|
value_fields=None,
|
||||||
|
aggregation_function="sum",
|
||||||
|
target_cell="A1",
|
||||||
|
):
|
||||||
|
'''Create a pivot table in the active worksheet based on data from the active sheet.'''
|
||||||
|
|
||||||
|
def libreoffice_calc_merge_cells(sheet_name, range_str):
|
||||||
|
'''Merges a specified range of cells within a specific worksheet.
|
||||||
|
|
||||||
|
This function connects to a running LibreOffice Calc instance,
|
||||||
|
selects a worksheet by its name, and merges the cells defined
|
||||||
|
by the given range string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sheet_name (str): The name of the worksheet where the cells will be
|
||||||
|
merged, e.g., 'Sheet1' or 'Q4_Report'.
|
||||||
|
range_str (str): The cell range to merge, specified in A1 notation,
|
||||||
|
e.g., 'A1:B10'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the cells were successfully merged, False if an
|
||||||
|
error occurred.
|
||||||
|
'''
|
||||||
|
|
||||||
|
def libreoffice_calc_set_cell_value(self, cell, value):
|
||||||
|
'''Set a value to a specific cell in the active worksheet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cell (str): Cell reference (e.g., 'A1')
|
||||||
|
value (str): Value to set in the cell
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_format_range(self, range_str, background_color=None, font_color=None, bold=None, alignment=None):
|
||||||
|
'''Apply formatting to the specified range in the active worksheet
|
||||||
|
|
||||||
|
Args:
|
||||||
|
range_str (str): Range to format, in the format of 'A1:B10'
|
||||||
|
background_color (str, optional): Background color in hex format (e.g., '#0000ff')
|
||||||
|
font_color (str, optional): Font color in hex format (e.g., '#ffffff')
|
||||||
|
bold (bool, optional): Whether to make the text bold
|
||||||
|
italic (bool, optional): Whether to make the text italic
|
||||||
|
alignment (str, optional): Text alignment (left, center, right)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_freeze_panes(self, rows=0, columns=0):
|
||||||
|
'''冻结活动工作表中的行和/或列
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rows (int): 从顶部开始冻结的行数
|
||||||
|
columns (int): 从左侧开始冻结的列数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 成功返回True,失败返回False'''
|
||||||
|
|
||||||
|
def libreoffice_calc_rename_sheet(self, old_name, new_name):
|
||||||
|
'''重命名工作表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
old_name (str): 要重命名的工作表的当前名称
|
||||||
|
new_name (str): 工作表的新名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 成功返回True,失败返回False'''
|
||||||
|
|
||||||
|
def libreoffice_calc_copy_sheet(self, source_sheet, new_sheet_name=None):
|
||||||
|
'''创建工作簿中现有工作表的副本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_sheet (str): 要复制的工作表名称
|
||||||
|
new_sheet_name (str, optional): 新工作表副本的名称,如果不提供则自动生成
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 新创建的工作表名称,如果失败则返回None'''
|
||||||
|
|
||||||
|
def libreoffice_calc_reorder_sheets(self, sheet_name, position):
|
||||||
|
'''重新排序工作表在工作簿中的位置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sheet_name (str): 要移动的工作表名称
|
||||||
|
position (int): 要移动到的位置(基于0的索引)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 成功返回True,失败返回False'''
|
||||||
|
|
||||||
|
def libreoffice_calc_set_chart_legend_position(self, position):
|
||||||
|
'''Set the position of the legend in a chart in the active worksheet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
position (str): Position of the legend ('top', 'bottom', 'left', 'right', 'none')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_set_number_format(self, range_str, format_type, decimal_places=None):
|
||||||
|
'''Apply a specific number format to a range of cells in the active worksheet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
range_str (str): Range to format, in the format of 'A1:B10'
|
||||||
|
format_type (str): Type of number format to apply
|
||||||
|
decimal_places (int, optional): Number of decimal places to display
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_calc_adjust_column_width(self, columns, width=None, autofit=False):
|
||||||
|
'''调整活动工作表中指定列的宽度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
columns (str): 要调整的列范围,例如 'A:C' 表示从A列到C列
|
||||||
|
width (float, optional): 要设置的宽度(以字符为单位)
|
||||||
|
autofit (bool, optional): 是否自动调整列宽以适应内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 成功返回True,失败返回False'''
|
||||||
|
|
||||||
|
def libreoffice_calc_adjust_row_height(self, rows, height=None, autofit=False):
|
||||||
|
'''调整活动工作表中指定行的高度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rows (str): 要调整的行范围,例如 '1:10' 表示第1行到第10行
|
||||||
|
height (float, optional): 要设置的高度(以点为单位)
|
||||||
|
autofit (bool, optional): 是否自动调整行高以适应内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 操作成功返回True,否则返回False'''
|
||||||
|
|
||||||
|
def libreoffice_calc_export_to_pdf(self, file_path=None, sheets=None, open_after_export=False):
|
||||||
|
'''将当前文档或指定工作表导出为PDF文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str, optional): PDF文件保存路径,如果不指定则使用当前文档路径
|
||||||
|
sheets (list, optional): 要包含在PDF中的工作表名称列表,如果不指定则包含所有工作表
|
||||||
|
open_after_export (bool, optional): 导出后是否打开PDF文件
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 成功返回True,失败返回False'''
|
||||||
|
|
||||||
|
def libreoffice_calc_set_zoom_level(self, zoom_percentage):
|
||||||
|
'''调整当前工作表的缩放级别,使单元格看起来更大或更小
|
||||||
|
|
||||||
|
Args:
|
||||||
|
zoom_percentage (int): 缩放级别的百分比(例如,75表示75%,100表示正常大小,150表示放大)。
|
||||||
|
有效范围通常为10-400。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 成功返回True,失败返回False'''
|
||||||
|
|
||||||
|
|
||||||
|
def libreoffice_impress_save(self):
|
||||||
|
'''保存文档到当前位置'''
|
||||||
|
|
||||||
|
def libreoffice_impress_go_to_slide(self, slide_index):
|
||||||
|
'''Navigates to a specific slide in the presentation based on its index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slide_index (int): The index of the slide to navigate to (1-based indexing)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if navigation was successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_get_slide_count(self):
|
||||||
|
'''Gets the total number of slides in the current presentation.
|
||||||
|
:return: The total number of slides as an integer'''
|
||||||
|
|
||||||
|
def libreoffice_impress_duplicate_slide(self, slide_index):
|
||||||
|
'''Creates a duplicate of a specific slide and places it at the end of the presentation.
|
||||||
|
|
||||||
|
:param slide_index: The index of the slide to duplicate (1-based indexing)
|
||||||
|
:return: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_set_slide_font(self, slide_index, font_name):
|
||||||
|
'''Sets the font style for all text elements in a specific slide, including the title.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slide_index (int): The index of the slide to modify (1-based indexing)
|
||||||
|
font_name (str): The name of the font to apply (e.g., 'Arial', 'Times New Roman', 'Calibri')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_write_text(self, content, page_index, box_index, bold=False, italic=False, size=None, append=False):
|
||||||
|
'''Writes text to a specific textbox on a slide
|
||||||
|
|
||||||
|
:param content: The text content to add
|
||||||
|
:param page_index: The index of the slide (1-based indexing)
|
||||||
|
:param box_index: The index of the textbox to modify (0-based indexing)
|
||||||
|
:param bold: Whether to make the text bold, default is False
|
||||||
|
:param italic: Whether to make the text italic, default is False
|
||||||
|
:param size: The size of the text. If None, uses the box's current font size.
|
||||||
|
:param append: Whether to append the text, default is False. If you want to observe some formats(like a bullet at the beginning) or keep the original text, you should set up it.
|
||||||
|
:return: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_set_style(self, slide_index, box_index, bold=None, italic=None, underline=None):
|
||||||
|
'''Sets the style properties for the specified textbox on a slide.
|
||||||
|
|
||||||
|
:param slide_index: The index of the slide to modify (1-based indexing)
|
||||||
|
:param box_index: The index of the textbox to modify (0-based indexing)
|
||||||
|
:param bold: Whether to make the text bold
|
||||||
|
:param italic: Whether to make the text italic
|
||||||
|
:param underline: Whether to underline the text
|
||||||
|
:return: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_configure_auto_save(self, enabled, interval_minutes):
|
||||||
|
'''Enables or disables auto-save functionality for the current document and sets the auto-save interval.
|
||||||
|
|
||||||
|
:param enabled: Whether to enable (True) or disable (False) auto-save
|
||||||
|
:param interval_minutes: The interval in minutes between auto-saves (minimum 1 minute)
|
||||||
|
:return: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_set_background_color(self, slide_index, box_index, color):
|
||||||
|
'''Sets the background color for the specified textbox on a slide.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slide_index (int): The index of the slide containing the textbox (1-based indexing)
|
||||||
|
box_index (int): The index of the textbox to modify (0-based indexing)
|
||||||
|
color (str): The color to apply to the textbox (e.g., 'red', 'green', 'blue', 'yellow', or hex color code)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_set_text_color(self, slide_index, box_index, color):
|
||||||
|
'''Sets the text color for the specified textbox on a slide.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slide_index (int): The index of the slide to modify (1-based indexing)
|
||||||
|
box_index (int): The index of the textbox to modify (0-based indexing)
|
||||||
|
color (str): The color to apply to the text (e.g., 'red', 'green', 'blue', 'black', or hex color code)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_delete_content(self, slide_index, box_index):
|
||||||
|
'''Deletes the specified textbox from a slide.
|
||||||
|
|
||||||
|
:param slide_index: The index of the slide to modify (1-based indexing)
|
||||||
|
:param box_index: The index of the textbox to modify (0-based indexing)
|
||||||
|
:return: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_set_slide_orientation(self, orientation):
|
||||||
|
'''Changes the orientation of slides in the presentation between portrait (upright) and landscape (sideways).
|
||||||
|
|
||||||
|
:param orientation: The desired orientation for the slides ('portrait' or 'landscape')
|
||||||
|
:return: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_position_box(self, slide_index, box_index, position):
|
||||||
|
'''Positions a textbox or image on a slide at a specific location or predefined position.
|
||||||
|
|
||||||
|
:param slide_index: The index of the slide containing the box (1-based indexing)
|
||||||
|
:param box_index: The index of the box to position (0-based indexing)
|
||||||
|
:param position: Predefined position on the slide (left, right, center, top, bottom, etc.)
|
||||||
|
:return: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_insert_file(self, file_path, slide_index=None, position=None, size=None, autoplay=False):
|
||||||
|
'''Inserts a video file into the current or specified slide in the presentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): The full path to the video file to be inserted
|
||||||
|
slide_index (int, optional): The index of the slide to insert the video into (1-based indexing).
|
||||||
|
If not provided, inserts into the current slide.
|
||||||
|
position (dict, optional): The position coordinates for the video as percentages of slide dimensions
|
||||||
|
{'x': float, 'y': float}
|
||||||
|
size (dict, optional): The size dimensions for the video as percentages of slide dimensions
|
||||||
|
{'width': float, 'height': float}
|
||||||
|
autoplay (bool, optional): Whether the video should automatically play when the slide is shown
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_set_slide_background(self, slide_index=None, color=None, image_path=None):
|
||||||
|
'''Sets the background color or image for a specific slide or all slides.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slide_index (int, optional): The index of the slide to modify (1-based indexing).
|
||||||
|
If not provided, applies to all slides.
|
||||||
|
color (str, optional): The background color to apply (e.g., 'red', 'green', 'blue', or hex color code)
|
||||||
|
image_path (str, optional): Path to an image file to use as background. If provided, overrides color.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_save_as(self, file_path, overwrite=False):
|
||||||
|
'''Saves the current document to a specified location with a given filename.
|
||||||
|
|
||||||
|
:param file_path: The full path where the file should be saved, including the filename and extension
|
||||||
|
:param overwrite: Whether to overwrite the file if it already exists (default: False)
|
||||||
|
:return: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_insert_image(self, slide_index, image_path, width=None, height=None, position=None):
|
||||||
|
'''Inserts an image to a specific slide in the presentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slide_index (int): The index of the slide to add the image to (1-based indexing)
|
||||||
|
image_path (str): The full path to the image file to be added
|
||||||
|
width (float, optional): The width of the image in centimeters
|
||||||
|
height (float, optional): The height of the image in centimeters
|
||||||
|
position (dict, optional): The position coordinates for the image as percentages
|
||||||
|
{
|
||||||
|
'x': float, # The x-coordinate as a percentage of slide width
|
||||||
|
'y': float # The y-coordinate as a percentage of slide height
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_configure_display_settings(self, use_presenter_view=None, primary_monitor_only=None, monitor_for_presentation=None
|
||||||
|
):
|
||||||
|
'''Configures the display settings for LibreOffice Impress presentations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_presenter_view (bool, optional): Whether to use presenter view. Set to false to disable presenter view.
|
||||||
|
primary_monitor_only (bool, optional): Whether to use only the primary monitor for the presentation.
|
||||||
|
monitor_for_presentation (int, optional): Specify which monitor to use (1 for primary, 2 for secondary, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if settings were successfully applied, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_set_text_strikethrough(self, slide_index, box_index, line_numbers, apply):
|
||||||
|
'''Applies or removes strike-through formatting to specific text content in a slide.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slide_index (int): The index of the slide containing the text (1-based indexing)
|
||||||
|
box_index (int): The index of the textbox containing the text (0-based indexing)
|
||||||
|
line_numbers (list): The line numbers to apply strike-through formatting to (1-based indexing)
|
||||||
|
apply (bool): Whether to apply (true) or remove (false) strike-through formatting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_set_textbox_alignment(self, slide_index, box_index, alignment):
|
||||||
|
'''Sets the text alignment for the specified textbox on a slide.
|
||||||
|
|
||||||
|
:param slide_index: The index of the slide to modify (1-based indexing)
|
||||||
|
:param box_index: The index of the textbox to modify (0-based indexing)
|
||||||
|
:param alignment: The text alignment to apply ('left', 'center', 'right', or 'justify')
|
||||||
|
:return: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_set_slide_number_color(self, color):
|
||||||
|
'''Sets the color of the slide number in the presentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
color (str): The color to apply to slide numbers (e.g., 'red', 'green', 'blue', 'black', or hex color code)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_impress_export_to_image(self, file_path, format, slide_index=None):
|
||||||
|
'''Exports the current presentation or a specific slide to an image file format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): The full path where the image file should be saved, including the filename and extension
|
||||||
|
format (str): The image format to export to (e.g., 'png', 'jpeg', 'gif')
|
||||||
|
slide_index (int, optional): The index of the specific slide to export (1-based indexing).
|
||||||
|
If not provided, exports the entire presentation as a series of images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if export was successful, False otherwise'''
|
||||||
|
|
||||||
|
|
||||||
|
def libreoffice_writer_save(self):
|
||||||
|
'''保存文档到当前位置'''
|
||||||
|
|
||||||
|
def libreoffice_writer_write_text(self, text, bold=False, italic=False, size=None):
|
||||||
|
'''写入文本'''
|
||||||
|
|
||||||
|
def libreoffice_writer_set_color(self, pattern, color, paragraph_indices=None):
|
||||||
|
'''Changes the color of matched text in the document for specified paragraphs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern (str): Regular expression pattern to match text
|
||||||
|
color (int): Hex color code (e.g., 0x000000 for black)
|
||||||
|
paragraph_indices (list, optional): List of paragraph indices to modify (0-based).
|
||||||
|
If None, applies to all paragraphs.'''
|
||||||
|
|
||||||
|
def libreoffice_writer_find_and_replace(self, pattern, replacement, paragraph_indices=None):
|
||||||
|
'''Finds all occurrences of a specified text pattern and replaces them with another text in the document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern (str): The pattern to match in the document, should be a regular expression
|
||||||
|
replacement (str): The text to replace the found text with
|
||||||
|
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Success message with number of replacements made'''
|
||||||
|
|
||||||
|
def libreoffice_writer_set_font(self, font_name, paragraph_indices=None):
|
||||||
|
'''Changes the font of text in the document or specified paragraphs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
font_name (str): The name of the font to apply (e.g., 'Times New Roman', 'Arial', 'Calibri')
|
||||||
|
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||||
|
If not provided, applies to all paragraphs.'''
|
||||||
|
|
||||||
|
def libreoffice_writer_set_line_spacing(self, spacing_value, paragraph_indices=None):
|
||||||
|
'''Sets the line spacing for specified paragraphs in the document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
spacing_value (float): The line spacing value to apply (1.0 for single spacing, 2.0 for double spacing, etc.)
|
||||||
|
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||||
|
If not provided, applies to all paragraphs.'''
|
||||||
|
|
||||||
|
def libreoffice_writer_remove_highlighting(self, paragraph_indices=None):
|
||||||
|
'''Removes ALL highlighting from text in the document for specified paragraphs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||||
|
If not provided, applies to all paragraphs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Success message or error message'''
|
||||||
|
|
||||||
|
def libreoffice_writer_find_highlighted_text(self, highlight_color):
|
||||||
|
'''Finds all text in the document that has a specific highlight color applied to it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
highlight_color (str): The highlight color to search for. Can be a color name (e.g., 'yellow', 'green') or hex code.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of strings containing all text segments with the specified highlight color.'''
|
||||||
|
|
||||||
|
def libreoffice_writer_insert_formula_at_cursor(self, formula):
|
||||||
|
'''Inserts a formula at the current cursor position in the document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
formula (str): The formula to insert at the current cursor position.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise'''
|
||||||
|
|
||||||
|
def libreoffice_writer_insert_image_at_cursor(self, image_path, width=None, height=None):
|
||||||
|
'''Inserts an image at the current cursor position in the document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path (str): Full path to the image file to insert
|
||||||
|
width (int, optional): Width to display the image in pixels
|
||||||
|
height (int, optional): Height to display the image in pixels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Success message or error message'''
|
||||||
|
|
||||||
|
def libreoffice_writer_set_strikethrough(self, pattern, paragraph_indices=None):
|
||||||
|
'''Sets the strikethrough formatting for text matching the specified pattern in the document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern (str): The regular expression pattern to match in the document
|
||||||
|
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||||
|
If not provided, applies to all paragraphs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Success message or error information'''
|
||||||
|
|
||||||
|
def libreoffice_writer_set_font_size(self, font_size, pattern, paragraph_indices=None):
|
||||||
|
'''Changes the font size of specified text in the document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
font_size (float): The font size to apply (in points).
|
||||||
|
pattern (str): The pattern to match in the document, should be a regular expression.
|
||||||
|
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||||
|
If not provided, applies to all paragraphs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Result message indicating success or failure.'''
|
||||||
|
|
||||||
|
def libreoffice_writer_export_to_pdf(self, output_path=None, output_filename=None, include_comments=False, quality="standard"):
|
||||||
|
'''Exports the current document to PDF format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path (str, optional): The full path where the PDF should be saved.
|
||||||
|
If not provided, uses the same location as the original document.
|
||||||
|
output_filename (str, optional): The filename to use for the PDF.
|
||||||
|
If not provided, uses the original document's filename with .pdf extension.
|
||||||
|
include_comments (bool, optional): Whether to include comments in the exported PDF.
|
||||||
|
Defaults to False.
|
||||||
|
quality (str, optional): The quality of the PDF export ('standard', 'high', 'print').
|
||||||
|
Defaults to 'standard'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Path to the exported PDF file or error message'''
|
||||||
|
|
||||||
|
def libreoffice_writer_set_paragraph_alignment(self, alignment, paragraph_indices=None):
|
||||||
|
'''Sets the text alignment for specified paragraphs in the document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
alignment (str): The alignment to apply ('left', 'center', 'right', 'justify').
|
||||||
|
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||||
|
If not provided, applies to all paragraphs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Success message or error message'''
|
||||||
|
|
||||||
|
def libreoffice_writer_capitalize_words(self, paragraph_indices=None):
|
||||||
|
'''Capitalizes the first letter of each word for specified paragraphs in the document.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paragraph_indices (list, optional): Indices of paragraphs to modify (0-based indexing).
|
||||||
|
If not provided, applies to all paragraphs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Success message or error message'''
|
||||||
|
|
||||||
|
def libreoffice_writer_set_default_font(self, font_name, font_size=None):
|
||||||
|
'''Sets the default font for new text in the document without changing existing text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
font_name (str): The name of the font to set as default (e.g., 'Times New Roman', 'Arial', 'Calibri')
|
||||||
|
font_size (float, optional): The default font size in points.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Success message or error message'''
|
||||||
|
|
||||||
|
def libreoffice_writer_add_page_numbers(self, position, start_number=1, format=None):
|
||||||
|
'''Adds page numbers to the document at the specified position.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
position (str): Position of the page numbers ('bottom_left', 'bottom_center', 'bottom_right',
|
||||||
|
'top_left', 'top_center', 'top_right')
|
||||||
|
start_number (int
|
||||||
|
def libreoffice_writer_add_page_numbers(self, position, start_number=1, format=None):
|
||||||
|
'''Adds page numbers to the document at the specified position.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
position (str): Position of the page numbers ('bottom_left', 'bottom_center', 'bottom_right',
|
||||||
|
'top_left', 'top_center', 'top_right')
|
||||||
|
start_number (int, optional): The starting page number. Defaults to 1.
|
||||||
|
format (str, optional): Format of the page numbers (e.g., '1', 'Page 1', '1 of N').
|
||||||
|
Defaults to simple number format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Success message or error message''', optional): The starting page number. Defaults to 1.
|
||||||
|
format (str, optional): Format of the page numbers (e.g., '1', 'Page 1', '1 of N').
|
||||||
|
Defaults to simple number format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Success message or error message'''
|
||||||
|
|
||||||
|
def libreoffice_writer_insert_page_break(self, position="at_cursor"):
|
||||||
|
'''Inserts a page break at the specified position.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
position (str): Where to insert the page break: 'at_cursor' for current cursor position,
|
||||||
|
'end_of_document' for end of document. Defaults to 'at_cursor'.'''
|
||||||
|
|
||||||
|
Your response should be formatted like this:
|
||||||
|
(Previous action verification)
|
||||||
|
Carefully analyze based on the screenshot if the previous action was successful. If the previous action was not successful, provide a reason for the failure.
|
||||||
|
|
||||||
|
(Screenshot Analysis)
|
||||||
|
Closely examine and describe the current state of the desktop along with the currently open applications.
|
||||||
|
|
||||||
|
(Next Action)
|
||||||
|
Based on the current screenshot and the history of your previous interaction with the UI, decide on the next action in natural language to accomplish the given task.
|
||||||
|
|
||||||
|
(Grounded Action)
|
||||||
|
Translate the next action into code using the provided API methods. Format the code like this:
|
||||||
|
```python
|
||||||
|
agent.click("The menu button at the top right of the window", 1, "left")
|
||||||
|
```
|
||||||
|
Note for the code:
|
||||||
|
1. Only perform one action at a time.
|
||||||
|
2. Do not put anything other than python code in the block. You can only use one function call at a time. Do not put more than one function call in the block.
|
||||||
|
3. You must use only the available methods provided above to interact with the UI, do not invent new methods.
|
||||||
|
4. Only return one code block every time. There must be a single line of code in the code block.
|
||||||
|
5. Do not do anything other than the exact specified task. Return with `agent.done()` immediately after the subtask is completed or `agent.fail()` if it cannot be completed.
|
||||||
|
6. Whenever possible, your grounded action should use hot-keys with the agent.hotkey() action instead of clicking or dragging.
|
||||||
|
7. My computer's password is 'osworld-public-evaluation', feel free to use it when you need sudo rights.
|
||||||
|
8. Before performing any calculations on elements in a table or inserting charts, always use libreoffice_calc_get_column_data or libreoffice_calc_get_active_sheet_data to obtain accurate column coordinates and element values from the table, ensuring precise execution of subsequent calculations or chart insertions.
|
||||||
|
9. Generate agent.fail() as your grounded action if you get exhaustively stuck on the task and believe it is impossible.
|
||||||
|
10. Generate agent.done() as your grounded action when your believe the task is fully complete.
|
||||||
|
11. Do not use the "command" + "tab" hotkey on MacOS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
REFLECTION_SYS_PROMPT = """
|
||||||
|
You are an expert computer use agent designed to reflect on the trajectory of a task and provide feedback on what has happened so far.
|
||||||
|
You have access to the Task Description and the Current Trajectory of another computer agent. The Current Trajectory is a sequence of a desktop image, chain-of-thought reasoning, and a desktop action for each time step. The last image is the screen's display after the last action.
|
||||||
|
Your task is to generate a reflection. Your generated reflection must fall under one of the cases listed below:
|
||||||
|
|
||||||
|
**Your judgment must be based solely on a critical comparison between the agent's stated plan/reasoning and the visual evidence presented in the screenshot history.** Do not take the agent's claims of success at face value. **If there is no visual proof in the screenshot, the action did not happen.**
|
||||||
|
|
||||||
|
Case 1. The trajectory is not going according to plan. This occurs when there is a mismatch between the intended action and the visual outcome, when the agent hallucinates information, or when it is stuck. You must trigger Case 1 if you detect any of the following:
|
||||||
|
Failed Action: The previous action did not produce its expected visual change on the screen (e.g., a window failed to open, text was not pasted).
|
||||||
|
Unsupported Conclusion (Hallucination): The agent makes a claim or states a result (like a number or a fact) that is not visibly supported by the current or any previous screenshot. This is a critical failure.
|
||||||
|
Repetitive Cycle: The agent is repeating actions without making meaningful progress.
|
||||||
|
Case 2. The trajectory is going according to plan. In this case, simply tell the agent to continue proceeding as planned. DO NOT encourage a specific action in particular.
|
||||||
|
Case 3. You believe the current task has been completed. In this case, tell the agent that the task has been successfully completed.
|
||||||
|
|
||||||
|
To be successful, you must follow the rules below:
|
||||||
|
- **Your output MUST be based on one of the case options above**.
|
||||||
|
- DO NOT suggest any specific future plans or actions. Your only goal is to provide a reflection, not an actual plan or action.
|
||||||
|
- Any response that falls under Case 1 should explain why the trajectory is not going according to plan. You should especially lookout for cycles of actions that are continually repeated with no progress.
|
||||||
|
- Any response that falls under Case 2 should be concise, since you just need to affirm the agent to continue with the current trajectory.
|
||||||
|
"""
|
||||||
194
mm_agents/aworldguiagent/utils.py
Normal file
194
mm_agents/aworldguiagent/utils.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
"""
|
||||||
|
This code is adapted from AgentS2 (https://github.com/simular-ai/Agent-S)
|
||||||
|
with modifications to suit specific requirements.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
import base64
|
||||||
|
from aworld.core.common import Observation, ActionModel
|
||||||
|
from aworld.models.model_response import ModelResponse
|
||||||
|
from aworld.core.agent.base import AgentResult
|
||||||
|
from aworld.memory.main import InMemoryMemoryStore
|
||||||
|
|
||||||
|
def encode_image(image_content):
|
||||||
|
# if image_content is a path to an image file, check type of the image_content to verify
|
||||||
|
if isinstance(image_content, str):
|
||||||
|
with open(image_content, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
else:
|
||||||
|
return base64.b64encode(image_content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def extract_first_agent_function(code_string):
|
||||||
|
# Regular expression pattern to match 'agent' functions with any arguments, including nested parentheses
|
||||||
|
pattern = r'agent\.[a-zA-Z_]+\((?:[^()\'"]|\'[^\']*\'|"[^"]*")*\)'
|
||||||
|
|
||||||
|
# Find all matches in the string
|
||||||
|
matches = re.findall(pattern, code_string)
|
||||||
|
|
||||||
|
# Return the first match if found, otherwise return None
|
||||||
|
return matches[0] if matches else None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_single_code_from_string(input_string):
|
||||||
|
input_string = input_string.strip()
|
||||||
|
if input_string.strip() in ["WAIT", "DONE", "FAIL"]:
|
||||||
|
return input_string.strip()
|
||||||
|
|
||||||
|
# This regular expression will match both ```code``` and ```python code```
|
||||||
|
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
||||||
|
pattern = r"```(?:\w+\s+)?(.*?)```"
|
||||||
|
# Find all non-overlapping matches in the string
|
||||||
|
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||||
|
|
||||||
|
# The regex above captures the content inside the triple backticks.
|
||||||
|
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
|
||||||
|
# so the code inside backticks can span multiple lines.
|
||||||
|
|
||||||
|
# matches now contains all the captured code snippets
|
||||||
|
|
||||||
|
codes = []
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
match = match.strip()
|
||||||
|
commands = [
|
||||||
|
"WAIT",
|
||||||
|
"DONE",
|
||||||
|
"FAIL",
|
||||||
|
] # fixme: updates this part when we have more commands
|
||||||
|
|
||||||
|
if match in commands:
|
||||||
|
codes.append(match.strip())
|
||||||
|
elif match.split("\n")[-1] in commands:
|
||||||
|
if len(match.split("\n")) > 1:
|
||||||
|
codes.append("\n".join(match.split("\n")[:-1]))
|
||||||
|
codes.append(match.split("\n")[-1])
|
||||||
|
else:
|
||||||
|
codes.append(match)
|
||||||
|
|
||||||
|
if len(codes) <= 0:
|
||||||
|
return "fail"
|
||||||
|
return codes[0]
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_code(code):
|
||||||
|
# This pattern captures the outermost double-quoted text
|
||||||
|
if "\n" in code:
|
||||||
|
pattern = r'(".*?")'
|
||||||
|
# Find all matches in the text
|
||||||
|
matches = re.findall(pattern, code, flags=re.DOTALL)
|
||||||
|
if matches:
|
||||||
|
# Replace the first occurrence only
|
||||||
|
first_match = matches[0]
|
||||||
|
code = code.replace(first_match, f'"""{first_match[1:-1]}"""', 1)
|
||||||
|
return code
|
||||||
|
|
||||||
|
def prune_image_messages(memory_store: InMemoryMemoryStore, max_trajectory_length: int):
|
||||||
|
"""
|
||||||
|
检查 memory_store 中的消息,并仅保留最新的 max_trajectory_length 个包含图片的消息。
|
||||||
|
对于更早的包含图片的消息,会从其 content 中移除图片部分。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_store (InMemoryMemoryStore): 内存存储的对象实例。
|
||||||
|
max_trajectory_length (int): 希望保留的含图片消息的最大数量。
|
||||||
|
"""
|
||||||
|
# 步骤 1: 使用 memory_store 的 get_all 方法获取所有消息
|
||||||
|
all_items = memory_store.get_all()
|
||||||
|
|
||||||
|
# 步骤 2: 筛选出所有包含图片内容的消息
|
||||||
|
image_messages = []
|
||||||
|
for item in all_items:
|
||||||
|
if isinstance(item.content, list):
|
||||||
|
if any(isinstance(part, dict) and part.get('type') == 'image_url' for part in item.content):
|
||||||
|
image_messages.append(item)
|
||||||
|
|
||||||
|
# 步骤 3: 检查包含图片的消息数量是否超过限制
|
||||||
|
if len(image_messages) <= max_trajectory_length:
|
||||||
|
print("Number of image messages does not exceed the limit. No pruning needed.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 步骤 4: 确定需要移除图片的旧消息
|
||||||
|
# 由于 get_all() 返回的列表是按添加顺序排列的,所以列表前面的项就是最旧的
|
||||||
|
num_to_prune = len(image_messages) - max_trajectory_length
|
||||||
|
messages_to_prune = image_messages[:num_to_prune]
|
||||||
|
|
||||||
|
print(f"Found {len(image_messages)} image messages. Pruning the oldest {num_to_prune}.")
|
||||||
|
|
||||||
|
# 步骤 5: 遍历需要修剪的消息,更新其 content,并使用 store 的 update 方法保存
|
||||||
|
for item_to_prune in messages_to_prune:
|
||||||
|
|
||||||
|
# 创建一个新的 content 列表,仅包含非图片部分
|
||||||
|
new_content = [
|
||||||
|
part for part in item_to_prune.content
|
||||||
|
if not (isinstance(part, dict) and part.get('type') == 'image_url')
|
||||||
|
]
|
||||||
|
|
||||||
|
# 可选:如果 new_content 中只剩下一个文本元素,可以将其简化为字符串
|
||||||
|
if len(new_content) == 1 and new_content[0].get('type') == 'text':
|
||||||
|
final_content = new_content[0].get('text', '')
|
||||||
|
else:
|
||||||
|
final_content = new_content
|
||||||
|
|
||||||
|
# 更新消息对象的 content 属性
|
||||||
|
item_to_prune.content = final_content
|
||||||
|
|
||||||
|
# 使用 memory_store 的 update 方法将更改持久化到 store 中
|
||||||
|
memory_store.update(item_to_prune)
|
||||||
|
|
||||||
|
print(f"Pruned image from message with ID: {item_to_prune.id}")
|
||||||
|
|
||||||
|
def reps_action_result(resp: ModelResponse) -> AgentResult:
|
||||||
|
try:
|
||||||
|
full_response = resp.content
|
||||||
|
# Extract thoughts section
|
||||||
|
thoughts_match = re.search(
|
||||||
|
r"<thoughts>(.*?)</thoughts>", full_response, re.DOTALL
|
||||||
|
)
|
||||||
|
thoughts = thoughts_match.group(1).strip()
|
||||||
|
# Extract answer section
|
||||||
|
answer_match = re.search(r"<answer>(.*?)</answer>", full_response, re.DOTALL)
|
||||||
|
answer = answer_match.group(1).strip()
|
||||||
|
action = ActionModel(action_name=answer, policy_info=thoughts)
|
||||||
|
return AgentResult(actions=[action], current_state=None)
|
||||||
|
except Exception as e:
|
||||||
|
action = ActionModel(action_name=resp.content, policy_info="")
|
||||||
|
return AgentResult(actions=[action], current_state=None)
|
||||||
|
|
||||||
|
def parse_single_code_from_string(input_string):
|
||||||
|
input_string = input_string.strip()
|
||||||
|
if input_string.strip() in ["WAIT", "DONE", "FAIL"]:
|
||||||
|
return input_string.strip()
|
||||||
|
|
||||||
|
# This regular expression will match both ```code``` and ```python code```
|
||||||
|
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
||||||
|
pattern = r"```(?:\w+\s+)?(.*?)```"
|
||||||
|
# Find all non-overlapping matches in the string
|
||||||
|
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||||
|
|
||||||
|
# The regex above captures the content inside the triple backticks.
|
||||||
|
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
|
||||||
|
# so the code inside backticks can span multiple lines.
|
||||||
|
|
||||||
|
# matches now contains all the captured code snippets
|
||||||
|
|
||||||
|
codes = []
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
match = match.strip()
|
||||||
|
commands = [
|
||||||
|
"WAIT",
|
||||||
|
"DONE",
|
||||||
|
"FAIL",
|
||||||
|
] # fixme: updates this part when we have more commands
|
||||||
|
|
||||||
|
if match in commands:
|
||||||
|
codes.append(match.strip())
|
||||||
|
elif match.split("\n")[-1] in commands:
|
||||||
|
if len(match.split("\n")) > 1:
|
||||||
|
codes.append("\n".join(match.split("\n")[:-1]))
|
||||||
|
codes.append(match.split("\n")[-1])
|
||||||
|
else:
|
||||||
|
codes.append(match)
|
||||||
|
|
||||||
|
if len(codes) <= 0:
|
||||||
|
return "fail"
|
||||||
|
return codes[0]
|
||||||
230
mm_agents/aworldguiagent/workflow.py
Normal file
230
mm_agents/aworldguiagent/workflow.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""
|
||||||
|
This code is adapted from AgentS2 (https://github.com/simular-ai/Agent-S)
|
||||||
|
with modifications to suit specific requirements.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import textwrap
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from aworld.config.conf import AgentConfig
|
||||||
|
from aworld.agents.llm_agent import Agent
|
||||||
|
from aworld.core.common import Observation
|
||||||
|
|
||||||
|
from aworld.core.task import Task
|
||||||
|
from aworld.core.context.base import Context
|
||||||
|
from aworld.core.event.base import Message
|
||||||
|
from aworld.models.llm import get_llm_model
|
||||||
|
from aworld.utils.common import sync_exec
|
||||||
|
|
||||||
|
from mm_agents.aworldguiagent.grounding import ACI
|
||||||
|
from mm_agents.aworldguiagent.prompt import GENERATOR_SYS_PROMPT, REFLECTION_SYS_PROMPT
|
||||||
|
from mm_agents.aworldguiagent.utils import encode_image, extract_first_agent_function, parse_single_code_from_string, sanitize_code
|
||||||
|
from mm_agents.aworldguiagent.utils import prune_image_messages, reps_action_result
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine_params: Dict,
|
||||||
|
grounding_agent: ACI,
|
||||||
|
platform: str = "ubuntu",
|
||||||
|
max_trajectory_length: int = 16,
|
||||||
|
enable_reflection: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Worker receives the main task and generates actions, without the need of hierarchical planning
|
||||||
|
Args:
|
||||||
|
engine_params: Dict
|
||||||
|
Parameters for the multimodal engine
|
||||||
|
grounding_agent: Agent
|
||||||
|
The grounding agent to use
|
||||||
|
platform: str
|
||||||
|
OS platform the agent runs on (darwin, linux, windows)
|
||||||
|
max_trajectory_length: int
|
||||||
|
The amount of images turns to keep
|
||||||
|
enable_reflection: bool
|
||||||
|
Whether to enable reflection
|
||||||
|
"""
|
||||||
|
# super().__init__(engine_params, platform)
|
||||||
|
|
||||||
|
self.grounding_agent = grounding_agent
|
||||||
|
self.max_trajectory_length = max_trajectory_length
|
||||||
|
self.enable_reflection = enable_reflection
|
||||||
|
self.use_thinking = engine_params.get("model", "") in [
|
||||||
|
"claude-3-7-sonnet-20250219"
|
||||||
|
]
|
||||||
|
|
||||||
|
self.generator_agent_config = AgentConfig(
|
||||||
|
llm_provider=engine_params.get("engine_type", "openai"),
|
||||||
|
llm_model_name=engine_params.get("model", "openai/o3",),
|
||||||
|
llm_temperature=engine_params.get("temperature", 1.0),
|
||||||
|
llm_base_url=engine_params.get("base_url", "https://openrouter.ai/api/v1"),
|
||||||
|
llm_api_key=engine_params.get("api_key", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
|
||||||
|
self.generator_agent = Agent(
|
||||||
|
name="generator_agent",
|
||||||
|
conf=self.generator_agent_config,
|
||||||
|
system_prompt=GENERATOR_SYS_PROMPT,
|
||||||
|
resp_parse_func=reps_action_result
|
||||||
|
)
|
||||||
|
|
||||||
|
self.reflection_agent = Agent(
|
||||||
|
name="reflection_agent",
|
||||||
|
conf=self.generator_agent_config,
|
||||||
|
system_prompt=REFLECTION_SYS_PROMPT,
|
||||||
|
resp_parse_func=reps_action_result
|
||||||
|
)
|
||||||
|
|
||||||
|
self.turn_count = 0
|
||||||
|
self.worker_history = []
|
||||||
|
self.reflections = []
|
||||||
|
self.cost_this_turn = 0
|
||||||
|
self.screenshot_inputs = []
|
||||||
|
|
||||||
|
self.dummy_task = Task()
|
||||||
|
self.dummy_context = Context()
|
||||||
|
self.dummy_context.set_task(self.dummy_task)
|
||||||
|
self.dummy_message = Message(headers={'context': self.dummy_context})
|
||||||
|
|
||||||
|
self.planning_model = get_llm_model(self.generator_agent_config)
|
||||||
|
|
||||||
|
self.first_done = False
|
||||||
|
self.first_image = None
|
||||||
|
|
||||||
|
def generate_next_action(
|
||||||
|
self,
|
||||||
|
instruction: str,
|
||||||
|
obs: Dict,
|
||||||
|
) -> Tuple[Dict, List]:
|
||||||
|
"""
|
||||||
|
Predict the next action(s) based on the current observation.
|
||||||
|
"""
|
||||||
|
agent = self.grounding_agent
|
||||||
|
generator_message = (
|
||||||
|
""
|
||||||
|
if self.turn_count > 0
|
||||||
|
else "The initial screen is provided. No action has been taken yet."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the task into the system prompt
|
||||||
|
if self.turn_count == 0:
|
||||||
|
self.generator_agent.system_prompt = self.generator_agent.system_prompt.replace(
|
||||||
|
"TASK_DESCRIPTION", instruction)
|
||||||
|
|
||||||
|
# Get the per-step reflection
|
||||||
|
reflection = None
|
||||||
|
reflection_thoughts = None
|
||||||
|
if self.enable_reflection:
|
||||||
|
# Load the initial message
|
||||||
|
if self.turn_count == 0:
|
||||||
|
text_content = textwrap.dedent(
|
||||||
|
f"""
|
||||||
|
Task Description: {instruction}
|
||||||
|
Current Trajectory below:
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
updated_sys_prompt = (
|
||||||
|
self.reflection_agent.system_prompt + "\n" + text_content
|
||||||
|
)
|
||||||
|
self.reflection_agent.system_prompt = updated_sys_prompt
|
||||||
|
|
||||||
|
image_content = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"The initial screen is provided. No action has been taken yet."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "data:image/png;base64," + encode_image(obs["screenshot"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
self.reflection_agent._init_context(context=self.dummy_context)
|
||||||
|
|
||||||
|
sync_exec(
|
||||||
|
self.reflection_agent._add_human_input_to_memory,
|
||||||
|
image_content,
|
||||||
|
self.dummy_context,
|
||||||
|
"message"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the latest action
|
||||||
|
else:
|
||||||
|
|
||||||
|
image = "data:image/png;base64," + encode_image(obs["screenshot"])
|
||||||
|
reflection_message = self.worker_history[-1] + "\n" + f"Here is function execute result: {obs['action_response']}.\n"
|
||||||
|
|
||||||
|
reflection_observation = Observation(content=reflection_message, image=image)
|
||||||
|
|
||||||
|
self.reflection_agent._init_context(context=self.dummy_context)
|
||||||
|
reflection_actions = self.reflection_agent.policy(reflection_observation, message=self.dummy_message)
|
||||||
|
|
||||||
|
reflection = reflection_actions[0].action_name
|
||||||
|
reflection_thoughts = reflection_actions[0].policy_info
|
||||||
|
|
||||||
|
self.reflections.append(reflection)
|
||||||
|
|
||||||
|
generator_message += f"Here is your function execute result: {obs['action_response']}.\n"
|
||||||
|
|
||||||
|
generator_message += f"REFLECTION: You may use this reflection on the previous action and overall trajectory:\n{reflection}\n"
|
||||||
|
logger.info("REFLECTION: %s", reflection)
|
||||||
|
|
||||||
|
if self.first_done:
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Add finalized message to conversation
|
||||||
|
generator_message += f"\nCurrent Text Buffer = [{','.join(agent.notes)}]\n"
|
||||||
|
|
||||||
|
image = "data:image/png;base64," + encode_image(obs["screenshot"])
|
||||||
|
generator_observation = Observation(content=generator_message, image=image)
|
||||||
|
|
||||||
|
self.generator_agent._init_context(context=self.dummy_context)
|
||||||
|
generator_actions = self.generator_agent.policy(generator_observation, message=self.dummy_message)
|
||||||
|
|
||||||
|
plan = generator_actions[0].action_name
|
||||||
|
plan_thoughts = generator_actions[0].policy_info
|
||||||
|
|
||||||
|
prune_image_messages(self.generator_agent.memory.memory_store, 16)
|
||||||
|
prune_image_messages(self.reflection_agent.memory.memory_store, 16)
|
||||||
|
|
||||||
|
self.worker_history.append(plan)
|
||||||
|
|
||||||
|
logger.info("FULL PLAN:\n %s", plan)
|
||||||
|
|
||||||
|
# self.generator_agent.add_message(plan, role="assistant")
|
||||||
|
# Use the grounding agent to convert agent_action("desc") into agent_action([x, y])
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent.assign_coordinates(plan, obs)
|
||||||
|
plan_code = parse_single_code_from_string(plan.split("Grounded Action")[-1])
|
||||||
|
plan_code = sanitize_code(plan_code)
|
||||||
|
plan_code = extract_first_agent_function(plan_code)
|
||||||
|
exec_code = eval(plan_code)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in parsing plan code: %s", e)
|
||||||
|
plan_code = "agent.wait(1.0)"
|
||||||
|
exec_code = eval(plan_code)
|
||||||
|
|
||||||
|
executor_info = {
|
||||||
|
"full_plan": plan,
|
||||||
|
"executor_plan": plan,
|
||||||
|
"plan_thoughts": plan_thoughts,
|
||||||
|
"plan_code": plan_code,
|
||||||
|
"reflection": reflection,
|
||||||
|
"reflection_thoughts": reflection_thoughts,
|
||||||
|
}
|
||||||
|
self.turn_count += 1
|
||||||
|
|
||||||
|
self.screenshot_inputs.append(obs["screenshot"])
|
||||||
|
|
||||||
|
return executor_info, [exec_code]
|
||||||
740
run_multienv_aworldguiagent.py
Normal file
740
run_multienv_aworldguiagent.py
Normal file
@@ -0,0 +1,740 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import math
|
||||||
|
from tqdm import tqdm
|
||||||
|
from multiprocessing import Process, Manager
|
||||||
|
from multiprocessing import current_process
|
||||||
|
import lib_run_single
|
||||||
|
from desktop_env.desktop_env import DesktopEnv, _fix_pyautogui_less_than_bug
|
||||||
|
from mm_agents.aworldguiagent.agent import AworldGUIAgent
|
||||||
|
from mm_agents.aworldguiagent.grounding import OSWorldACI
|
||||||
|
|
||||||
|
MAX_RETRIES = 5 # Maximum retries for environment setup
|
||||||
|
|
||||||
|
# Global variables for signal handling
|
||||||
|
active_environments = []
|
||||||
|
processes = []
|
||||||
|
is_terminating = False
|
||||||
|
|
||||||
|
# import wandb
|
||||||
|
|
||||||
|
# load the environment variables from .env file
|
||||||
|
if os.path.exists(".env"):
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
# Logger Configs {{{ #
|
||||||
|
def config() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run end-to-end evaluation on the benchmark"
|
||||||
|
)
|
||||||
|
|
||||||
|
# environment config
|
||||||
|
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--headless", action="store_true", help="Run in headless machine"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--observation_type",
|
||||||
|
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
|
default="screenshot",
|
||||||
|
help="Observation type",
|
||||||
|
)
|
||||||
|
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||||
|
parser.add_argument("--max_steps", type=int, default=15)
|
||||||
|
|
||||||
|
# agent config
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||||
|
)
|
||||||
|
|
||||||
|
# lm config
|
||||||
|
parser.add_argument("--model", type=str, default="o3")
|
||||||
|
|
||||||
|
# example config
|
||||||
|
parser.add_argument("--domain", type=str, default="all")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
# logging related
|
||||||
|
parser.add_argument("--result_dir", type=str, default="./results")
|
||||||
|
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||||
|
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||||
|
default='INFO', help="Set the logging level")
|
||||||
|
# aws config
|
||||||
|
parser.add_argument(
|
||||||
|
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"],
|
||||||
|
help="Provider name"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--client_password", type=str, default="", help="Client password"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--screen_width", type=int, default=1920, help="Screen width"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--screen_height", type=int, default=1080, help="Screen height"
|
||||||
|
)
|
||||||
|
|
||||||
|
# agent S2 config
|
||||||
|
|
||||||
|
parser.add_argument("--model_provider", type=str, default="openai")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_url",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The URL of the main generation model API.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_api_key",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The API key of the main generation model.",
|
||||||
|
)
|
||||||
|
parser.add_argument("--model_temperature", type=float, default=None,
|
||||||
|
help="Temperature to fix the generation model at (e.g. o3 can only be run with 1.0)")
|
||||||
|
|
||||||
|
parser.add_argument("--ground_provider", type=str, required=True, help="The provider for the grounding model")
|
||||||
|
parser.add_argument("--ground_url", type=str, required=True, help="The URL of the grounding model")
|
||||||
|
parser.add_argument(
|
||||||
|
"--ground_api_key",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The API key of the grounding model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ground_model", type=str, required=True, help="The model name for the grounding model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grounding_width",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Width of screenshot image after processor rescaling",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grounding_height",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Height of screenshot image after processor rescaling",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
args = config() # Get command line arguments first
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
log_level = getattr(logging, args.log_level.upper())
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
debug_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
debug_handler.setLevel(logging.DEBUG)
|
||||||
|
stdout_handler.setLevel(log_level)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
debug_handler.setFormatter(formatter)
|
||||||
|
stdout_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(debug_handler)
|
||||||
|
logger.addHandler(stdout_handler)
|
||||||
|
# }}} Logger Configs #
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDesktopEnv(DesktopEnv):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
logger.info("CustomDesktopEnv class initialized.")
|
||||||
|
|
||||||
|
def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
# Reset to certain task in OSWorld
|
||||||
|
logger.info("Resetting environment...")
|
||||||
|
logger.info("Switching task...")
|
||||||
|
logger.info("Setting counters...")
|
||||||
|
self._traj_no += 1
|
||||||
|
self._step_no = 0
|
||||||
|
self.action_history.clear()
|
||||||
|
|
||||||
|
for attempt in range(MAX_RETRIES):
|
||||||
|
# Only revert to snapshot if environment has been used (step/setup)
|
||||||
|
# This optimization is especially important for cloud providers like AWS
|
||||||
|
# where unnecessary snapshot operations are costly and time-consuming
|
||||||
|
|
||||||
|
if task_config is not None:
|
||||||
|
# Only consider task proxy requirement if proxy is enabled at system level
|
||||||
|
task_use_proxy = task_config.get("proxy", False) and self.enable_proxy
|
||||||
|
if not self.enable_proxy and task_config.get("proxy", False):
|
||||||
|
logger.info(
|
||||||
|
"Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.")
|
||||||
|
|
||||||
|
if task_use_proxy != self.current_use_proxy:
|
||||||
|
# keep because get_info_from_website depend on this
|
||||||
|
self.current_use_proxy = task_use_proxy
|
||||||
|
|
||||||
|
if self.is_environment_used:
|
||||||
|
logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name))
|
||||||
|
self._revert_to_snapshot()
|
||||||
|
logger.info("Starting emulator...")
|
||||||
|
self._start_emulator()
|
||||||
|
logger.info("Emulator started.")
|
||||||
|
# Reset the usage flag after reverting
|
||||||
|
self.is_environment_used = False
|
||||||
|
else:
|
||||||
|
logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name))
|
||||||
|
|
||||||
|
if task_config is not None:
|
||||||
|
if task_config.get("proxy", False) and self.enable_proxy:
|
||||||
|
# If using proxy and proxy is enabled, set up the proxy configuration
|
||||||
|
self.setup_controller._proxy_setup(self.client_password)
|
||||||
|
self._set_task_info(task_config)
|
||||||
|
self.setup_controller.reset_cache_dir(self.cache_dir)
|
||||||
|
logger.info("Clearing browser cache and browsing data...")
|
||||||
|
try:
|
||||||
|
self.setup_controller._delete_all_browsing_data_chromium_setup()
|
||||||
|
logger.info("Browser cache cleared successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to clear browser cache: {e}")
|
||||||
|
logger.info("Setting up environment...")
|
||||||
|
success = self.setup_controller.setup(self.config,
|
||||||
|
task_config.get("proxy", False) and self.enable_proxy)
|
||||||
|
if success:
|
||||||
|
# Mark environment as used when setup is successfully executed
|
||||||
|
if self.config: # Only mark as used if there were actual setup operations
|
||||||
|
self.is_environment_used = True
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Environment setup failed, retrying (%d/%d)...",
|
||||||
|
attempt + 1,
|
||||||
|
MAX_RETRIES,
|
||||||
|
)
|
||||||
|
time.sleep(5)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("Environment setup complete.")
|
||||||
|
|
||||||
|
# start soffice service for office tools
|
||||||
|
self.setup_controller._launch_setup(
|
||||||
|
'soffice --headless --accept="socket,host=localhost,port=2002;urp;" --norestore --nologo --nodefault', shell=True)
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
observation = self._get_obs()
|
||||||
|
return observation
|
||||||
|
|
||||||
|
def step(self, action, pause=2):
|
||||||
|
self._step_no += 1
|
||||||
|
self.action_history.append(action)
|
||||||
|
|
||||||
|
# Mark environment as used when step is called
|
||||||
|
self.is_environment_used = True
|
||||||
|
|
||||||
|
reward = 0 # todo: Define reward calculation for each example
|
||||||
|
done = False # todo: Define episode termination condition for each example
|
||||||
|
response = None
|
||||||
|
info = {}
|
||||||
|
logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}")
|
||||||
|
# handle the special actions
|
||||||
|
if action in ['WAIT', 'FAIL', 'DONE'] or (
|
||||||
|
type(action) == dict and action['action_type'] in ['WAIT', 'FAIL', 'DONE']):
|
||||||
|
if action == 'WAIT':
|
||||||
|
time.sleep(pause)
|
||||||
|
elif action == 'FAIL':
|
||||||
|
done = True
|
||||||
|
info = {"fail": True}
|
||||||
|
elif action == 'DONE':
|
||||||
|
done = True
|
||||||
|
info = {"done": True}
|
||||||
|
|
||||||
|
if self.action_space == "computer_13":
|
||||||
|
# the set of all possible actions defined in the action representation
|
||||||
|
self.controller.execute_action(action)
|
||||||
|
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
|
||||||
|
if action in ['WAIT', 'FAIL', 'DONE']:
|
||||||
|
self.controller.execute_action(action)
|
||||||
|
else:
|
||||||
|
# the set of all possible python commands insides `pyautogui`
|
||||||
|
if type(action) == str:
|
||||||
|
# Fix PyAutoGUI '<' character bug before execution
|
||||||
|
fixed_command = _fix_pyautogui_less_than_bug(action)
|
||||||
|
response = self.controller.execute_python_command(fixed_command)
|
||||||
|
|
||||||
|
elif type(action) == dict:
|
||||||
|
# Fix PyAutoGUI '<' character bug before execution
|
||||||
|
fixed_command = _fix_pyautogui_less_than_bug(action['command'])
|
||||||
|
response = self.controller.execute_python_command(fixed_command)
|
||||||
|
|
||||||
|
time.sleep(pause)
|
||||||
|
observation = self._get_obs()
|
||||||
|
observation["action_response"] = response
|
||||||
|
return observation, reward, done, info
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||||
|
all_tasks = []
|
||||||
|
for domain, examples in test_all_meta.items():
|
||||||
|
for example_id in examples:
|
||||||
|
all_tasks.append((domain, example_id))
|
||||||
|
return all_tasks
|
||||||
|
|
||||||
|
|
||||||
|
def process_signal_handler(signum, frame, env_idx):
|
||||||
|
"""Signal handler for child processes to gracefully shut down their environments."""
|
||||||
|
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||||
|
|
||||||
|
# Get the active_environments from the caller's frame
|
||||||
|
local_vars = frame.f_locals
|
||||||
|
active_environments = local_vars.get('active_environments', [])
|
||||||
|
|
||||||
|
# Close environment in the current process context
|
||||||
|
for env in active_environments:
|
||||||
|
if env is not None:
|
||||||
|
try:
|
||||||
|
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||||
|
env.close()
|
||||||
|
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||||
|
|
||||||
|
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||||
|
active_environments = []
|
||||||
|
env = None
|
||||||
|
try:
|
||||||
|
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||||
|
REGION = args.region
|
||||||
|
screen_size = (args.screen_width, args.screen_height)
|
||||||
|
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||||
|
env = CustomDesktopEnv(
|
||||||
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=args.action_space,
|
||||||
|
provider_name=args.provider_name,
|
||||||
|
region=REGION,
|
||||||
|
# snapshot_name=ami_id,
|
||||||
|
screen_size=screen_size,
|
||||||
|
headless=args.headless,
|
||||||
|
os_type="Ubuntu",
|
||||||
|
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
|
enable_proxy=False,
|
||||||
|
client_password=args.client_password
|
||||||
|
)
|
||||||
|
active_environments.append(env)
|
||||||
|
|
||||||
|
# AgentS2 configuration
|
||||||
|
engine_params = {
|
||||||
|
"engine_type": args.model_provider,
|
||||||
|
"model": args.model,
|
||||||
|
"base_url": getattr(args, 'model_url', ''),
|
||||||
|
"api_key": getattr(args, 'model_api_key', ''),
|
||||||
|
"temperature": getattr(args, 'model_temperature', None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
engine_params_for_grounding = {
|
||||||
|
"engine_type": args.ground_provider,
|
||||||
|
"model": args.ground_model,
|
||||||
|
"base_url": getattr(args, 'ground_url', ''),
|
||||||
|
"api_key": getattr(args, 'ground_api_key', ''),
|
||||||
|
"grounding_width": args.grounding_width,
|
||||||
|
"grounding_height": args.grounding_height,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create grounding agent
|
||||||
|
grounding_agent = OSWorldACI(
|
||||||
|
platform="linux",
|
||||||
|
engine_params_for_generation=engine_params,
|
||||||
|
engine_params_for_grounding=engine_params_for_grounding,
|
||||||
|
width=args.screen_width,
|
||||||
|
height=args.screen_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create AgentS2 worker
|
||||||
|
agent = AworldGUIAgent(
|
||||||
|
engine_params,
|
||||||
|
grounding_agent,
|
||||||
|
platform="linux",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Process {current_process().name} started.")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
item = task_queue.get(timeout=5)
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
domain, example_id = item
|
||||||
|
try:
|
||||||
|
config_file = os.path.join(
|
||||||
|
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||||
|
)
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
example = json.load(f)
|
||||||
|
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||||
|
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||||
|
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||||
|
example_result_dir = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
args.model,
|
||||||
|
domain,
|
||||||
|
example_id,
|
||||||
|
)
|
||||||
|
os.makedirs(example_result_dir, exist_ok=True)
|
||||||
|
try:
|
||||||
|
lib_run_single.run_single_example(
|
||||||
|
agent,
|
||||||
|
env,
|
||||||
|
example,
|
||||||
|
args.max_steps,
|
||||||
|
example["instruction"],
|
||||||
|
args,
|
||||||
|
example_result_dir,
|
||||||
|
shared_scores,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
try:
|
||||||
|
env.controller.end_recording(
|
||||||
|
os.path.join(example_result_dir, "recording.mp4")
|
||||||
|
)
|
||||||
|
except Exception as rec_e:
|
||||||
|
logger.error(f"Failed to end recording: {rec_e}")
|
||||||
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{"Error": f"{domain}/{example_id} - {e}"}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
finally:
|
||||||
|
logger.info(f"{current_process().name} cleaning up environment...")
|
||||||
|
try:
|
||||||
|
if env:
|
||||||
|
env.close()
|
||||||
|
logger.info(f"{current_process().name} environment closed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||||
|
global is_terminating, active_environments, processes
|
||||||
|
|
||||||
|
# Avoid duplicate handling
|
||||||
|
if is_terminating:
|
||||||
|
return
|
||||||
|
|
||||||
|
is_terminating = True
|
||||||
|
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||||
|
|
||||||
|
# Close all registered environments in the main process
|
||||||
|
for env in active_environments:
|
||||||
|
try:
|
||||||
|
logger.info(f"Closing environment...")
|
||||||
|
env.close()
|
||||||
|
logger.info(f"Environment closed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing environment: {e}")
|
||||||
|
|
||||||
|
# Send termination signal to all child processes first
|
||||||
|
for p in processes:
|
||||||
|
if p.is_alive():
|
||||||
|
try:
|
||||||
|
logger.info(f"Sending termination signal to process {p.name}...")
|
||||||
|
p.terminate()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error sending termination signal to process: {e}")
|
||||||
|
|
||||||
|
# Allow a short time for processes to handle their own cleanup
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Forcefully terminate any processes that didn't exit
|
||||||
|
for p in processes:
|
||||||
|
if p.is_alive():
|
||||||
|
try:
|
||||||
|
logger.info(f"Forcefully terminating process {p.name}...")
|
||||||
|
import signal as sig
|
||||||
|
os.kill(p.pid, sig.SIGKILL)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error forcefully terminating process: {e}")
|
||||||
|
|
||||||
|
logger.info("Shutdown complete. Exiting.")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||||
|
global processes
|
||||||
|
logger.info("Args: %s", args)
|
||||||
|
all_tasks = distribute_tasks(test_all_meta)
|
||||||
|
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||||
|
with Manager() as manager:
|
||||||
|
shared_scores = manager.list()
|
||||||
|
task_queue = manager.Queue()
|
||||||
|
for item in all_tasks:
|
||||||
|
task_queue.put(item)
|
||||||
|
num_envs = args.num_envs
|
||||||
|
processes = []
|
||||||
|
for i in range(num_envs):
|
||||||
|
p = Process(
|
||||||
|
target=run_env_tasks,
|
||||||
|
args=(task_queue, args, shared_scores),
|
||||||
|
name=f"EnvProcess-{i + 1}"
|
||||||
|
)
|
||||||
|
p.daemon = True
|
||||||
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
|
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
alive_count = 0
|
||||||
|
for idx, p in enumerate(processes):
|
||||||
|
if not p.is_alive():
|
||||||
|
logger.warning(f"Process {p.name} died, restarting...")
|
||||||
|
new_p = Process(
|
||||||
|
target=run_env_tasks,
|
||||||
|
args=(task_queue, args, shared_scores),
|
||||||
|
name=f"EnvProcess-Restart-{idx + 1}"
|
||||||
|
)
|
||||||
|
new_p.daemon = True
|
||||||
|
new_p.start()
|
||||||
|
processes[idx] = new_p
|
||||||
|
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||||
|
else:
|
||||||
|
alive_count += 1
|
||||||
|
if task_queue.empty():
|
||||||
|
logger.info("All tasks finished.")
|
||||||
|
break
|
||||||
|
if alive_count == 0:
|
||||||
|
logger.error("All processes died, exiting.")
|
||||||
|
break
|
||||||
|
time.sleep(5)
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||||
|
for p in processes:
|
||||||
|
if p.is_alive():
|
||||||
|
try:
|
||||||
|
logger.info(f"Terminating process {p.name} due to error...")
|
||||||
|
p.terminate()
|
||||||
|
except Exception as term_e:
|
||||||
|
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||||
|
raise
|
||||||
|
scores = list(shared_scores)
|
||||||
|
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_unfinished(
|
||||||
|
action_space, use_model, observation_type, result_dir, total_file_json
|
||||||
|
):
|
||||||
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
finished = {}
|
||||||
|
for domain in os.listdir(target_dir):
|
||||||
|
finished[domain] = []
|
||||||
|
domain_path = os.path.join(target_dir, domain)
|
||||||
|
if os.path.isdir(domain_path):
|
||||||
|
for example_id in os.listdir(domain_path):
|
||||||
|
if example_id == "onboard":
|
||||||
|
continue
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" not in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
for file in os.listdir(example_path):
|
||||||
|
os.remove(os.path.join(example_path, file))
|
||||||
|
else:
|
||||||
|
finished[domain].append(example_id)
|
||||||
|
|
||||||
|
if not finished:
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
for domain, examples in finished.items():
|
||||||
|
if domain in total_file_json:
|
||||||
|
total_file_json[domain] = [
|
||||||
|
x for x in total_file_json[domain] if x not in examples
|
||||||
|
]
|
||||||
|
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
|
||||||
|
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||||
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
all_result = []
|
||||||
|
|
||||||
|
for domain in os.listdir(target_dir):
|
||||||
|
domain_path = os.path.join(target_dir, domain)
|
||||||
|
if os.path.isdir(domain_path):
|
||||||
|
for example_id in os.listdir(domain_path):
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
try:
|
||||||
|
all_result.append(
|
||||||
|
float(
|
||||||
|
open(
|
||||||
|
os.path.join(example_path, "result.txt"), "r"
|
||||||
|
).read()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
all_result.append(0.0)
|
||||||
|
|
||||||
|
if not all_result:
|
||||||
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||||
|
return all_result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
####### The complete version of the list of examples #######
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
# Register signal handlers for graceful termination
|
||||||
|
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = config()
|
||||||
|
|
||||||
|
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||||
|
path_to_args = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
args.model,
|
||||||
|
"args.json",
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||||
|
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(vars(args), f, indent=4)
|
||||||
|
|
||||||
|
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||||
|
test_all_meta = json.load(f)
|
||||||
|
|
||||||
|
if args.domain != "all":
|
||||||
|
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||||
|
|
||||||
|
test_file_list = get_unfinished(
|
||||||
|
args.action_space,
|
||||||
|
args.model,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta,
|
||||||
|
)
|
||||||
|
left_info = ""
|
||||||
|
for domain in test_file_list:
|
||||||
|
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||||
|
logger.info(f"Left tasks:\n{left_info}")
|
||||||
|
|
||||||
|
get_result(
|
||||||
|
args.action_space,
|
||||||
|
args.model,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta,
|
||||||
|
)
|
||||||
|
test(args, test_file_list)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Main process received KeyboardInterrupt.")
|
||||||
|
# Signal handler will take care of cleanup
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||||
|
# Also trigger cleanup for unhandled exceptions
|
||||||
|
signal_handler(signal.SIGTERM, None)
|
||||||
|
finally:
|
||||||
|
# Final cleanup in case any environments or processes remain
|
||||||
|
logger.info("Main process final cleanup...")
|
||||||
|
for env in active_environments:
|
||||||
|
if env is not None:
|
||||||
|
try:
|
||||||
|
logger.info(f"Closing environment in final cleanup...")
|
||||||
|
env.close()
|
||||||
|
logger.info(f"Environment closed successfully in final cleanup")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during final environment cleanup: {e}")
|
||||||
|
|
||||||
|
# First try gentle termination
|
||||||
|
for p in processes:
|
||||||
|
if p is not None and p.is_alive():
|
||||||
|
try:
|
||||||
|
logger.info(f"Terminating process {p.name}...")
|
||||||
|
p.terminate()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error terminating process: {e}")
|
||||||
|
|
||||||
|
# Wait a moment for processes to terminate
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Then force kill if needed
|
||||||
|
for p in processes:
|
||||||
|
if p is not None and p.is_alive():
|
||||||
|
try:
|
||||||
|
logger.info(f"Force killing process {p.name}...")
|
||||||
|
os.kill(p.pid, signal.SIGKILL)
|
||||||
|
logger.info(f"Process {p.name} force killed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error force killing process: {e}")
|
||||||
Reference in New Issue
Block a user