feat: add flexible recording control and improve execution logging
This commit is contained in:
35
run.py
35
run.py
@@ -86,6 +86,7 @@ def config() -> argparse.Namespace:
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
parser.add_argument("--enable_recording", action="store_true", help="Enable video recording (disabled by default)")
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||
@@ -94,10 +95,10 @@ def config() -> argparse.Namespace:
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="gpt-4o")
|
||||
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--max_tokens", type=int, default=16384)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
# example config
|
||||
@@ -147,6 +148,8 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
screen_width=args.screen_width,
|
||||
screen_height=args.screen_height,
|
||||
)
|
||||
|
||||
env = DesktopEnv(
|
||||
@@ -155,11 +158,31 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
action_space=agent.action_space,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type = "Ubuntu",
|
||||
os_type = "Windows",
|
||||
require_a11y_tree=args.observation_type
|
||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
)
|
||||
|
||||
# get actual VM screen size after environment initialization
|
||||
try:
|
||||
actual_screen_size = env.vm_screen_size
|
||||
if actual_screen_size and 'width' in actual_screen_size and 'height' in actual_screen_size:
|
||||
actual_width = actual_screen_size['width']
|
||||
actual_height = actual_screen_size['height']
|
||||
logger.info(f"Actual VM screen size: {actual_width}x{actual_height}")
|
||||
|
||||
# update agent's screen size if different
|
||||
if actual_width != args.screen_width or actual_height != args.screen_height:
|
||||
logger.warning(f"Screen size mismatch! Expected: {args.screen_width}x{args.screen_height}, Actual: {actual_width}x{actual_height}")
|
||||
agent.screen_width = actual_width
|
||||
agent.screen_height = actual_height
|
||||
# replace in system message as well
|
||||
agent.system_message = agent.system_message.replace(
|
||||
f"({args.screen_width}, {args.screen_height})",
|
||||
f"({actual_width}, {actual_height})"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to get actual VM screen size: {e}")
|
||||
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||
config_file = os.path.join(
|
||||
@@ -204,8 +227,8 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
||||
# Only attempt to end recording if controller exists (not Docker provider)
|
||||
if hasattr(env, 'controller') and env.controller is not None:
|
||||
# Only attempt to end recording if controller exists (not Docker provider) and recording is enabled
|
||||
if args.enable_recording and hasattr(env, 'controller') and env.controller is not None:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
@@ -217,7 +240,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
)
|
||||
f.write("\n")
|
||||
|
||||
env.close()
|
||||
# env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user