This commit is contained in:
Jason Lee
2024-03-15 23:26:04 +08:00
parent 1789a28657
commit afec1a3a23
2 changed files with 51 additions and 35 deletions

32
demo.py
View File

@@ -1,16 +1,24 @@
import signal import concurrent.futures
import time import time
def handler(signo, frame): # Define the function you want to run with a timeout
raise RuntimeError("Timeout") def my_task():
print("Task started")
# Simulate a long-running task
time.sleep(5)
print("Task completed")
return "Task result"
signal.signal(signal.SIGALRM, handler) # Main program
def main():
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(my_task)
try:
# Wait for 2 seconds for my_task to complete
result = future.result(timeout=2)
print(f"Task completed with result: {result}")
except concurrent.futures.TimeoutError:
print("Task did not complete in time")
while True: if __name__ == "__main__":
try: main()
signal.alarm(5) # seconds
time.sleep(10)
print("Working...")
except Exception as e :
print(e)
continue

54
run.py
View File

@@ -7,7 +7,9 @@ import json
import logging import logging
import os import os
import sys import sys
import signal # import signal
import time
import timeout_decorator
from desktop_env.envs.desktop_env import DesktopEnv from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.agent import PromptAgent from mm_agents.agent import PromptAgent
@@ -47,9 +49,9 @@ logger.addHandler(sdebug_handler)
logger = logging.getLogger("desktopenv.experiment") logger = logging.getLogger("desktopenv.experiment")
# make sure each example won't exceed the time limit # make sure each example won't exceed the time limit
def handler(signo, frame): # def handler(signo, frame):
raise RuntimeError("Time limit exceeded!") # raise RuntimeError("Time limit exceeded!")
signal.signal(signal.SIGALRM, handler) # signal.signal(signal.SIGALRM, handler)
def config() -> argparse.Namespace: def config() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@@ -148,9 +150,9 @@ def test(
) )
os.makedirs(example_result_dir, exist_ok=True) os.makedirs(example_result_dir, exist_ok=True)
# example start running
try: @timeout_decorator.timeout(seconds=time_limit, timeout_exception=RuntimeError, exception_message="Time limit exceeded.")
signal.alarm(time_limit) def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
agent.reset() agent.reset()
obs = env.reset(task_config=example) obs = env.reset(task_config=example)
done = False done = False
@@ -201,24 +203,20 @@ def test(
logger.info("Result: %.2f", result) logger.info("Result: %.2f", result)
scores.append(result) scores.append(result)
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
# example start running
try:
# signal.alarm(time_limit)
run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores)
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Error in example {domain}/{example_id}: {e}") logger.error(f"Error in example {domain}/{example_id}: {e}")
# save info of this example and then continue # save info of this example and then continue
try: env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({
f.write(json.dumps({ "Error": f"Error in example {domain}/{example_id}: {e}"
"Error": f"Error in example {domain}/{example_id}: {e}", }))
"step": step_idx + 1, f.write("\n")
}))
f.write("\n")
except Exception as new_e:
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"Error": f"Error in example {domain}/{example_id}: {e} and {new_e}",
"step": "before start recording",
}))
f.write("\n")
continue continue
env.close() env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}") logger.info(f"Average score: {sum(scores) / len(scores)}")
@@ -232,9 +230,18 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
finished = {} finished = {}
for domain in os.listdir(target_dir): for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain) domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path): if os.path.isdir(domain_path):
finished[domain] = os.listdir(domain_path) for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished: if not finished:
return total_file_json return total_file_json
@@ -259,4 +266,5 @@ if __name__ == '__main__':
left_info += f"{domain}: {len(test_file_list[domain])}\n" left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}") logger.info(f"Left tasks:\n{left_info}")
os.environ['OPENAI_API_KEY'] = "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"
test(args, test_all_meta) test(args, test_all_meta)