diff --git a/run_autoglm.py b/run_autoglm.py index 343be5d..5b3a947 100644 --- a/run_autoglm.py +++ b/run_autoglm.py @@ -99,6 +99,14 @@ def config() -> argparse.Namespace: parser.add_argument("--domain", type=str, default="all") parser.add_argument("--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json") + # aws config + parser.add_argument( + "--region", type=str, default="us-east-1", help="AWS region for the VM" + ) + parser.add_argument( + "--client_password", type=str, default="", help="Client password" + ) + # logging related parser.add_argument("--result_dir", type=str, default="./results") args = parser.parse_args() @@ -324,6 +332,8 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: env = DesktopEnv( provider_name=args.provider_name, + region=args.region, + client_password=args.client_password, path_to_vm=args.path_to_vm, action_space=args.action_space, screen_size=(args.screen_width, args.screen_height), @@ -335,6 +345,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: action_space=args.action_space, observation_type=args.observation_type, max_trajectory_length=args.max_trajectory_length, + client_password=args.client_password, gen_func=call_llm, ) @@ -454,6 +465,13 @@ if __name__ == "__main__": ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" args = config() + if args.client_password == "": + if args.provider_name == "aws": + args.client_password = "osworld-public-evaluation" + else: + args.client_password = "password" + else: + args.client_password = args.client_password # save args to json in result_dir/action_space/observation_type/model/args.json path_to_args = os.path.join( diff --git a/run_multienv_autoglm.py b/run_multienv_autoglm.py index 77c7471..f6d74c6 100644 --- a/run_multienv_autoglm.py +++ b/run_multienv_autoglm.py @@ -82,6 +82,14 @@ def config() -> argparse.Namespace: "--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json" ) + # aws config + parser.add_argument( + "--region", type=str, default="us-east-1", help="AWS region for the VM" + ) + parser.add_argument( + "--client_password", type=str, default="", help="Client password" + ) + # logging related parser.add_argument("--result_dir", type=str, default="./results") parser.add_argument("--num_envs", type=int, default=20, help="Number of environments to run in parallel") @@ -91,6 +99,13 @@ def config() -> argparse.Namespace: return args args = config() # Get command line arguments first +if args.client_password == "": + if args.provider_name == "aws": + args.client_password = "osworld-public-evaluation" + else: + args.client_password = "password" +else: + args.client_password = args.client_password logger = logging.getLogger() log_level = getattr(logging, args.log_level.upper()) @@ -186,6 +201,8 @@ def run_env_tasks(task_queue, args, shared_scores): env = DesktopEnv( provider_name=args.provider_name, + region=args.region, + client_password=args.client_password, path_to_vm=args.path_to_vm, action_space=args.action_space, screen_size=(args.screen_width, args.screen_height), @@ -198,6 +215,7 @@ def run_env_tasks(task_queue, args, shared_scores): action_space=args.action_space, observation_type=args.observation_type, max_trajectory_length=args.max_trajectory_length, + client_password=args.client_password, gen_func=call_llm, ) logger.info(f"Process {current_process().name} started.")