Add AWS config for autoglm-os agent script (#311)

* Add AWS config for autoglm-os agent script

* update default password
This commit is contained in:
Adam Yanxiao Zhao
2025-08-17 22:54:23 +08:00
committed by GitHub
parent 2664eba23b
commit deff1fe385
2 changed files with 36 additions and 0 deletions

View File

@@ -99,6 +99,14 @@ def config() -> argparse.Namespace:
parser.add_argument("--domain", type=str, default="all")
parser.add_argument("--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
args = parser.parse_args()
@@ -324,6 +332,8 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
env = DesktopEnv(
provider_name=args.provider_name,
region=args.region,
client_password=args.client_password,
path_to_vm=args.path_to_vm,
action_space=args.action_space,
screen_size=(args.screen_width, args.screen_height),
@@ -335,6 +345,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
client_password=args.client_password,
gen_func=call_llm,
)
@@ -454,6 +465,13 @@ if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
if args.client_password == "":
if args.provider_name == "aws":
args.client_password = "osworld-public-evaluation"
else:
args.client_password = "password"
else:
args.client_password = args.client_password
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(

View File

@@ -82,6 +82,14 @@ def config() -> argparse.Namespace:
"--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json"
)
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=20, help="Number of environments to run in parallel")
@@ -91,6 +99,13 @@ def config() -> argparse.Namespace:
return args
args = config() # Get command line arguments first
if args.client_password == "":
if args.provider_name == "aws":
args.client_password = "osworld-public-evaluation"
else:
args.client_password = "password"
else:
args.client_password = args.client_password
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
@@ -186,6 +201,8 @@ def run_env_tasks(task_queue, args, shared_scores):
env = DesktopEnv(
provider_name=args.provider_name,
region=args.region,
client_password=args.client_password,
path_to_vm=args.path_to_vm,
action_space=args.action_space,
screen_size=(args.screen_width, args.screen_height),
@@ -198,6 +215,7 @@ def run_env_tasks(task_queue, args, shared_scores):
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
client_password=args.client_password,
gen_func=call_llm,
)
logger.info(f"Process {current_process().name} started.")