init public release (#350)

This commit is contained in:
Yan98
2025-10-07 01:16:31 +11:00
committed by GitHub
parent 5eff00a9e3
commit ddb8372a6c
5 changed files with 1106 additions and 26 deletions

View File

@@ -14,7 +14,8 @@ from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.gta1_agent import GTA1Agent
from mm_agents.gta1.gta1_agent import GTA1Agent
from mm_agents.gta1.gta15_agent import run_cua_gpt5gta1
# Global variables for signal handling
active_environments = []
@@ -58,6 +59,8 @@ def config() -> argparse.Namespace:
# lm config
parser.add_argument("--model", type=str, default="o3")
parser.add_argument("--tts_step", type=int, default=8)
parser.add_argument("--purge_history_images", type=int, default=8)
# example config
parser.add_argument("--domain", type=str, default="all")
@@ -156,7 +159,7 @@ def process_signal_handler(signum, frame, env_idx):
sys.exit(0)
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
def run_env_tasks_o3(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
try:
@@ -200,9 +203,6 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
@@ -228,7 +228,7 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
with open(os.path.join(example_result_dir, "traj.jsonl"), "w") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
@@ -253,6 +253,106 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def run_env_tasks_gpt5(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
try:
if args.provider_name == "aws":
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
else:
REGION = None
ami_id = None
screen_size = (args.screen_width, args.screen_height)
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
logger.info(f"Process {current_process().name} started.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
env.reset(task_config=example)
time.sleep(15)
obs = env._get_obs()
_, traj = run_cua_gpt5gta1(
env=env,
instruction=example["instruction"],
max_steps=args.max_steps,
save_path=example_result_dir,
sleep_after_execution=args.sleep_after_execution,
screen_width=args.screen_width,
screen_height=args.screen_height,
client_password=args.client_password,
tts_step=args.tts_step,
purge_history_images=args.purge_history_images,
cua_model=args.model,
logger=logger,
)
time.sleep(15)
result = env.evaluate()
shared_scores.append(result)
with open(os.path.join(example_result_dir, "traj.jsonl"), "w") as f:
json.dump(traj, f)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
global is_terminating, active_environments, processes
@@ -313,7 +413,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
processes = []
for i in range(num_envs):
p = Process(
target=run_env_tasks,
target=run_env_tasks_o3 if args.model == "o3" else run_env_tasks_gpt5,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
)
@@ -328,7 +428,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
target=run_env_tasks_o3 if args.model == "o3" else run_env_tasks_gpt5,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
@@ -367,7 +467,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
target_dir = result_dir
if not os.path.exists(target_dir):
return total_file_json
@@ -402,7 +502,7 @@ def get_unfinished(
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
target_dir = result_dir
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
@@ -446,18 +546,6 @@ if __name__ == "__main__":
try:
args = config()
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)