filter unfinished examples and add timer to ensure upper limit of each example

This commit is contained in:
Jason Lee
2024-03-15 16:52:17 +08:00
parent f6b96165e2
commit 815c7ab67c
5 changed files with 166 additions and 59 deletions

19
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,19 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File with Arguments",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"args": [
"--path_to_vm", "/Users/lxc/Virtual Machines.localized/DesktopEnv-Ubuntu 64-bit Arm.vmwarevm/DesktopEnv-Ubuntu 64-bit Arm.vmx",
"--example_time_limit", "60"
]
}
]
}

16
demo.py Normal file
View File

@@ -0,0 +1,16 @@
import signal
import time
def handler(signo, frame):
raise RuntimeError("Timeout")
signal.signal(signal.SIGALRM, handler)
while True:
try:
signal.alarm(5) # seconds
time.sleep(10)
print("Working...")
except Exception as e :
print(e)
continue

View File

@@ -0,0 +1,19 @@
import pandas as pd
file_path = "/Users/lxc/Downloads/Speedtest.csv"
# 找到csv第二行的第二个数据格里的值
# with open(file_path, "r") as f:
# for i, line in enumerate(f):
# if i == 1:
# data = line.split(",")[1]
# break
# print(data)
with open(file_path, "r") as f:
reader = pd.read_csv(f, sep=',', header=None)
# for column in reader.columns:
# if column.startswith("TEST_DATE"):
# data_col = column
# break
for data in reader['TEST_DATE']:
print(data)

View File

@@ -5,10 +5,12 @@ import os
import re import re
import time import time
import uuid import uuid
import openai
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import Dict, List from typing import Dict, List
from google.api_core.exceptions import InvalidArgument
import backoff import backoff
import dashscope import dashscope
@@ -513,7 +515,7 @@ class PromptAgent:
try: try:
actions = self.parse_actions(response, masks) actions = self.parse_actions(response, masks)
self.thoughts.append(response) self.thoughts.append(response)
except Exception as e: except ValueError as e:
print("Failed to parse action from response", e) print("Failed to parse action from response", e)
actions = None actions = None
self.thoughts.append("") self.thoughts.append("")
@@ -522,9 +524,16 @@ class PromptAgent:
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.expo,
(Exception), # here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
(openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
InvalidArgument),
max_tries=5 max_tries=5
) )
def call_llm(self, payload): def call_llm(self, payload):
if self.model.startswith("gpt"): if self.model.startswith("gpt"):
@@ -532,7 +541,7 @@ class PromptAgent:
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
} }
logger.info("Generating content with GPT model: %s", self.model) # logger.info("Generating content with GPT model: %s", self.model)
response = requests.post( response = requests.post(
"https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions",
headers=headers, headers=headers,

156
run.py
View File

@@ -7,6 +7,7 @@ import json
import logging import logging
import os import os
import sys import sys
import signal
from desktop_env.envs.desktop_env import DesktopEnv from desktop_env.envs.desktop_env import DesktopEnv
from mm_agents.agent import PromptAgent from mm_agents.agent import PromptAgent
@@ -45,6 +46,10 @@ logger.addHandler(sdebug_handler)
logger = logging.getLogger("desktopenv.experiment") logger = logging.getLogger("desktopenv.experiment")
# make sure each example won't exceed the time limit
def handler(signo, frame):
raise RuntimeError("Time limit exceeded!")
signal.signal(signal.SIGALRM, handler)
def config() -> argparse.Namespace: def config() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@@ -77,6 +82,7 @@ def config() -> argparse.Namespace:
# agent config # agent config
parser.add_argument("--max_trajectory_length", type=int, default=3) parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples") parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
parser.add_argument("--example_time_limit", type=int, default=600)
# lm config # lm config
parser.add_argument("--model", type=str, default="gpt-4-vision-preview") parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
@@ -98,6 +104,7 @@ def test(
) -> None: ) -> None:
scores = [] scores = []
max_steps = args.max_steps max_steps = args.max_steps
time_limit = args.example_time_limit
# log args # log args
logger.info("Args: %s", args) logger.info("Args: %s", args)
@@ -119,6 +126,7 @@ def test(
for domain in test_all_meta: for domain in test_all_meta:
for example_id in test_all_meta[domain]: for example_id in test_all_meta[domain]:
# example setting
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
with open(config_file, "r", encoding="utf-8") as f: with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f) example = json.load(f)
@@ -140,79 +148,115 @@ def test(
) )
os.makedirs(example_result_dir, exist_ok=True) os.makedirs(example_result_dir, exist_ok=True)
agent.reset() # example start running
obs = env.reset(task_config=example) try:
done = False signal.alarm(time_limit)
step_idx = 0 agent.reset()
env.controller.start_recording() obs = env.reset(task_config=example)
done = False
step_idx = 0
env.controller.start_recording()
while not done and step_idx < max_steps: while not done and step_idx < max_steps:
actions = agent.predict( actions = agent.predict(
instruction, instruction,
obs obs
) )
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action)
for action in actions: observation, reward, done, info = env.step(action, args.sleep_after_execution)
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
logger.info("Info: %s", info)
# Save screenshot and trajectory information
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f:
with open(observation['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"step_num": step_idx + 1,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
step_idx += 1 step_idx += 1
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") result = env.evaluate()
logger.info("Step %d: %s", step_idx + 1, action) logger.info("Result: %.2f", result)
scores.append(result)
observation, reward, done, info = env.step(action, args.sleep_after_execution) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
except RuntimeError as e:
logger.info("Reward: %.2f", reward) logger.error(f"Error in example {domain}/{example_id}: {e}")
logger.info("Done: %s", done) # save info of this example and then continue
logger.info("Info: %s", info) try:
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
# Save screenshot and trajectory information with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f:
with open(observation['screenshot'], "rb") as __f:
screenshot = __f.read()
_f.write(screenshot)
with open(os.path.join(example_result_dir, "traj.json"), "a") as f:
f.write(json.dumps({ f.write(json.dumps({
"step_num": step_idx + 1, "Error": f"Error in example {domain}/{example_id}: {e}",
"action_timestamp": action_timestamp, "step": step_idx + 1,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
})) }))
f.write("\n") f.write("\n")
except Exception as new_e:
if done: with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
logger.info("The episode is done.") f.write(json.dumps({
break "Error": f"Error in example {domain}/{example_id}: {e} and {new_e}",
"step": "before start recording",
result = env.evaluate() }))
logger.info("Result: %.2f", result) f.write("\n")
scores.append(result) continue
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
env.close() env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}") logger.info(f"Average score: {sum(scores) / len(scores)}")
def get_unfinished(test_file_list, result_dir): def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json):
finished = [] target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
for domain in os.listdir(result_dir):
for example_id in os.listdir(os.path.join(result_dir, domain)): if not os.path.exists(target_dir):
finished.append(f"{domain}/{example_id}") return total_file_json
return [x for x in test_file_list if x not in finished]
finished = {}
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
finished[domain] = os.listdir(domain_path)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples]
return total_file_json
if __name__ == '__main__': if __name__ == '__main__':
####### The complete version of the list of examples ####### ####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config() args = config()
# test_file_list = get_unfinished(args.test, args.result_dir)
# logger.info(f"Total {len(test_file_list)} tasks left")
with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f: with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f:
test_all_meta = json.load(f) test_all_meta = json.load(f)
test(args, test_all_meta) test_file_list = get_unfinished(args.action_space, args.model, args.observation_type, args.result_dir, test_all_meta)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
test(args, test_all_meta)