init public release (#350)
This commit is contained in:
@@ -14,7 +14,8 @@ from multiprocessing import Process, Manager
|
||||
from multiprocessing import current_process
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.gta1_agent import GTA1Agent
|
||||
from mm_agents.gta1.gta1_agent import GTA1Agent
|
||||
from mm_agents.gta1.gta15_agent import run_cua_gpt5gta1
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
@@ -58,6 +59,8 @@ def config() -> argparse.Namespace:
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="o3")
|
||||
parser.add_argument("--tts_step", type=int, default=8)
|
||||
parser.add_argument("--purge_history_images", type=int, default=8)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
@@ -156,7 +159,7 @@ def process_signal_handler(signum, frame, env_idx):
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
def run_env_tasks_o3(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
@@ -200,9 +203,6 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
@@ -228,7 +228,7 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "w") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
@@ -253,6 +253,106 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
def run_env_tasks_gpt5(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
|
||||
active_environments = []
|
||||
env = None
|
||||
try:
|
||||
if args.provider_name == "aws":
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = args.region
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
|
||||
else:
|
||||
REGION = None
|
||||
ami_id = None
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=REGION,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=args.client_password
|
||||
)
|
||||
active_environments.append(env)
|
||||
logger.info(f"Process {current_process().name} started.")
|
||||
while True:
|
||||
try:
|
||||
item = task_queue.get(timeout=5)
|
||||
except Exception:
|
||||
break
|
||||
domain, example_id = item
|
||||
try:
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
logger.info(f"[{current_process().name}][Domain]: {domain}")
|
||||
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
|
||||
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
try:
|
||||
env.reset(task_config=example)
|
||||
time.sleep(15)
|
||||
obs = env._get_obs()
|
||||
|
||||
_, traj = run_cua_gpt5gta1(
|
||||
env=env,
|
||||
instruction=example["instruction"],
|
||||
max_steps=args.max_steps,
|
||||
save_path=example_result_dir,
|
||||
sleep_after_execution=args.sleep_after_execution,
|
||||
screen_width=args.screen_width,
|
||||
screen_height=args.screen_height,
|
||||
client_password=args.client_password,
|
||||
tts_step=args.tts_step,
|
||||
purge_history_images=args.purge_history_images,
|
||||
cua_model=args.model,
|
||||
logger=logger,
|
||||
)
|
||||
time.sleep(15)
|
||||
result = env.evaluate()
|
||||
shared_scores.append(result)
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "w") as f:
|
||||
json.dump(traj, f)
|
||||
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Task-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
except Exception as e:
|
||||
logger.error(f"Process-level error in {current_process().name}: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
finally:
|
||||
logger.info(f"{current_process().name} cleaning up environment...")
|
||||
try:
|
||||
if env:
|
||||
env.close()
|
||||
logger.info(f"{current_process().name} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"{current_process().name} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||
global is_terminating, active_environments, processes
|
||||
@@ -313,7 +413,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
processes = []
|
||||
for i in range(num_envs):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
target=run_env_tasks_o3 if args.model == "o3" else run_env_tasks_gpt5,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-{i+1}"
|
||||
)
|
||||
@@ -328,7 +428,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
if not p.is_alive():
|
||||
logger.warning(f"Process {p.name} died, restarting...")
|
||||
new_p = Process(
|
||||
target=run_env_tasks,
|
||||
target=run_env_tasks_o3 if args.model == "o3" else run_env_tasks_gpt5,
|
||||
args=(task_queue, args, shared_scores),
|
||||
name=f"EnvProcess-Restart-{idx+1}"
|
||||
)
|
||||
@@ -367,7 +467,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
target_dir = result_dir
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
@@ -402,7 +502,7 @@ def get_unfinished(
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
target_dir = result_dir
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
@@ -446,18 +546,6 @@ if __name__ == "__main__":
|
||||
|
||||
try:
|
||||
args = config()
|
||||
|
||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||
path_to_args = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
"args.json",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
|
||||
with open(path_to_args, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(args), f, indent=4)
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
Reference in New Issue
Block a user