This commit is contained in:
Jason Lee
2024-03-15 23:26:04 +08:00
parent 1789a28657
commit afec1a3a23
2 changed files with 51 additions and 35 deletions

32
demo.py
View File

@@ -1,16 +1,24 @@
import signal
import concurrent.futures
import time
def handler(signo, frame):
raise RuntimeError("Timeout")
# Define the function you want to run with a timeout
def my_task():
print("Task started")
# Simulate a long-running task
time.sleep(5)
print("Task completed")
return "Task result"
signal.signal(signal.SIGALRM, handler)
# Main program
def main():
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(my_task)
try:
# Wait for 2 seconds for my_task to complete
result = future.result(timeout=2)
print(f"Task completed with result: {result}")
except concurrent.futures.TimeoutError:
print("Task did not complete in time")
while True:
try:
signal.alarm(5) # seconds
time.sleep(10)
print("Working...")
except Exception as e :
print(e)
continue
if __name__ == "__main__":
main()

54
run.py
View File

@@ -7,7 +7,9 @@ import json
import logging
import os
import sys
import signal
# import signal
import time
import timeout_decorator
from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.agent import PromptAgent
@@ -47,9 +49,9 @@ logger.addHandler(sdebug_handler)
logger = logging.getLogger("desktopenv.experiment")
# make sure each example won't exceed the time limit
def handler(signo, frame):
raise RuntimeError("Time limit exceeded!")
signal.signal(signal.SIGALRM, handler)
# def handler(signo, frame):
# raise RuntimeError("Time limit exceeded!")
# signal.signal(signal.SIGALRM, handler)
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
@@ -148,9 +150,9 @@ def test(
)
os.makedirs(example_result_dir, exist_ok=True)
# example start running
try:
signal.alarm(time_limit)
@timeout_decorator.timeout(seconds=time_limit, timeout_exception=RuntimeError, exception_message="Time limit exceeded.")
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
agent.reset()
obs = env.reset(task_config=example)
done = False
@@ -201,24 +203,20 @@ def test(
logger.info("Result: %.2f", result)
scores.append(result)
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
# example start running
try:
# signal.alarm(time_limit)
run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores)
except RuntimeError as e:
logger.error(f"Error in example {domain}/{example_id}: {e}")
# save info of this example and then continue
try:
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"Error": f"Error in example {domain}/{example_id}: {e}",
"step": step_idx + 1,
}))
f.write("\n")
except Exception as new_e:
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"Error": f"Error in example {domain}/{example_id}: {e} and {new_e}",
"step": "before start recording",
}))
f.write("\n")
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"Error": f"Error in example {domain}/{example_id}: {e}"
}))
f.write("\n")
continue
env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}")
@@ -232,9 +230,18 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
finished[domain] = os.listdir(domain_path)
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
@@ -259,4 +266,5 @@ if __name__ == '__main__':
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
os.environ['OPENAI_API_KEY'] = "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"
test(args, test_all_meta)