diff --git a/demo.py b/demo.py index 736adfe..9b0bb06 100644 --- a/demo.py +++ b/demo.py @@ -1,16 +1,24 @@ -import signal +import concurrent.futures import time -def handler(signo, frame): - raise RuntimeError("Timeout") +# Define the function you want to run with a timeout +def my_task(): + print("Task started") + # Simulate a long-running task + time.sleep(5) + print("Task completed") + return "Task result" -signal.signal(signal.SIGALRM, handler) +# Main program +def main(): + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(my_task) + try: + # Wait for 2 seconds for my_task to complete + result = future.result(timeout=2) + print(f"Task completed with result: {result}") + except concurrent.futures.TimeoutError: + print("Task did not complete in time") -while True: - try: - signal.alarm(5) # seconds - time.sleep(10) - print("Working...") - except Exception as e : - print(e) - continue \ No newline at end of file +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/run.py b/run.py index 16c2bba..14cbe00 100644 --- a/run.py +++ b/run.py @@ -7,7 +7,9 @@ import json import logging import os import sys -import signal +# import signal +import time +import timeout_decorator from desktop_env.envs.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent @@ -47,9 +49,9 @@ logger.addHandler(sdebug_handler) logger = logging.getLogger("desktopenv.experiment") # make sure each example won't exceed the time limit -def handler(signo, frame): - raise RuntimeError("Time limit exceeded!") -signal.signal(signal.SIGALRM, handler) +# def handler(signo, frame): +# raise RuntimeError("Time limit exceeded!") +# signal.signal(signal.SIGALRM, handler) def config() -> argparse.Namespace: parser = argparse.ArgumentParser( @@ -148,9 +150,9 @@ def test( ) os.makedirs(example_result_dir, exist_ok=True) - # example start running - try: - signal.alarm(time_limit) + + @timeout_decorator.timeout(seconds=time_limit, timeout_exception=RuntimeError, exception_message="Time limit exceeded.") + def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores): agent.reset() obs = env.reset(task_config=example) done = False @@ -201,24 +203,20 @@ def test( logger.info("Result: %.2f", result) scores.append(result) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + + # example start running + try: + # signal.alarm(time_limit) + run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores) except RuntimeError as e: logger.error(f"Error in example {domain}/{example_id}: {e}") # save info of this example and then continue - try: - env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write(json.dumps({ - "Error": f"Error in example {domain}/{example_id}: {e}", - "step": step_idx + 1, - })) - f.write("\n") - except Exception as new_e: - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write(json.dumps({ - "Error": f"Error in example {domain}/{example_id}: {e} and {new_e}", - "step": "before start recording", - })) - f.write("\n") + env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write(json.dumps({ + "Error": f"Error in example {domain}/{example_id}: {e}" + })) + f.write("\n") continue env.close() logger.info(f"Average score: {sum(scores) / len(scores)}") @@ -232,9 +230,18 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_ finished = {} for domain in os.listdir(target_dir): + finished[domain] = [] domain_path = os.path.join(target_dir, domain) if os.path.isdir(domain_path): - finished[domain] = os.listdir(domain_path) + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) if not finished: return total_file_json @@ -259,4 +266,5 @@ if __name__ == '__main__': left_info += f"{domain}: {len(test_file_list[domain])}\n" logger.info(f"Left tasks:\n{left_info}") + os.environ['OPENAI_API_KEY'] = "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" test(args, test_all_meta) \ No newline at end of file