fix(os_symphony) (#405)

This commit is contained in:
Bowen Yang
2025-12-30 22:43:47 +08:00
committed by GitHub
parent 662826f57e
commit 02a35be067

View File

@@ -138,17 +138,67 @@ def run_env_tasks(
try:
# Use IMAGE_ID_MAP for AWS provider to get snapshot_name
snapshot_name = None
region = getattr(args, "region", None)
region = getattr(args, "region", "us-east-1")
platform = 'linux'
screen_size = (args.screen_width, args.screen_height)
if "osworld" in args.benchmark:
env = OSWorldDesktopEnv(
if args.provider_name == "aws":
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
env = OSWorldDesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=region,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
)
elif args.provider_name == "docker":
env = OSWorldDesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=region,
snapshot_name=snapshot_name,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type
in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=getattr(args, "client_password", "")
)
else:
raise Exception("Don't support other providers!")
env.start()
if args.provider_name == "aws":
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
search_env = OSWorldDesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=region,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
)
elif args.provider_name == "docker":
search_env = OSWorldDesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=region,
snapshot_name=snapshot_name,
screen_size=(args.screen_width, args.screen_height),
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type
@@ -156,25 +206,9 @@ def run_env_tasks(
enable_proxy=True,
client_password=getattr(args, "client_password", "")
)
env.start()
platform = "linux"
search_env = OSWorldDesktopEnv(
path_to_vm=args.searcher_path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=region,
snapshot_name=snapshot_name,
screen_size=(args.searcher_screen_width, args.searcher_screen_height),
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type
in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=getattr(args, "client_password", "")
)
else:
raise Exception("Don't support other providers!")
engine_params_for_ocr = copy.deepcopy(engine_params_for_orchestrator)
engine_params_for_ocr["agent_name"] = "ocr"
os_aci = OSACI(