edit prompt
This commit is contained in:
@@ -690,10 +690,10 @@ class OpenAICUAAgent:
|
|||||||
state_correct = False
|
state_correct = False
|
||||||
# if action_exit and thought_exit:
|
# if action_exit and thought_exit:
|
||||||
# state_correct = True
|
# state_correct = True
|
||||||
if action_exit and not message_exit:
|
# if action_exit and not message_exit:
|
||||||
state_correct = True
|
# state_correct = True
|
||||||
# if action_exit:
|
if action_exit:
|
||||||
# state_correct = True
|
state_correct = True
|
||||||
if not state_correct:
|
if not state_correct:
|
||||||
logger.warning("The state of the agent is not correct, action_exit: %s, thought_exit: %s, message_exit: %s", action_exit, thought_exit, message_exit)
|
logger.warning("The state of the agent is not correct, action_exit: %s, thought_exit: %s, message_exit: %s", action_exit, thought_exit, message_exit)
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
# Do not write any secret keys or sensitive information here.
|
# Do not write any secret keys or sensitive information here.
|
||||||
|
|
||||||
# Monitor configuration
|
# Monitor configuration
|
||||||
TASK_CONFIG_PATH=../evaluation_examples/test_small.json
|
TASK_CONFIG_PATH=../evaluation_examples/test_all.json
|
||||||
EXAMPLES_BASE_PATH=../evaluation_examples/examples
|
EXAMPLES_BASE_PATH=../evaluation_examples/examples
|
||||||
RESULTS_BASE_PATH=../results_small_endmethod_ifmessage
|
RESULTS_BASE_PATH=../results_all_ifmessage_promptnochange
|
||||||
ACTION_SPACE=pyautogui
|
ACTION_SPACE=pyautogui
|
||||||
OBSERVATION_TYPE=screenshot
|
OBSERVATION_TYPE=screenshot
|
||||||
MODEL_NAME=computer-use-preview
|
MODEL_NAME=computer-use-preview
|
||||||
|
|||||||
@@ -28,35 +28,6 @@ if os.path.exists(".env"):
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Logger Configs {{{ #
|
# Logger Configs {{{ #
|
||||||
logger = logging.getLogger()
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
|
||||||
|
|
||||||
file_handler = logging.FileHandler(
|
|
||||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
|
||||||
)
|
|
||||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
|
||||||
|
|
||||||
|
|
||||||
file_handler.setLevel(logging.INFO)
|
|
||||||
stdout_handler.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
|
||||||
)
|
|
||||||
file_handler.setFormatter(formatter)
|
|
||||||
stdout_handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
|
||||||
|
|
||||||
logger.addHandler(file_handler)
|
|
||||||
logger.addHandler(stdout_handler)
|
|
||||||
# }}} Logger Configs #
|
|
||||||
|
|
||||||
logger = logging.getLogger("desktopenv.experiment")
|
|
||||||
|
|
||||||
|
|
||||||
def config() -> argparse.Namespace:
|
def config() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Run end-to-end evaluation on the benchmark"
|
description="Run end-to-end evaluation on the benchmark"
|
||||||
@@ -103,6 +74,8 @@ def config() -> argparse.Namespace:
|
|||||||
# logging related
|
# logging related
|
||||||
parser.add_argument("--result_dir", type=str, default="./results")
|
parser.add_argument("--result_dir", type=str, default="./results")
|
||||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||||
|
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||||
|
default='INFO', help="Set the logging level")
|
||||||
# aws config
|
# aws config
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||||
@@ -110,6 +83,42 @@ def config() -> argparse.Namespace:
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
args = config() # Get command line arguments first
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
log_level = getattr(logging, args.log_level.upper())
|
||||||
|
logger.setLevel(log_level)
|
||||||
|
|
||||||
|
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
debug_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
debug_handler.setLevel(logging.DEBUG)
|
||||||
|
stdout_handler.setLevel(log_level)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
debug_handler.setFormatter(formatter)
|
||||||
|
stdout_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(debug_handler)
|
||||||
|
logger.addHandler(stdout_handler)
|
||||||
|
# }}} Logger Configs #
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
|
||||||
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
||||||
"""Distribute tasks evenly across environments."""
|
"""Distribute tasks evenly across environments."""
|
||||||
|
|||||||
@@ -6,4 +6,4 @@ python run_multienv_openaicua.py \
|
|||||||
--test_all_meta_path evaluation_examples/test_all.json \
|
--test_all_meta_path evaluation_examples/test_all.json \
|
||||||
--region us-east-1 \
|
--region us-east-1 \
|
||||||
--max_steps 150 \
|
--max_steps 150 \
|
||||||
--num_envs 10
|
--num_envs 25
|
||||||
|
|||||||
Reference in New Issue
Block a user