Add AWS config for autoglm-os agent script (#311)
* Add AWS config for autoglm-os agent script * update default password
This commit is contained in:
committed by
GitHub
parent
2664eba23b
commit
deff1fe385
@@ -99,6 +99,14 @@ def config() -> argparse.Namespace:
|
|||||||
parser.add_argument("--domain", type=str, default="all")
|
parser.add_argument("--domain", type=str, default="all")
|
||||||
parser.add_argument("--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json")
|
parser.add_argument("--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json")
|
||||||
|
|
||||||
|
# aws config
|
||||||
|
parser.add_argument(
|
||||||
|
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--client_password", type=str, default="", help="Client password"
|
||||||
|
)
|
||||||
|
|
||||||
# logging related
|
# logging related
|
||||||
parser.add_argument("--result_dir", type=str, default="./results")
|
parser.add_argument("--result_dir", type=str, default="./results")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -324,6 +332,8 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
|
|
||||||
env = DesktopEnv(
|
env = DesktopEnv(
|
||||||
provider_name=args.provider_name,
|
provider_name=args.provider_name,
|
||||||
|
region=args.region,
|
||||||
|
client_password=args.client_password,
|
||||||
path_to_vm=args.path_to_vm,
|
path_to_vm=args.path_to_vm,
|
||||||
action_space=args.action_space,
|
action_space=args.action_space,
|
||||||
screen_size=(args.screen_width, args.screen_height),
|
screen_size=(args.screen_width, args.screen_height),
|
||||||
@@ -335,6 +345,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
action_space=args.action_space,
|
action_space=args.action_space,
|
||||||
observation_type=args.observation_type,
|
observation_type=args.observation_type,
|
||||||
max_trajectory_length=args.max_trajectory_length,
|
max_trajectory_length=args.max_trajectory_length,
|
||||||
|
client_password=args.client_password,
|
||||||
gen_func=call_llm,
|
gen_func=call_llm,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -454,6 +465,13 @@ if __name__ == "__main__":
|
|||||||
####### The complete version of the list of examples #######
|
####### The complete version of the list of examples #######
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
args = config()
|
args = config()
|
||||||
|
if args.client_password == "":
|
||||||
|
if args.provider_name == "aws":
|
||||||
|
args.client_password = "osworld-public-evaluation"
|
||||||
|
else:
|
||||||
|
args.client_password = "password"
|
||||||
|
else:
|
||||||
|
args.client_password = args.client_password
|
||||||
|
|
||||||
# save args to json in result_dir/action_space/observation_type/model/args.json
|
# save args to json in result_dir/action_space/observation_type/model/args.json
|
||||||
path_to_args = os.path.join(
|
path_to_args = os.path.join(
|
||||||
|
|||||||
@@ -82,6 +82,14 @@ def config() -> argparse.Namespace:
|
|||||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json"
|
"--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# aws config
|
||||||
|
parser.add_argument(
|
||||||
|
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--client_password", type=str, default="", help="Client password"
|
||||||
|
)
|
||||||
|
|
||||||
# logging related
|
# logging related
|
||||||
parser.add_argument("--result_dir", type=str, default="./results")
|
parser.add_argument("--result_dir", type=str, default="./results")
|
||||||
parser.add_argument("--num_envs", type=int, default=20, help="Number of environments to run in parallel")
|
parser.add_argument("--num_envs", type=int, default=20, help="Number of environments to run in parallel")
|
||||||
@@ -91,6 +99,13 @@ def config() -> argparse.Namespace:
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
args = config() # Get command line arguments first
|
args = config() # Get command line arguments first
|
||||||
|
if args.client_password == "":
|
||||||
|
if args.provider_name == "aws":
|
||||||
|
args.client_password = "osworld-public-evaluation"
|
||||||
|
else:
|
||||||
|
args.client_password = "password"
|
||||||
|
else:
|
||||||
|
args.client_password = args.client_password
|
||||||
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
log_level = getattr(logging, args.log_level.upper())
|
log_level = getattr(logging, args.log_level.upper())
|
||||||
@@ -186,6 +201,8 @@ def run_env_tasks(task_queue, args, shared_scores):
|
|||||||
|
|
||||||
env = DesktopEnv(
|
env = DesktopEnv(
|
||||||
provider_name=args.provider_name,
|
provider_name=args.provider_name,
|
||||||
|
region=args.region,
|
||||||
|
client_password=args.client_password,
|
||||||
path_to_vm=args.path_to_vm,
|
path_to_vm=args.path_to_vm,
|
||||||
action_space=args.action_space,
|
action_space=args.action_space,
|
||||||
screen_size=(args.screen_width, args.screen_height),
|
screen_size=(args.screen_width, args.screen_height),
|
||||||
@@ -198,6 +215,7 @@ def run_env_tasks(task_queue, args, shared_scores):
|
|||||||
action_space=args.action_space,
|
action_space=args.action_space,
|
||||||
observation_type=args.observation_type,
|
observation_type=args.observation_type,
|
||||||
max_trajectory_length=args.max_trajectory_length,
|
max_trajectory_length=args.max_trajectory_length,
|
||||||
|
client_password=args.client_password,
|
||||||
gen_func=call_llm,
|
gen_func=call_llm,
|
||||||
)
|
)
|
||||||
logger.info(f"Process {current_process().name} started.")
|
logger.info(f"Process {current_process().name} started.")
|
||||||
|
|||||||
Reference in New Issue
Block a user