fix(os_symphony) (#405)
This commit is contained in:
@@ -138,17 +138,67 @@ def run_env_tasks(
|
|||||||
try:
|
try:
|
||||||
# Use IMAGE_ID_MAP for AWS provider to get snapshot_name
|
# Use IMAGE_ID_MAP for AWS provider to get snapshot_name
|
||||||
snapshot_name = None
|
snapshot_name = None
|
||||||
region = getattr(args, "region", None)
|
region = getattr(args, "region", "us-east-1")
|
||||||
|
|
||||||
platform = 'linux'
|
platform = 'linux'
|
||||||
|
screen_size = (args.screen_width, args.screen_height)
|
||||||
|
|
||||||
if "osworld" in args.benchmark:
|
if "osworld" in args.benchmark:
|
||||||
env = OSWorldDesktopEnv(
|
if args.provider_name == "aws":
|
||||||
|
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||||
|
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||||
|
env = OSWorldDesktopEnv(
|
||||||
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=args.action_space,
|
||||||
|
provider_name=args.provider_name,
|
||||||
|
region=region,
|
||||||
|
snapshot_name=ami_id,
|
||||||
|
screen_size=screen_size,
|
||||||
|
headless=args.headless,
|
||||||
|
os_type="Ubuntu",
|
||||||
|
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
|
||||||
|
)
|
||||||
|
elif args.provider_name == "docker":
|
||||||
|
env = OSWorldDesktopEnv(
|
||||||
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=args.action_space,
|
||||||
|
provider_name=args.provider_name,
|
||||||
|
region=region,
|
||||||
|
snapshot_name=snapshot_name,
|
||||||
|
screen_size=screen_size,
|
||||||
|
headless=args.headless,
|
||||||
|
os_type="Ubuntu",
|
||||||
|
require_a11y_tree=args.observation_type
|
||||||
|
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
|
enable_proxy=True,
|
||||||
|
client_password=getattr(args, "client_password", "")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("Don't support other providers!")
|
||||||
|
|
||||||
|
env.start()
|
||||||
|
|
||||||
|
if args.provider_name == "aws":
|
||||||
|
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||||
|
ami_id = IMAGE_ID_MAP[region].get(screen_size, IMAGE_ID_MAP[region][(1920, 1080)])
|
||||||
|
search_env = OSWorldDesktopEnv(
|
||||||
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=args.action_space,
|
||||||
|
provider_name=args.provider_name,
|
||||||
|
region=region,
|
||||||
|
snapshot_name=ami_id,
|
||||||
|
screen_size=screen_size,
|
||||||
|
headless=args.headless,
|
||||||
|
os_type="Ubuntu",
|
||||||
|
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
|
||||||
|
)
|
||||||
|
elif args.provider_name == "docker":
|
||||||
|
search_env = OSWorldDesktopEnv(
|
||||||
path_to_vm=args.path_to_vm,
|
path_to_vm=args.path_to_vm,
|
||||||
action_space=args.action_space,
|
action_space=args.action_space,
|
||||||
provider_name=args.provider_name,
|
provider_name=args.provider_name,
|
||||||
region=region,
|
region=region,
|
||||||
snapshot_name=snapshot_name,
|
snapshot_name=snapshot_name,
|
||||||
screen_size=(args.screen_width, args.screen_height),
|
screen_size=screen_size,
|
||||||
headless=args.headless,
|
headless=args.headless,
|
||||||
os_type="Ubuntu",
|
os_type="Ubuntu",
|
||||||
require_a11y_tree=args.observation_type
|
require_a11y_tree=args.observation_type
|
||||||
@@ -156,25 +206,9 @@ def run_env_tasks(
|
|||||||
enable_proxy=True,
|
enable_proxy=True,
|
||||||
client_password=getattr(args, "client_password", "")
|
client_password=getattr(args, "client_password", "")
|
||||||
)
|
)
|
||||||
env.start()
|
else:
|
||||||
|
raise Exception("Don't support other providers!")
|
||||||
platform = "linux"
|
|
||||||
|
|
||||||
search_env = OSWorldDesktopEnv(
|
|
||||||
path_to_vm=args.searcher_path_to_vm,
|
|
||||||
action_space=args.action_space,
|
|
||||||
provider_name=args.provider_name,
|
|
||||||
region=region,
|
|
||||||
snapshot_name=snapshot_name,
|
|
||||||
screen_size=(args.searcher_screen_width, args.searcher_screen_height),
|
|
||||||
headless=args.headless,
|
|
||||||
os_type="Ubuntu",
|
|
||||||
require_a11y_tree=args.observation_type
|
|
||||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
|
||||||
enable_proxy=True,
|
|
||||||
client_password=getattr(args, "client_password", "")
|
|
||||||
)
|
|
||||||
|
|
||||||
engine_params_for_ocr = copy.deepcopy(engine_params_for_orchestrator)
|
engine_params_for_ocr = copy.deepcopy(engine_params_for_orchestrator)
|
||||||
engine_params_for_ocr["agent_name"] = "ocr"
|
engine_params_for_ocr["agent_name"] = "ocr"
|
||||||
os_aci = OSACI(
|
os_aci = OSACI(
|
||||||
|
|||||||
Reference in New Issue
Block a user