feat: add flexible recording control and improve execution logging
This commit is contained in:
@@ -101,7 +101,7 @@ class DesktopEnv(gym.Env):
|
|||||||
provider_name: str = "vmware",
|
provider_name: str = "vmware",
|
||||||
region: str = None,
|
region: str = None,
|
||||||
path_to_vm: str = None,
|
path_to_vm: str = None,
|
||||||
snapshot_name: str = "init_state",
|
snapshot_name: str = "snapshot",
|
||||||
action_space: str = "pyautogui",
|
action_space: str = "pyautogui",
|
||||||
cache_dir: str = "cache",
|
cache_dir: str = "cache",
|
||||||
screen_size: Tuple[int] = (int(os.environ.get("SCREEN_WIDTH", 1920)), int(os.environ.get("SCREEN_HEIGHT", 1080))),
|
screen_size: Tuple[int] = (int(os.environ.get("SCREEN_WIDTH", 1920)), int(os.environ.get("SCREEN_HEIGHT", 1080))),
|
||||||
@@ -117,7 +117,7 @@ class DesktopEnv(gym.Env):
|
|||||||
provider_name (str): virtualization provider name, default to "vmware"
|
provider_name (str): virtualization provider name, default to "vmware"
|
||||||
region (str): the region for allocate machines, work for cloud services, default to "us-east-1"
|
region (str): the region for allocate machines, work for cloud services, default to "us-east-1"
|
||||||
path_to_vm (str): path to .vmx file
|
path_to_vm (str): path to .vmx file
|
||||||
snapshot_name (str): snapshot name to revert to, default to "init_state"
|
snapshot_name (str): snapshot name to revert to, default to "snapshot"
|
||||||
action_space (str): "computer_13" | "pyautogui"
|
action_space (str): "computer_13" | "pyautogui"
|
||||||
cache_dir (str): cache directory to cache task-related stuffs like
|
cache_dir (str): cache directory to cache task-related stuffs like
|
||||||
reference file for evaluation
|
reference file for evaluation
|
||||||
@@ -265,7 +265,7 @@ class DesktopEnv(gym.Env):
|
|||||||
self.current_use_proxy = task_use_proxy
|
self.current_use_proxy = task_use_proxy
|
||||||
|
|
||||||
if self.is_environment_used:
|
if self.is_environment_used:
|
||||||
logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name))
|
logger.info("Environment has been used, reverting to snapshot: {}...".format(self.snapshot_name))
|
||||||
self._revert_to_snapshot()
|
self._revert_to_snapshot()
|
||||||
logger.info("Starting emulator...")
|
logger.info("Starting emulator...")
|
||||||
self._start_emulator()
|
self._start_emulator()
|
||||||
@@ -402,6 +402,7 @@ class DesktopEnv(gym.Env):
|
|||||||
|
|
||||||
if self.action_space == "computer_13":
|
if self.action_space == "computer_13":
|
||||||
# the set of all possible actions defined in the action representation
|
# the set of all possible actions defined in the action representation
|
||||||
|
logger.info(f"======executing here======{self.action_space}========================")
|
||||||
self.controller.execute_action(action)
|
self.controller.execute_action(action)
|
||||||
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
|
elif self.action_space == "pyautogui" or self.action_space == "claude_computer_use":
|
||||||
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']):
|
if action in ['WAIT', 'FAIL', 'DONE'] or (type(action) == dict and action.get('action_type') in ['WAIT', 'FAIL', 'DONE']):
|
||||||
@@ -411,6 +412,8 @@ class DesktopEnv(gym.Env):
|
|||||||
if type(action) == str:
|
if type(action) == str:
|
||||||
# Fix PyAutoGUI '<' character bug before execution
|
# Fix PyAutoGUI '<' character bug before execution
|
||||||
fixed_command = _fix_pyautogui_less_than_bug(action)
|
fixed_command = _fix_pyautogui_less_than_bug(action)
|
||||||
|
logger.info(f"======executing here======{self.action_space}========================")
|
||||||
|
logger.info(f"Fixed command: {fixed_command}")
|
||||||
self.controller.execute_python_command(fixed_command)
|
self.controller.execute_python_command(fixed_command)
|
||||||
elif type(action) == dict:
|
elif type(action) == dict:
|
||||||
# Fix PyAutoGUI '<' character bug before execution
|
# Fix PyAutoGUI '<' character bug before execution
|
||||||
|
|||||||
@@ -13,30 +13,44 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
|||||||
runtime_logger = setup_logger(example, example_result_dir)
|
runtime_logger = setup_logger(example, example_result_dir)
|
||||||
|
|
||||||
# Reset environment first to get fresh VM IP
|
# Reset environment first to get fresh VM IP
|
||||||
env.reset(task_config=example)
|
# env.reset(task_config=example)
|
||||||
|
# logger.info("=======Environment reset completed=======")
|
||||||
|
|
||||||
# Reset agent with fresh VM IP (for snapshot reverts)
|
# # Reset agent with fresh VM IP (for snapshot reverts)
|
||||||
try:
|
# try:
|
||||||
agent.reset(runtime_logger, vm_ip=env.vm_ip)
|
# agent.reset(runtime_logger, vm_ip=env.vm_ip)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
agent.reset(vm_ip=env.vm_ip)
|
# agent.reset(vm_ip=env.vm_ip)
|
||||||
|
|
||||||
time.sleep(60) # Wait for the environment to be ready
|
# time.sleep(10) # Wait for the environment to be ready
|
||||||
|
|
||||||
|
# get initial observation
|
||||||
|
logger.info("Getting initial observation...")
|
||||||
obs = env._get_obs() # Get the initial observation
|
obs = env._get_obs() # Get the initial observation
|
||||||
|
logger.info("Initial observation obtained.")
|
||||||
done = False
|
done = False
|
||||||
step_idx = 0
|
step_idx = 0
|
||||||
env.controller.start_recording()
|
if getattr(args, 'enable_recording', False):
|
||||||
|
env.controller.start_recording()
|
||||||
while not done and step_idx < max_steps:
|
while not done and step_idx < max_steps:
|
||||||
|
logger.info(f"Step {step_idx + 1} prediction...")
|
||||||
response, actions = agent.predict(
|
response, actions = agent.predict(
|
||||||
instruction,
|
instruction,
|
||||||
obs
|
obs
|
||||||
)
|
)
|
||||||
|
logger.info(f"Response: {response}")
|
||||||
|
logger.info(f"Actions: {actions}")
|
||||||
|
|
||||||
|
logger.info(f"Executing actions...")
|
||||||
for action in actions:
|
for action in actions:
|
||||||
# Capture the timestamp before executing the action
|
# Capture the timestamp before executing the action
|
||||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S%f")
|
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S%f")
|
||||||
logger.info("Step %d: %s", step_idx + 1, action)
|
logger.info("Step %d: %s", step_idx + 1, action)
|
||||||
|
|
||||||
|
logger.info("执行动作中...")
|
||||||
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
obs, reward, done, info = env.step(action, args.sleep_after_execution)
|
||||||
|
logger.info("动作执行完成。")
|
||||||
|
|
||||||
logger.info("Reward: %.2f", reward)
|
logger.info("Reward: %.2f", reward)
|
||||||
logger.info("Done: %s", done)
|
logger.info("Done: %s", done)
|
||||||
# Save screenshot and trajectory information
|
# Save screenshot and trajectory information
|
||||||
@@ -69,7 +83,8 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
|||||||
# Log task completion to results.json
|
# Log task completion to results.json
|
||||||
log_task_completion(example, result, example_result_dir, args)
|
log_task_completion(example, result, example_result_dir, args)
|
||||||
|
|
||||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
if getattr(args, 'enable_recording', False):
|
||||||
|
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(example, example_result_dir):
|
def setup_logger(example, example_result_dir):
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
from desktop_env.desktop_env import DesktopEnv
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from desktop_env.desktop_env import DesktopEnv
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
|
||||||
example = {
|
example = {
|
||||||
"id": "94d95f96-9699-4208-98ba-3c3119edf9c2",
|
"id": "94d95f96-9699-4208-98ba-3c3119edf9c2",
|
||||||
|
|||||||
35
run.py
35
run.py
@@ -86,6 +86,7 @@ def config() -> argparse.Namespace:
|
|||||||
parser.add_argument("--screen_height", type=int, default=1080)
|
parser.add_argument("--screen_height", type=int, default=1080)
|
||||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||||
parser.add_argument("--max_steps", type=int, default=15)
|
parser.add_argument("--max_steps", type=int, default=15)
|
||||||
|
parser.add_argument("--enable_recording", action="store_true", help="Enable video recording (disabled by default)")
|
||||||
|
|
||||||
# agent config
|
# agent config
|
||||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||||
@@ -94,10 +95,10 @@ def config() -> argparse.Namespace:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# lm config
|
# lm config
|
||||||
parser.add_argument("--model", type=str, default="gpt-4o")
|
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
|
||||||
parser.add_argument("--temperature", type=float, default=1.0)
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
parser.add_argument("--top_p", type=float, default=0.9)
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
parser.add_argument("--max_tokens", type=int, default=16384)
|
||||||
parser.add_argument("--stop_token", type=str, default=None)
|
parser.add_argument("--stop_token", type=str, default=None)
|
||||||
|
|
||||||
# example config
|
# example config
|
||||||
@@ -147,6 +148,8 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
action_space=args.action_space,
|
action_space=args.action_space,
|
||||||
observation_type=args.observation_type,
|
observation_type=args.observation_type,
|
||||||
max_trajectory_length=args.max_trajectory_length,
|
max_trajectory_length=args.max_trajectory_length,
|
||||||
|
screen_width=args.screen_width,
|
||||||
|
screen_height=args.screen_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
env = DesktopEnv(
|
env = DesktopEnv(
|
||||||
@@ -155,11 +158,31 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
action_space=agent.action_space,
|
action_space=agent.action_space,
|
||||||
screen_size=(args.screen_width, args.screen_height),
|
screen_size=(args.screen_width, args.screen_height),
|
||||||
headless=args.headless,
|
headless=args.headless,
|
||||||
os_type = "Ubuntu",
|
os_type = "Windows",
|
||||||
require_a11y_tree=args.observation_type
|
require_a11y_tree=args.observation_type
|
||||||
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# get actual VM screen size after environment initialization
|
||||||
|
try:
|
||||||
|
actual_screen_size = env.vm_screen_size
|
||||||
|
if actual_screen_size and 'width' in actual_screen_size and 'height' in actual_screen_size:
|
||||||
|
actual_width = actual_screen_size['width']
|
||||||
|
actual_height = actual_screen_size['height']
|
||||||
|
logger.info(f"Actual VM screen size: {actual_width}x{actual_height}")
|
||||||
|
|
||||||
|
# update agent's screen size if different
|
||||||
|
if actual_width != args.screen_width or actual_height != args.screen_height:
|
||||||
|
logger.warning(f"Screen size mismatch! Expected: {args.screen_width}x{args.screen_height}, Actual: {actual_width}x{actual_height}")
|
||||||
|
agent.screen_width = actual_width
|
||||||
|
agent.screen_height = actual_height
|
||||||
|
# replace in system message as well
|
||||||
|
agent.system_message = agent.system_message.replace(
|
||||||
|
f"({args.screen_width}, {args.screen_height})",
|
||||||
|
f"({actual_width}, {actual_height})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Unable to get actual VM screen size: {e}")
|
||||||
for domain in tqdm(test_all_meta, desc="Domain"):
|
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||||
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||||
config_file = os.path.join(
|
config_file = os.path.join(
|
||||||
@@ -204,8 +227,8 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
||||||
# Only attempt to end recording if controller exists (not Docker provider)
|
# Only attempt to end recording if controller exists (not Docker provider) and recording is enabled
|
||||||
if hasattr(env, 'controller') and env.controller is not None:
|
if args.enable_recording and hasattr(env, 'controller') and env.controller is not None:
|
||||||
env.controller.end_recording(
|
env.controller.end_recording(
|
||||||
os.path.join(example_result_dir, "recording.mp4")
|
os.path.join(example_result_dir, "recording.mp4")
|
||||||
)
|
)
|
||||||
@@ -217,7 +240,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
)
|
)
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
|
|
||||||
env.close()
|
# env.close()
|
||||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user