diff --git a/.vscode/launch.json b/.vscode/launch.json index bc0f472..cf0e7fc 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,8 +11,8 @@ "program": "${file}", "console": "integratedTerminal", "args": [ - "--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx", - "--example_time_limit", "60" + "--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx" + // "--example_time_limit", "60" ] } ] diff --git a/demo.py b/demo.py deleted file mode 100644 index 9b0bb06..0000000 --- a/demo.py +++ /dev/null @@ -1,24 +0,0 @@ -import concurrent.futures -import time - -# Define the function you want to run with a timeout -def my_task(): - print("Task started") - # Simulate a long-running task - time.sleep(5) - print("Task completed") - return "Task result" - -# Main program -def main(): - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(my_task) - try: - # Wait for 2 seconds for my_task to complete - result = future.result(timeout=2) - print(f"Task completed with result: {result}") - except concurrent.futures.TimeoutError: - print("Task did not complete in time") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/lib_run_single.py b/lib_run_single.py new file mode 100644 index 0000000..102a1bc --- /dev/null +++ b/lib_run_single.py @@ -0,0 +1,66 @@ +import os +import datetime +import json +import logging +from wrapt_timeout_decorator import * +logger = logging.getLogger("desktopenv.experiment") + +# Open the JSON file +with open("./settings.json", "r") as file: + # Load the JSON data from the file + data = json.load(file) +time_limit = data["time_limit"] + +@timeout(time_limit, use_signals=False) +def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores): + agent.reset() + obs = env.reset(task_config=example) + done = False + step_idx = 0 + env.controller.start_recording() + + while not done and step_idx < max_steps: + actions = agent.predict( + instruction, + obs + ) + for action in actions: + # Capture the timestamp before executing the action + action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + logger.info("Step %d: %s", step_idx + 1, action) + + obs, reward, done, info = env.step(action, args.sleep_after_execution) + + logger.info("Reward: %.2f", reward) + logger.info("Done: %s", done) + logger.info("Info: %s", info) + + # Save screenshot and trajectory information + with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), + "wb") as _f: + with open(obs['screenshot'], "rb") as __f: + screenshot = __f.read() + _f.write(screenshot) + + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write(json.dumps({ + "step_num": step_idx + 1, + "action_timestamp": action_timestamp, + "action": action, + "reward": reward, + "done": done, + "info": info, + "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png" + })) + f.write("\n") + + if done: + logger.info("The episode is done.") + break + step_idx += 1 + result = env.evaluate() + logger.info("Result: %.2f", result) + scores.append(result) + with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: + f.write(f"{result}\n") + env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) diff --git a/run.py b/run.py index 0eb5116..719222a 100644 --- a/run.py +++ b/run.py @@ -8,13 +8,12 @@ import logging import os import sys -from tqdm # import tqdm +from tqdm import tqdm import time -import timeout_decorator from desktop_env.envs.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent - +import lib_run_single # Logger Configs {{{ # logger = logging.getLogger() logger.setLevel(logging.DEBUG) @@ -49,12 +48,6 @@ logger.addHandler(sdebug_handler) logger = logging.getLogger("desktopenv.experiment") - -# make sure each example won't exceed the time limit -# def handler(signo, frame): -# raise RuntimeError("Time limit exceeded!") -# signal.signal(signal.SIGALRM, handler) - def config() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run end-to-end evaluation on the benchmark" @@ -151,80 +144,17 @@ def test( example_id ) os.makedirs(example_result_dir, exist_ok=True) - - - @timeout_decorator.timeout(seconds=time_limit, timeout_exception=RuntimeError, exception_message="Time limit exceeded.") - def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores): - agent.reset() - obs = env.reset(task_config=example) - done = False - step_idx = 0 - env.controller.start_recording() - - while not done and step_idx < max_steps: - actions = agent.predict( - instruction, - obs - ) - for action in actions: - # Capture the timestamp before executing the action - action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") - logger.info("Step %d: %s", step_idx + 1, action) - - obs, reward, done, info = env.step(action, args.sleep_after_execution) - - logger.info("Reward: %.2f", reward) - logger.info("Done: %s", done) - logger.info("Info: %s", info) - - # Save screenshot and trajectory information - with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), - "wb") as _f: - with open(obs['screenshot'], "rb") as __f: - screenshot = __f.read() - _f.write(screenshot) - - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write(json.dumps({ - "step_num": step_idx + 1, - "action_timestamp": action_timestamp, - "action": action, - "reward": reward, - "done": done, - "info": info, - "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png" - })) - f.write("\n") - - if done: - logger.info("The episode is done.") - break - step_idx += 1 - - result = env.evaluate() - logger.info("Result: %.2f", result) - scores.append(result) - with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: - f.write(f"{result}\n") - env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) - # example start running try: - # signal.alarm(time_limit) - run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores) - except RuntimeError as e: - logger.error(f"Error in example {domain}/{example_id}: {e}") - # save info of this example and then continue + lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores) + except Exception as e: env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + logger.error(f"Time limit exceeded in {domain}/{example_id}") with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({ - "Error": f"Error in example {domain}/{example_id}: {e}" + "Error": f"Time limit exceeded in {domain}/{example_id}" })) - f.write("\n") - continue - except Exception as e: - logger.error(f"Error in example {domain}/{example_id}: {e}") - continue + f.write("\n") env.close() logger.info(f"Average score: {sum(scores) / len(scores)}") @@ -281,5 +211,5 @@ if __name__ == '__main__': for domain in test_file_list: left_info += f"{domain}: {len(test_file_list[domain])}\n" logger.info(f"Left tasks:\n{left_info}") - + os.environ["OPENAI_API_KEY"] = "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" test(args, test_all_meta) diff --git a/settings.json b/settings.json new file mode 100644 index 0000000..9b87bac --- /dev/null +++ b/settings.json @@ -0,0 +1,3 @@ +{ + "time_limit": "60" +} \ No newline at end of file