fix(os_symphony) (#405)
This commit is contained in:
@@ -138,17 +138,67 @@ def run_env_tasks(
|
||||
try:
|
||||
# Use IMAGE_ID_MAP for AWS provider to get snapshot_name
|
||||
snapshot_name = None
|
||||
region = getattr(args, "region", None)
|
||||
|
||||
region = getattr(args, "region", "us-east-1")
|
||||
platform = 'linux'
|
||||
screen_size = (args.screen_width, args.screen_height)
|
||||
|
||||
if "osworld" in args.benchmark:
|
||||
env = OSWorldDesktopEnv(
|
||||
if args.provider_name == "aws":
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||
env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
)
|
||||
elif args.provider_name == "docker":
|
||||
env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=getattr(args, "client_password", "")
|
||||
)
|
||||
else:
|
||||
raise Exception("Don't support other providers!")
|
||||
|
||||
env.start()
|
||||
|
||||
if args.provider_name == "aws":
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||
search_env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=ami_id,
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
)
|
||||
elif args.provider_name == "docker":
|
||||
search_env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
screen_size=screen_size,
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
@@ -156,25 +206,9 @@ def run_env_tasks(
|
||||
enable_proxy=True,
|
||||
client_password=getattr(args, "client_password", "")
|
||||
)
|
||||
env.start()
|
||||
|
||||
platform = "linux"
|
||||
|
||||
search_env = OSWorldDesktopEnv(
|
||||
path_to_vm=args.searcher_path_to_vm,
|
||||
action_space=args.action_space,
|
||||
provider_name=args.provider_name,
|
||||
region=region,
|
||||
snapshot_name=snapshot_name,
|
||||
screen_size=(args.searcher_screen_width, args.searcher_screen_height),
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
enable_proxy=True,
|
||||
client_password=getattr(args, "client_password", "")
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception("Don't support other providers!")
|
||||
|
||||
engine_params_for_ocr = copy.deepcopy(engine_params_for_orchestrator)
|
||||
engine_params_for_ocr["agent_name"] = "ocr"
|
||||
os_aci = OSACI(
|
||||
|
||||
Reference in New Issue
Block a user