From 81580a1bbce9e23684fafb18297c44e4eccff115 Mon Sep 17 00:00:00 2001 From: rhythmcao Date: Fri, 15 Mar 2024 22:09:24 +0800 Subject: [PATCH] fix incompatible errors in main.py (temporarily fixup, will be dropped in future after snapshot download is ok) --- main.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 93282ec..06debec 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ import logging import os import sys import time - +import argparse from desktop_env.envs.desktop_env import DesktopEnv # Logger Configs {{{ # @@ -46,19 +46,29 @@ def human_agent(): """ Runs the Gym environment with human input. """ + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.") + parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.") + parser.add_argument('-e', '--example', type=str, help="Path to the example json file.") + args = parser.parse_args(sys.argv[1:]) - with open("evaluation_examples/examples/multi_apps/4c26e3f3-3a14-4d86-b44a-d3cedebbb487.json", "r", encoding="utf-8") as f: + example_path = args.example if args.example is not None and os.path.exists(args.example) else \ + 'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json' + with open(example_path, "r", encoding="utf-8") as f: example = json.load(f) - example["snapshot"] = "exp_v5" + if args.snapshot is not None: + example['snapshot'] = args.snapshot + assert os.path.exists(args.path), "The specified path to the .vmx file does not exist." env = DesktopEnv( - path_to_vm=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", - action_space="computer_13", - task_config=example + path_to_vm=args.path, + snapshot_name=args.snapshot, + action_space="computer_13" ) # reset the environment to certain snapshot - observation = env.reset() + observation = env.reset(task_config=example) done = False + logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"]) trajectory = [ {