Merge branch 'main' of github.com:xlang-ai/OSWorld
This commit is contained in:
@@ -346,6 +346,8 @@ def read_cell_value(xlsx_file: str, sheet_name: str, coordinate: str) -> Any:
|
|||||||
return cell["c"]["v"]
|
return cell["c"]["v"]
|
||||||
if cell["c"]["@t"] == "inlineStr":
|
if cell["c"]["@t"] == "inlineStr":
|
||||||
return cell["c"]["is"]["t"]
|
return cell["c"]["is"]["t"]
|
||||||
|
if cell["c"]["@t"] == "e":
|
||||||
|
return cell["c"]["v"]
|
||||||
except (KeyError, ValueError):
|
except (KeyError, ValueError):
|
||||||
return None
|
return None
|
||||||
# }}} read_cell_value #
|
# }}} read_cell_value #
|
||||||
@@ -409,6 +411,43 @@ def _read_cell_style(style_name: str, cell: Union[Cell, MergedCell], diff_style:
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unsupported Style: {:}".format(style_name))
|
raise NotImplementedError("Unsupported Style: {:}".format(style_name))
|
||||||
|
|
||||||
|
def _process_xlsx_cf_operator(operator: str, value: Any, ref: List[Any]) -> bool:
|
||||||
|
# function _process_xlsx_cf_operator {{{ #
|
||||||
|
# "containsText", "lessThanOrEqual", "notBetween", "lessThan", "notContains", "beginsWith", "equal", "greaterThanOrEqual", "between", "endsWith", "notEqual", "greaterThan"
|
||||||
|
try:
|
||||||
|
if operator=="lessThanOrEqual":
|
||||||
|
result: bool = value<=ref[0]
|
||||||
|
elif operator=="lessThan":
|
||||||
|
result: bool = value<ref[0]
|
||||||
|
elif operator=="equal":
|
||||||
|
result: bool = value==ref[0]
|
||||||
|
elif operator=="greaterThanOrEqual":
|
||||||
|
result: bool = value>=ref[0]
|
||||||
|
elif operator=="notEqual":
|
||||||
|
result: bool = value!=ref[0]
|
||||||
|
elif operator=="greaterThan":
|
||||||
|
result: bool = value>ref[0]
|
||||||
|
elif operator=="between":
|
||||||
|
small_one: float
|
||||||
|
large_one: float
|
||||||
|
small_one, large_one = min(ref), max(ref)
|
||||||
|
result: bool = value>=small_one and value<=large_one
|
||||||
|
elif operator=="notBetween":
|
||||||
|
small_one: float
|
||||||
|
large_one: float
|
||||||
|
small_one, large_one = min(ref), max(ref)
|
||||||
|
result: bool = value<small_one or value>large_one
|
||||||
|
else:
|
||||||
|
#raise NotImplementedError("Not Implemented CondFormat Operator: {:}".format(operator))
|
||||||
|
logger.exception("Not Implemented CondFormat Operator: {:}".format(operator))
|
||||||
|
return result
|
||||||
|
except TypeError:
|
||||||
|
logger.exception("Unmatched type of %s and %s. Auto to False", repr(value), repr(ref))
|
||||||
|
return False
|
||||||
|
except IndexError:
|
||||||
|
logger.exception("ref array doesn't have enough elements. Auto to False: %s", repr(ref))
|
||||||
|
return False
|
||||||
|
# }}} function _process_xlsx_cf_operator #
|
||||||
|
|
||||||
_absolute_range_pattern: Pattern[str] = re.compile(r"""\$(?P<col1>[A-Z]{1,3})\$(?P<row1>\d+) # coord1
|
_absolute_range_pattern: Pattern[str] = re.compile(r"""\$(?P<col1>[A-Z]{1,3})\$(?P<row1>\d+) # coord1
|
||||||
(?::
|
(?::
|
||||||
@@ -459,16 +498,23 @@ def load_xlsx_styles(xlsx_file: Workbook, sheet_name: str, book_name: str, **opt
|
|||||||
for fmt in conditional_formattings:
|
for fmt in conditional_formattings:
|
||||||
for r in fmt.rules:
|
for r in fmt.rules:
|
||||||
active_cells: List[Cell] = []
|
active_cells: List[Cell] = []
|
||||||
if r.type == "expression":
|
|
||||||
|
# Process CF Formulae {{{ #
|
||||||
|
formulae: List[Callable[[Any], Any]] = []
|
||||||
|
argument_lists: List[List[Any]] = []
|
||||||
|
has_error = False
|
||||||
|
for fml in r.formula:
|
||||||
try:
|
try:
|
||||||
condition: Callable[[str], bool] = formula_parser.ast("=" + r.formula[0])[1].compile()
|
formula_func: Callable[[Any], Any] =\
|
||||||
|
formula_parser.ast("=" + fml)[1].compile()
|
||||||
|
logger.debug("CondFormat rule formula: %s", fml)
|
||||||
except:
|
except:
|
||||||
logger.exception("Formula parsing error: %s. Skipping.", repr(r.formula[0]))
|
logger.exception("Formula parsing error: %s. Skipping.", repr(fml))
|
||||||
continue
|
has_error = True
|
||||||
logger.debug("Expression condition: %s", r.formula[0])
|
break
|
||||||
|
|
||||||
arguments: List[Any] = []
|
arguments: List[Any] = []
|
||||||
absolute_range_match: List[Tuple[str, str, str, str]] = _absolute_range_pattern.findall(r.formula[0])
|
absolute_range_match: List[Tuple[str, str, str, str]] = _absolute_range_pattern.findall(fml)
|
||||||
for m in absolute_range_match:
|
for m in absolute_range_match:
|
||||||
logger.debug("Absolute ranges: %s", repr(m))
|
logger.debug("Absolute ranges: %s", repr(m))
|
||||||
if m[2] is None and m[3] is None:
|
if m[2] is None and m[3] is None:
|
||||||
@@ -484,31 +530,65 @@ def load_xlsx_styles(xlsx_file: Workbook, sheet_name: str, book_name: str, **opt
|
|||||||
)
|
)
|
||||||
logger.debug("Absolute range arguments: %s", repr(arguments))
|
logger.debug("Absolute range arguments: %s", repr(arguments))
|
||||||
|
|
||||||
nb_contiguous_nothings = 0
|
formulae.append(formula_func)
|
||||||
for rge in fmt.cells:
|
argument_lists.append(arguments)
|
||||||
for c in rge.cells:
|
|
||||||
cell: Cell = worksheet.cell(row=c[0], column=c[1])
|
if has_error:
|
||||||
cell_value = read_cell_value(book_name, sheet_name
|
continue
|
||||||
, coordinate="{:}{:d}".format(get_column_letter(c[1])
|
# }}} Process CF Formulae #
|
||||||
, c[0]
|
|
||||||
)
|
# Process Condition Accroding to Type {{{ #
|
||||||
)
|
if r.type in { "expression"
|
||||||
if cell_value is None:
|
, "containsText", "notContainsText"
|
||||||
nb_contiguous_nothings += 1
|
, "endsWith", "beginsWith"
|
||||||
if nb_contiguous_nothings>50:
|
, "containsErrors", "notContainsErrors"
|
||||||
break
|
}:
|
||||||
continue
|
condition: Callable[[Any], bool] = formulae[0]
|
||||||
else:
|
arguments: List[Any] = argument_lists[0]
|
||||||
try:
|
is_active: Callable[[Any], bool] = lambda v: condition(v, *arguments)
|
||||||
satisfies_condition: bool = condition(cell_value, *arguments)
|
elif r.type == "cellIs":
|
||||||
except:
|
operator: str = r.operator
|
||||||
logger.exception("Error in formula calculation with cell value %d", repr(cell_value))
|
try:
|
||||||
satisfies_condition = False
|
references: List[Any] = [fml() for fml in formulae]
|
||||||
if satisfies_condition:
|
except:
|
||||||
logger.debug("Active Cell %s(%s) for %s", repr(cell), repr(cell_value), r.formula[0])
|
logger.exception("Error occurs while calculating reference values for cellIs condition formatting.")
|
||||||
active_cells.append(cell)
|
continue
|
||||||
|
is_active: Callable[[Any], bool] =\
|
||||||
|
lambda v: _process_xlsx_cf_operator(operator, v, references)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Not Implemented Condition Type: {:}".format(r.type))
|
#raise NotImplementedError("Not Implemented Condition Type: {:}".format(r.type))
|
||||||
|
# e.g., type=top10 (rank=number, percent=bool, bottom=bool)
|
||||||
|
# type=aboveAverage (equalAverage=bool, aboveAverage=bool)
|
||||||
|
# type=duplicateValues / type=uniqueValues
|
||||||
|
logger.exception("Not Implemented Condition Type: {:}".format(r.type))
|
||||||
|
# }}} Process Condition Accroding to Type #
|
||||||
|
|
||||||
|
|
||||||
|
# Test Each Cell {{{ #
|
||||||
|
nb_contiguous_nothings = 0
|
||||||
|
for rge in fmt.cells:
|
||||||
|
for c in rge.cells:
|
||||||
|
cell: Cell = worksheet.cell(row=c[0], column=c[1])
|
||||||
|
cell_value = read_cell_value(book_name, sheet_name
|
||||||
|
, coordinate="{:}{:d}".format(get_column_letter(c[1])
|
||||||
|
, c[0]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if cell_value is None:
|
||||||
|
nb_contiguous_nothings += 1
|
||||||
|
if nb_contiguous_nothings>50:
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
satisfies_condition: bool = is_active(cell_value)
|
||||||
|
except:
|
||||||
|
logger.exception("Error in formula calculation with cell value %d", repr(cell_value))
|
||||||
|
satisfies_condition = False
|
||||||
|
if satisfies_condition:
|
||||||
|
logger.debug("Active Cell %s(%s) for %s", repr(cell), repr(cell_value), r.formula[0])
|
||||||
|
active_cells.append(cell)
|
||||||
|
# }}} Test Each Cell #
|
||||||
|
|
||||||
for c in active_cells:
|
for c in active_cells:
|
||||||
style_dict[c.coordinate] = [_read_cell_style(st, c, r.dxf) for st in concerned_styles]
|
style_dict[c.coordinate] = [_read_cell_style(st, c, r.dxf) for st in concerned_styles]
|
||||||
|
|||||||
@@ -177,9 +177,7 @@ def run_single_example_opencua(agent, env, example, max_steps, instruction, args
|
|||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
logger.info("Step %d: %s", step_idx + 1, action)
|
logger.info("Step %d: %s", step_idx + 1, action)
|
||||||
|
|
||||||
obs, reward, done, info = env.step(action)
|
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||||
time.sleep(3)
|
|
||||||
obs = env._get_obs()
|
|
||||||
|
|
||||||
logger.info(f"Action {action} executed, reward: {reward}, done: {done}")
|
logger.info(f"Action {action} executed, reward: {reward}, done: {done}")
|
||||||
# Save screenshot and trajectory information
|
# Save screenshot and trajectory information
|
||||||
|
|||||||
@@ -571,11 +571,6 @@ class OpenCUAAgent:
|
|||||||
logger.info(f"========================== {self.model} ===================================")
|
logger.info(f"========================== {self.model} ===================================")
|
||||||
logger.info(f"Instruction: \n{instruction}")
|
logger.info(f"Instruction: \n{instruction}")
|
||||||
|
|
||||||
image_bytes = BytesIO(obs['screenshot'])
|
|
||||||
with Image.open(image_bytes) as img:
|
|
||||||
print("Actual screen size", img.size)
|
|
||||||
print("Logical screen size", self.screen_size)
|
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -598,7 +593,7 @@ class OpenCUAAgent:
|
|||||||
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
||||||
observation=self.cots[i].get('observation'),
|
observation=self.cots[i].get('observation'),
|
||||||
thought=self.cots[i].get('thought'),
|
thought=self.cots[i].get('thought'),
|
||||||
action=self.cots[i]['action']
|
action=self.cots[i].get('action')
|
||||||
)
|
)
|
||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -609,7 +604,7 @@ class OpenCUAAgent:
|
|||||||
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
||||||
observation=self.cots[i].get('observation'),
|
observation=self.cots[i].get('observation'),
|
||||||
thought=self.cots[i].get('thought'),
|
thought=self.cots[i].get('thought'),
|
||||||
action=self.cots[i]['action']
|
action=self.cots[i].get('action')
|
||||||
)
|
)
|
||||||
history_step_texts.append(history_content)
|
history_step_texts.append(history_content)
|
||||||
if i == len(self.actions) - self.max_image_history_length:
|
if i == len(self.actions) - self.max_image_history_length:
|
||||||
@@ -640,7 +635,7 @@ class OpenCUAAgent:
|
|||||||
"temperature": self.temperature
|
"temperature": self.temperature
|
||||||
}, self.model)
|
}, self.model)
|
||||||
|
|
||||||
logger.info(f"Model Output: \n\n{response}")
|
logger.info(f"Model Output: \n{response}")
|
||||||
if not response:
|
if not response:
|
||||||
logger.error("No response found in the response.")
|
logger.error("No response found in the response.")
|
||||||
return "ERROR", [], {}
|
return "ERROR", [], {}
|
||||||
@@ -666,23 +661,23 @@ class OpenCUAAgent:
|
|||||||
self.cots.append(other_cot)
|
self.cots.append(other_cot)
|
||||||
|
|
||||||
# Print message structure if needed
|
# Print message structure if needed
|
||||||
logger.info(f"\nInstruction: {instruction}")
|
# messages_to_print = []
|
||||||
messages_to_print = []
|
# current_image = 1
|
||||||
current_image = 1
|
# for msg in messages:
|
||||||
for msg in messages:
|
# msg_copy = copy.deepcopy(msg)
|
||||||
msg_copy = copy.deepcopy(msg)
|
# if isinstance(msg_copy['content'], list):
|
||||||
if isinstance(msg_copy['content'], list):
|
# for content in msg_copy['content']:
|
||||||
for content in msg_copy['content']:
|
# if content['type'] == 'image_url':
|
||||||
if content['type'] == 'image_url':
|
# content['image_url']['url'] = f'Image {current_image}'
|
||||||
content['image_url']['url'] = f'Image {current_image}'
|
# current_image += 1
|
||||||
current_image += 1
|
# messages_to_print.append(msg_copy)
|
||||||
messages_to_print.append(msg_copy)
|
|
||||||
|
|
||||||
messages_to_print.append({
|
# messages_to_print.append({
|
||||||
"new_step_cot": other_cot,
|
# "new_step_cot": other_cot,
|
||||||
"response": response
|
# "response": response
|
||||||
})
|
# })
|
||||||
logger.info(json.dumps(messages_to_print, indent=2))
|
# logger.info(json.dumps(messages_to_print, indent=2))
|
||||||
|
logger.info(f"New step cot: {other_cot}")
|
||||||
|
|
||||||
return response, pyautogui_actions, {}
|
return response, pyautogui_actions, {}
|
||||||
|
|
||||||
@@ -720,4 +715,10 @@ class OpenCUAAgent:
|
|||||||
logger.error("Retrying...")
|
logger.error("Retrying...")
|
||||||
time.sleep(5)
|
time.sleep(5)
|
||||||
else:
|
else:
|
||||||
return response.json()['choices'][0]['message']['content']
|
response = response.json()
|
||||||
|
finish_reason = response["choices"][0].get("finish_reason")
|
||||||
|
if finish_reason is not None and finish_reason == "stop": # for most of the time, length will not exceed max_tokens
|
||||||
|
return response['choices'][0]['message']['content']
|
||||||
|
else:
|
||||||
|
logger.error("LLM did not finish properly, retrying...")
|
||||||
|
time.sleep(5)
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import os
|
|||||||
import re
|
import re
|
||||||
import base64
|
import base64
|
||||||
import requests
|
import requests
|
||||||
from typing import Optional, Dict, List, Tuple
|
import logging
|
||||||
|
from typing import Optional, Dict, List, Tuple, Union
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
@@ -573,7 +574,34 @@ finished(content='xxx') # Use escape characters \\', \\", and \\n in content par
|
|||||||
GROUNDING_DOUBAO = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. \n\n## Output Format\n\nAction: ...\n\n\n## Action Space\nclick(point='<point>x1 y1</point>'')\n\n## User Instruction
|
GROUNDING_DOUBAO = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. \n\n## Output Format\n\nAction: ...\n\n\n## Action Space\nclick(point='<point>x1 y1</point>'')\n\n## User Instruction
|
||||||
{instruction}"""
|
{instruction}"""
|
||||||
|
|
||||||
|
COMPUTER_USE_NO_THINKING = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
```
|
||||||
|
Thought: ...
|
||||||
|
Action: ...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Action Space
|
||||||
|
|
||||||
|
click(point='<point>x1 y1</point>')
|
||||||
|
left_double(point='<point>x1 y1</point>')
|
||||||
|
right_single(point='<point>x1 y1</point>')
|
||||||
|
drag(start_point='<point>x1 y1</point>', end_point='<point>x2 y2</point>')
|
||||||
|
hotkey(key='ctrl c') # Split keys with a space and use lowercase. Also, do not use more than 3 keys in one hotkey action.
|
||||||
|
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
|
||||||
|
scroll(point='<point>x1 y1</point>', direction='down or up or right or left') # Show more information on the `direction` side.
|
||||||
|
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||||
|
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||||
|
|
||||||
|
|
||||||
|
## Note
|
||||||
|
- Use Chinese in `Thought` part.
|
||||||
|
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||||
|
|
||||||
|
## User Instruction
|
||||||
|
{instruction}
|
||||||
|
"""
|
||||||
|
|
||||||
class UITarsAgent:
|
class UITarsAgent:
|
||||||
"""
|
"""
|
||||||
@@ -638,9 +666,11 @@ class UITarsAgent:
|
|||||||
self.history_images = []
|
self.history_images = []
|
||||||
self.history_responses = []
|
self.history_responses = []
|
||||||
|
|
||||||
self.system_prompt = COMPUTER_USE_DOUBAO
|
if use_thinking:
|
||||||
|
self.system_prompt = COMPUTER_USE_DOUBAO
|
||||||
|
else:
|
||||||
|
self.system_prompt = COMPUTER_USE_NO_THINKING
|
||||||
|
|
||||||
|
|
||||||
self.action_parse_res_factor = 1000
|
self.action_parse_res_factor = 1000
|
||||||
self.model_type = "doubao"
|
self.model_type = "doubao"
|
||||||
self.history_n = 5
|
self.history_n = 5
|
||||||
@@ -648,6 +678,9 @@ class UITarsAgent:
|
|||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.platform = "ubuntu"
|
self.platform = "ubuntu"
|
||||||
|
self.use_thinking = use_thinking
|
||||||
|
|
||||||
|
self.inference_func = self.inference_with_thinking if use_thinking else self.inference_without_thinking
|
||||||
|
|
||||||
def reset(self, _logger=None):
|
def reset(self, _logger=None):
|
||||||
global logger
|
global logger
|
||||||
@@ -721,7 +754,36 @@ class UITarsAgent:
|
|||||||
"details": response.text
|
"details": response.text
|
||||||
}
|
}
|
||||||
|
|
||||||
def predict(self, task_instruction: str, obs: dict) -> Tuple[str, List]:
|
def inference_without_thinking(self, messages):
|
||||||
|
api_key = os.environ['DOUBAO_API_KEY']
|
||||||
|
api_url = os.environ['DOUBAO_API_URL']
|
||||||
|
headers = {
|
||||||
|
'Authorization': f'Bearer {api_key}',
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"model": self.model,
|
||||||
|
"messages": messages,
|
||||||
|
"thinking": {"type": "disabled"},
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(api_url, headers=headers, json=data)
|
||||||
|
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()["choices"][0]["message"]["content"]
|
||||||
|
else:
|
||||||
|
print(f"Request failed with status code {response.status_code}")
|
||||||
|
print(response.json())
|
||||||
|
return {
|
||||||
|
"error": f"Request failed with status code {response.status_code}",
|
||||||
|
"details": response.text
|
||||||
|
}
|
||||||
|
|
||||||
|
def predict(self, task_instruction: str, obs: dict) -> Tuple[Union[str, Dict, None], List]:
|
||||||
"""Predict the next action based on the current observation."""
|
"""Predict the next action based on the current observation."""
|
||||||
|
|
||||||
self.task_instruction = task_instruction
|
self.task_instruction = task_instruction
|
||||||
@@ -793,7 +855,7 @@ class UITarsAgent:
|
|||||||
return prediction, ["FAIL"]
|
return prediction, ["FAIL"]
|
||||||
try:
|
try:
|
||||||
logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}")
|
logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}")
|
||||||
prediction = self.inference_with_thinking(messages)
|
prediction = self.inference_func(messages)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error when fetching response from client, with error:\n{e}")
|
self.logger.error(f"Error when fetching response from client, with error:\n{e}")
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing import List, Dict
|
|||||||
import math
|
import math
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from multiprocessing import Process, Manager
|
from multiprocessing import Process, Manager
|
||||||
|
from multiprocessing import current_process
|
||||||
import lib_run_single
|
import lib_run_single
|
||||||
from desktop_env.desktop_env import DesktopEnv
|
from desktop_env.desktop_env import DesktopEnv
|
||||||
from mm_agents.opencua_agent import OpenCUAAgent
|
from mm_agents.opencua_agent import OpenCUAAgent
|
||||||
@@ -45,7 +46,7 @@ def config() -> argparse.Namespace:
|
|||||||
default="screenshot",
|
default="screenshot",
|
||||||
help="Observation type",
|
help="Observation type",
|
||||||
)
|
)
|
||||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
|
||||||
parser.add_argument("--max_steps", type=int, default=15)
|
parser.add_argument("--max_steps", type=int, default=15)
|
||||||
|
|
||||||
# evaluation config
|
# evaluation config
|
||||||
@@ -57,7 +58,7 @@ def config() -> argparse.Namespace:
|
|||||||
parser.add_argument("--model", type=str, default="opencua")
|
parser.add_argument("--model", type=str, default="opencua")
|
||||||
parser.add_argument("--temperature", type=float, default=0)
|
parser.add_argument("--temperature", type=float, default=0)
|
||||||
parser.add_argument("--top_p", type=float, default=0.9)
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
parser.add_argument("--max_tokens", type=int, default=8196)
|
||||||
parser.add_argument("--stop_token", type=str, default=None)
|
parser.add_argument("--stop_token", type=str, default=None)
|
||||||
|
|
||||||
# OpenCUAagent config
|
# OpenCUAagent config
|
||||||
@@ -133,32 +134,12 @@ logger.addHandler(stdout_handler)
|
|||||||
logger = logging.getLogger("desktopenv.experiment")
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
|
||||||
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||||
"""Distribute tasks evenly across environments."""
|
|
||||||
# Flatten the tasks into a single list
|
|
||||||
all_tasks = []
|
all_tasks = []
|
||||||
for domain, examples in test_all_meta.items():
|
for domain, examples in test_all_meta.items():
|
||||||
for example_id in examples:
|
for example_id in examples:
|
||||||
all_tasks.append((domain, example_id))
|
all_tasks.append((domain, example_id))
|
||||||
|
return all_tasks
|
||||||
# Calculate tasks per environment
|
|
||||||
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
|
|
||||||
|
|
||||||
# Distribute tasks
|
|
||||||
distributed_tasks = []
|
|
||||||
for i in range(num_envs):
|
|
||||||
env_tasks = {}
|
|
||||||
start_idx = i * tasks_per_env
|
|
||||||
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
|
|
||||||
|
|
||||||
for domain, example_id in all_tasks[start_idx:end_idx]:
|
|
||||||
if domain not in env_tasks:
|
|
||||||
env_tasks[domain] = []
|
|
||||||
env_tasks[domain].append(example_id)
|
|
||||||
|
|
||||||
distributed_tasks.append(env_tasks)
|
|
||||||
|
|
||||||
return distributed_tasks
|
|
||||||
|
|
||||||
|
|
||||||
def process_signal_handler(signum, frame, env_idx):
|
def process_signal_handler(signum, frame, env_idx):
|
||||||
@@ -182,51 +163,45 @@ def process_signal_handler(signum, frame, env_idx):
|
|||||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||||
def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
|
|
||||||
"""Run tasks for a single environment."""
|
|
||||||
# Each process has its own list of active environments
|
|
||||||
active_environments = []
|
active_environments = []
|
||||||
env = None
|
env = None
|
||||||
|
|
||||||
# Setup signal handlers for this process too
|
|
||||||
signal.signal(signal.SIGINT, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
|
|
||||||
signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
|
|
||||||
|
|
||||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
|
||||||
REGION = args.region
|
|
||||||
screen_size = (args.screen_width, args.screen_height)
|
|
||||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=args.path_to_vm,
|
|
||||||
action_space=args.action_space,
|
|
||||||
provider_name=args.provider_name,
|
|
||||||
region=REGION,
|
|
||||||
snapshot_name=ami_id,
|
|
||||||
screen_size=screen_size,
|
|
||||||
headless=args.headless,
|
|
||||||
os_type="Ubuntu",
|
|
||||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
|
||||||
enable_proxy=True,
|
|
||||||
client_password=args.client_password
|
|
||||||
)
|
|
||||||
active_environments.append(env)
|
|
||||||
|
|
||||||
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
|
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||||
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
|
REGION = args.region
|
||||||
|
screen_size = (args.screen_width, args.screen_height)
|
||||||
|
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||||
|
env = DesktopEnv(
|
||||||
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=args.action_space,
|
||||||
|
provider_name=args.provider_name,
|
||||||
|
region=REGION,
|
||||||
|
snapshot_name=ami_id,
|
||||||
|
screen_size=screen_size,
|
||||||
|
headless=args.headless,
|
||||||
|
os_type="Ubuntu",
|
||||||
|
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
|
enable_proxy=True,
|
||||||
|
client_password=args.client_password
|
||||||
|
)
|
||||||
|
active_environments.append(env)
|
||||||
|
|
||||||
|
logger.info(f"Process {current_process().name} started.")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
item = task_queue.get(timeout=5)
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
domain, example_id = item
|
||||||
|
try:
|
||||||
config_file = os.path.join(
|
config_file = os.path.join(
|
||||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||||
)
|
)
|
||||||
with open(config_file, "r", encoding="utf-8") as f:
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
example = json.load(f)
|
example = json.load(f)
|
||||||
|
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||||
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
|
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||||
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
|
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||||
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
|
|
||||||
|
|
||||||
example_result_dir = os.path.join(
|
example_result_dir = os.path.join(
|
||||||
args.result_dir,
|
args.result_dir,
|
||||||
args.action_space,
|
args.action_space,
|
||||||
@@ -236,7 +211,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
|||||||
example_id,
|
example_id,
|
||||||
)
|
)
|
||||||
os.makedirs(example_result_dir, exist_ok=True)
|
os.makedirs(example_result_dir, exist_ok=True)
|
||||||
|
|
||||||
agent = OpenCUAAgent(
|
agent = OpenCUAAgent(
|
||||||
env=env,
|
env=env,
|
||||||
model=args.model,
|
model=args.model,
|
||||||
@@ -251,7 +225,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
|||||||
coordinate_type=args.coordinate_type,
|
coordinate_type=args.coordinate_type,
|
||||||
max_image_history_length=args.max_image_history_length,
|
max_image_history_length=args.max_image_history_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lib_run_single.run_single_example_opencua(
|
lib_run_single.run_single_example_opencua(
|
||||||
agent,
|
agent,
|
||||||
@@ -265,7 +238,7 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
|
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
try:
|
try:
|
||||||
env.controller.end_recording(
|
env.controller.end_recording(
|
||||||
@@ -273,7 +246,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
|||||||
)
|
)
|
||||||
except Exception as rec_e:
|
except Exception as rec_e:
|
||||||
logger.error(f"Failed to end recording: {rec_e}")
|
logger.error(f"Failed to end recording: {rec_e}")
|
||||||
|
|
||||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
f.write(
|
f.write(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
@@ -281,15 +253,23 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
finally:
|
finally:
|
||||||
# This ensures the environment is closed even if there's an exception
|
logger.info(f"{current_process().name} cleaning up environment...")
|
||||||
logger.info(f"Process {env_idx + 1} cleaning up environment...")
|
|
||||||
try:
|
try:
|
||||||
env.close()
|
if env:
|
||||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
env.close()
|
||||||
|
logger.info(f"{current_process().name} environment closed successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Process {env_idx + 1} error during environment cleanup: {e}")
|
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||||
|
|
||||||
|
|
||||||
def signal_handler(signum, frame):
|
def signal_handler(signum, frame):
|
||||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||||
@@ -328,8 +308,8 @@ def signal_handler(signum, frame):
|
|||||||
if p.is_alive():
|
if p.is_alive():
|
||||||
try:
|
try:
|
||||||
logger.info(f"Forcefully terminating process {p.name}...")
|
logger.info(f"Forcefully terminating process {p.name}...")
|
||||||
import signal
|
import signal as sig
|
||||||
os.kill(p.pid, signal.SIGKILL)
|
os.kill(p.pid, sig.SIGKILL)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error forcefully terminating process: {e}")
|
logger.error(f"Error forcefully terminating process: {e}")
|
||||||
|
|
||||||
@@ -340,38 +320,56 @@ def signal_handler(signum, frame):
|
|||||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||||
global processes
|
global processes
|
||||||
logger.info("Args: %s", args)
|
logger.info("Args: %s", args)
|
||||||
|
all_tasks = distribute_tasks(test_all_meta)
|
||||||
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
|
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||||
|
|
||||||
logger.info("All environments are ready. Starting parallel task execution...")
|
|
||||||
|
|
||||||
# Create a shared list for scores across processes
|
|
||||||
with Manager() as manager:
|
with Manager() as manager:
|
||||||
shared_scores = manager.list()
|
shared_scores = manager.list()
|
||||||
|
task_queue = manager.Queue()
|
||||||
# Create and start processes for each environment
|
for item in all_tasks:
|
||||||
|
task_queue.put(item)
|
||||||
|
num_envs = args.num_envs
|
||||||
processes = []
|
processes = []
|
||||||
for env_idx, env_tasks in enumerate(distributed_tasks):
|
for i in range(num_envs):
|
||||||
p = Process(
|
p = Process(
|
||||||
target=run_env_tasks,
|
target=run_env_tasks,
|
||||||
args=(env_idx, env_tasks, args, shared_scores)
|
args=(task_queue, args, shared_scores),
|
||||||
|
name=f"EnvProcess-{i+1}"
|
||||||
)
|
)
|
||||||
processes.append(p)
|
p.daemon = True
|
||||||
p.start()
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Wait for all processes to complete
|
while True:
|
||||||
|
alive_count = 0
|
||||||
|
for idx, p in enumerate(processes):
|
||||||
|
if not p.is_alive():
|
||||||
|
logger.warning(f"Process {p.name} died, restarting...")
|
||||||
|
new_p = Process(
|
||||||
|
target=run_env_tasks,
|
||||||
|
args=(task_queue, args, shared_scores),
|
||||||
|
name=f"EnvProcess-Restart-{idx+1}"
|
||||||
|
)
|
||||||
|
new_p.daemon = True
|
||||||
|
new_p.start()
|
||||||
|
processes[idx] = new_p
|
||||||
|
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||||
|
else:
|
||||||
|
alive_count += 1
|
||||||
|
if task_queue.empty():
|
||||||
|
logger.info("All tasks finished.")
|
||||||
|
break
|
||||||
|
if alive_count == 0:
|
||||||
|
logger.error("All processes died, exiting.")
|
||||||
|
break
|
||||||
|
time.sleep(5)
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
logger.info(f"Process {p.name} completed")
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||||
# Let the signal handler do the cleanup
|
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||||
# Ensure cleanup happens
|
|
||||||
for p in processes:
|
for p in processes:
|
||||||
if p.is_alive():
|
if p.is_alive():
|
||||||
try:
|
try:
|
||||||
@@ -380,10 +378,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
except Exception as term_e:
|
except Exception as term_e:
|
||||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Convert shared list to regular list
|
|
||||||
scores = list(shared_scores)
|
scores = list(shared_scores)
|
||||||
|
|
||||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||||
|
|
||||||
|
|
||||||
@@ -469,6 +464,18 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
args = config()
|
args = config()
|
||||||
|
|
||||||
|
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||||
|
path_to_args = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
args.model,
|
||||||
|
"args.json",
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||||
|
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(vars(args), f, indent=4)
|
||||||
|
|
||||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||||
test_all_meta = json.load(f)
|
test_all_meta = json.load(f)
|
||||||
|
|||||||
@@ -11,10 +11,29 @@ from typing import List, Dict
|
|||||||
import math
|
import math
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from multiprocessing import Process, Manager
|
from multiprocessing import Process, Manager
|
||||||
|
from multiprocessing import current_process
|
||||||
import lib_run_single
|
import lib_run_single
|
||||||
from desktop_env.desktop_env import DesktopEnv
|
from desktop_env.desktop_env import DesktopEnv
|
||||||
from mm_agents.uitars15_agent import UITarsAgent
|
from mm_agents.uitars15_agent import UITarsAgent
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
import os
|
||||||
|
|
||||||
|
# def clear_cache():
|
||||||
|
# cache_path = "cache"
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# if os.path.exists(cache_path):
|
||||||
|
# logger.info(f"Deleting cache directory: {cache_path}")
|
||||||
|
# shutil.rmtree(cache_path)
|
||||||
|
# logger.info(f"Cache directory deleted successfully")
|
||||||
|
# else:
|
||||||
|
# logger.info(f"Cache directory {cache_path} does not exist")
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"Error deleting cache directory: {e}")
|
||||||
|
|
||||||
|
# clear_cache()
|
||||||
|
|
||||||
# Global variables for signal handling
|
# Global variables for signal handling
|
||||||
active_environments = []
|
active_environments = []
|
||||||
processes = []
|
processes = []
|
||||||
@@ -45,7 +64,7 @@ def config() -> argparse.Namespace:
|
|||||||
default="screenshot",
|
default="screenshot",
|
||||||
help="Observation type",
|
help="Observation type",
|
||||||
)
|
)
|
||||||
parser.add_argument("--sleep_after_execution", type=float, default=0)
|
parser.add_argument("--sleep_after_execution", type=float, default=3.0)
|
||||||
parser.add_argument("--max_steps", type=int, default=15)
|
parser.add_argument("--max_steps", type=int, default=15)
|
||||||
|
|
||||||
# evaluation config
|
# evaluation config
|
||||||
@@ -58,6 +77,7 @@ def config() -> argparse.Namespace:
|
|||||||
parser.add_argument("--temperature", type=float, default=0)
|
parser.add_argument("--temperature", type=float, default=0)
|
||||||
parser.add_argument("--top_p", type=float, default=None)
|
parser.add_argument("--top_p", type=float, default=None)
|
||||||
parser.add_argument("--max_tokens", type=int, default=3000)
|
parser.add_argument("--max_tokens", type=int, default=3000)
|
||||||
|
parser.add_argument("--use_thinking", action="store_true", default=False)
|
||||||
|
|
||||||
# OpenCUAagent config
|
# OpenCUAagent config
|
||||||
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
|
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
|
||||||
@@ -131,32 +151,12 @@ logger.addHandler(stdout_handler)
|
|||||||
logger = logging.getLogger("desktopenv.experiment")
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
|
||||||
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
|
||||||
"""Distribute tasks evenly across environments."""
|
|
||||||
# Flatten the tasks into a single list
|
|
||||||
all_tasks = []
|
all_tasks = []
|
||||||
for domain, examples in test_all_meta.items():
|
for domain, examples in test_all_meta.items():
|
||||||
for example_id in examples:
|
for example_id in examples:
|
||||||
all_tasks.append((domain, example_id))
|
all_tasks.append((domain, example_id))
|
||||||
|
return all_tasks
|
||||||
# Calculate tasks per environment
|
|
||||||
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
|
|
||||||
|
|
||||||
# Distribute tasks
|
|
||||||
distributed_tasks = []
|
|
||||||
for i in range(num_envs):
|
|
||||||
env_tasks = {}
|
|
||||||
start_idx = i * tasks_per_env
|
|
||||||
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
|
|
||||||
|
|
||||||
for domain, example_id in all_tasks[start_idx:end_idx]:
|
|
||||||
if domain not in env_tasks:
|
|
||||||
env_tasks[domain] = []
|
|
||||||
env_tasks[domain].append(example_id)
|
|
||||||
|
|
||||||
distributed_tasks.append(env_tasks)
|
|
||||||
|
|
||||||
return distributed_tasks
|
|
||||||
|
|
||||||
|
|
||||||
def process_signal_handler(signum, frame, env_idx):
|
def process_signal_handler(signum, frame, env_idx):
|
||||||
@@ -180,62 +180,55 @@ def process_signal_handler(signum, frame, env_idx):
|
|||||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||||
def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
|
|
||||||
"""Run tasks for a single environment."""
|
|
||||||
# Each process has its own list of active environments
|
|
||||||
active_environments = []
|
active_environments = []
|
||||||
env = None
|
env = None
|
||||||
|
|
||||||
# Setup signal handlers for this process too
|
|
||||||
signal.signal(signal.SIGINT, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
|
|
||||||
signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
|
|
||||||
|
|
||||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
|
||||||
REGION = args.region
|
|
||||||
screen_size = (args.screen_width, args.screen_height)
|
|
||||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
|
||||||
env = DesktopEnv(
|
|
||||||
path_to_vm=args.path_to_vm,
|
|
||||||
action_space=args.action_space,
|
|
||||||
provider_name=args.provider_name,
|
|
||||||
region=REGION,
|
|
||||||
snapshot_name=ami_id,
|
|
||||||
screen_size=screen_size,
|
|
||||||
headless=args.headless,
|
|
||||||
os_type="Ubuntu",
|
|
||||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
|
||||||
enable_proxy=True,
|
|
||||||
client_password=args.client_password
|
|
||||||
)
|
|
||||||
active_environments.append(env)
|
|
||||||
agent = UITarsAgent(
|
|
||||||
model=args.model,
|
|
||||||
max_tokens=args.max_tokens,
|
|
||||||
top_p=args.top_p,
|
|
||||||
temperature=args.temperature,
|
|
||||||
|
|
||||||
max_trajectory_length=args.max_trajectory_length,
|
|
||||||
max_image_history_length=args.max_image_history_length,
|
|
||||||
use_thinking=True,
|
|
||||||
language=args.language,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
|
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||||
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
|
REGION = args.region
|
||||||
|
screen_size = (args.screen_width, args.screen_height)
|
||||||
|
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||||
|
env = DesktopEnv(
|
||||||
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=args.action_space,
|
||||||
|
provider_name=args.provider_name,
|
||||||
|
region=REGION,
|
||||||
|
snapshot_name=ami_id,
|
||||||
|
screen_size=screen_size,
|
||||||
|
headless=args.headless,
|
||||||
|
os_type="Ubuntu",
|
||||||
|
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
|
enable_proxy=True,
|
||||||
|
client_password=args.client_password
|
||||||
|
)
|
||||||
|
active_environments.append(env)
|
||||||
|
agent = UITarsAgent(
|
||||||
|
model=args.model,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
top_p=args.top_p,
|
||||||
|
temperature=args.temperature,
|
||||||
|
|
||||||
|
max_trajectory_length=args.max_trajectory_length,
|
||||||
|
max_image_history_length=args.max_image_history_length,
|
||||||
|
use_thinking=args.use_thinking,
|
||||||
|
language=args.language,
|
||||||
|
)
|
||||||
|
logger.info(f"Process {current_process().name} started.")
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
item = task_queue.get(timeout=5)
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
domain, example_id = item
|
||||||
|
try:
|
||||||
config_file = os.path.join(
|
config_file = os.path.join(
|
||||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||||
)
|
)
|
||||||
with open(config_file, "r", encoding="utf-8") as f:
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
example = json.load(f)
|
example = json.load(f)
|
||||||
|
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||||
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
|
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||||
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
|
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||||
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
|
|
||||||
|
|
||||||
example_result_dir = os.path.join(
|
example_result_dir = os.path.join(
|
||||||
args.result_dir,
|
args.result_dir,
|
||||||
args.action_space,
|
args.action_space,
|
||||||
@@ -258,7 +251,7 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
|
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
try:
|
try:
|
||||||
env.controller.end_recording(
|
env.controller.end_recording(
|
||||||
@@ -266,7 +259,6 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
|||||||
)
|
)
|
||||||
except Exception as rec_e:
|
except Exception as rec_e:
|
||||||
logger.error(f"Failed to end recording: {rec_e}")
|
logger.error(f"Failed to end recording: {rec_e}")
|
||||||
|
|
||||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
f.write(
|
f.write(
|
||||||
json.dumps(
|
json.dumps(
|
||||||
@@ -274,14 +266,23 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
finally:
|
finally:
|
||||||
# This ensures the environment is closed even if there's an exception
|
logger.info(f"{current_process().name} cleaning up environment...")
|
||||||
logger.info(f"Process {env_idx + 1} cleaning up environment...")
|
|
||||||
try:
|
try:
|
||||||
env.close()
|
if env:
|
||||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
env.close()
|
||||||
|
logger.info(f"{current_process().name} environment closed successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Process {env_idx + 1} error during environment cleanup: {e}")
|
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def signal_handler(signum, frame):
|
def signal_handler(signum, frame):
|
||||||
@@ -321,8 +322,8 @@ def signal_handler(signum, frame):
|
|||||||
if p.is_alive():
|
if p.is_alive():
|
||||||
try:
|
try:
|
||||||
logger.info(f"Forcefully terminating process {p.name}...")
|
logger.info(f"Forcefully terminating process {p.name}...")
|
||||||
import signal
|
import signal as sig
|
||||||
os.kill(p.pid, signal.SIGKILL)
|
os.kill(p.pid, sig.SIGKILL)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error forcefully terminating process: {e}")
|
logger.error(f"Error forcefully terminating process: {e}")
|
||||||
|
|
||||||
@@ -333,38 +334,56 @@ def signal_handler(signum, frame):
|
|||||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||||
global processes
|
global processes
|
||||||
logger.info("Args: %s", args)
|
logger.info("Args: %s", args)
|
||||||
|
all_tasks = distribute_tasks(test_all_meta)
|
||||||
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
|
logger.info(f"Total tasks: {len(all_tasks)}")
|
||||||
|
|
||||||
logger.info("All environments are ready. Starting parallel task execution...")
|
|
||||||
|
|
||||||
# Create a shared list for scores across processes
|
|
||||||
with Manager() as manager:
|
with Manager() as manager:
|
||||||
shared_scores = manager.list()
|
shared_scores = manager.list()
|
||||||
|
task_queue = manager.Queue()
|
||||||
# Create and start processes for each environment
|
for item in all_tasks:
|
||||||
|
task_queue.put(item)
|
||||||
|
num_envs = args.num_envs
|
||||||
processes = []
|
processes = []
|
||||||
for env_idx, env_tasks in enumerate(distributed_tasks):
|
for i in range(num_envs):
|
||||||
p = Process(
|
p = Process(
|
||||||
target=run_env_tasks,
|
target=run_env_tasks,
|
||||||
args=(env_idx, env_tasks, args, shared_scores)
|
args=(task_queue, args, shared_scores),
|
||||||
|
name=f"EnvProcess-{i+1}"
|
||||||
)
|
)
|
||||||
processes.append(p)
|
p.daemon = True
|
||||||
p.start()
|
p.start()
|
||||||
|
processes.append(p)
|
||||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Wait for all processes to complete
|
while True:
|
||||||
|
alive_count = 0
|
||||||
|
for idx, p in enumerate(processes):
|
||||||
|
if not p.is_alive():
|
||||||
|
logger.warning(f"Process {p.name} died, restarting...")
|
||||||
|
new_p = Process(
|
||||||
|
target=run_env_tasks,
|
||||||
|
args=(task_queue, args, shared_scores),
|
||||||
|
name=f"EnvProcess-Restart-{idx+1}"
|
||||||
|
)
|
||||||
|
new_p.daemon = True
|
||||||
|
new_p.start()
|
||||||
|
processes[idx] = new_p
|
||||||
|
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
|
||||||
|
else:
|
||||||
|
alive_count += 1
|
||||||
|
if task_queue.empty():
|
||||||
|
logger.info("All tasks finished.")
|
||||||
|
break
|
||||||
|
if alive_count == 0:
|
||||||
|
logger.error("All processes died, exiting.")
|
||||||
|
break
|
||||||
|
time.sleep(5)
|
||||||
for p in processes:
|
for p in processes:
|
||||||
p.join()
|
p.join()
|
||||||
logger.info(f"Process {p.name} completed")
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||||
# Let the signal handler do the cleanup
|
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||||
# Ensure cleanup happens
|
|
||||||
for p in processes:
|
for p in processes:
|
||||||
if p.is_alive():
|
if p.is_alive():
|
||||||
try:
|
try:
|
||||||
@@ -373,10 +392,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
except Exception as term_e:
|
except Exception as term_e:
|
||||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# Convert shared list to regular list
|
|
||||||
scores = list(shared_scores)
|
scores = list(shared_scores)
|
||||||
|
|
||||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||||
|
|
||||||
|
|
||||||
@@ -462,6 +478,18 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
args = config()
|
args = config()
|
||||||
|
|
||||||
|
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||||
|
path_to_args = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
args.model,
|
||||||
|
"args.json",
|
||||||
|
)
|
||||||
|
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||||
|
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(vars(args), f, indent=4)
|
||||||
|
|
||||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||||
test_all_meta = json.load(f)
|
test_all_meta = json.load(f)
|
||||||
|
|||||||
Reference in New Issue
Block a user