Update new func

This commit is contained in:
Timothyxxx
2024-03-17 22:25:13 +08:00
parent 7feeab8f6b
commit e156a20e3d
2 changed files with 41 additions and 23 deletions

View File

@@ -5,19 +5,20 @@ import os
import re import re
import time import time
import uuid import uuid
import openai
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import Dict, List from typing import Dict, List
from google.api_core.exceptions import InvalidArgument
import backoff import backoff
import dashscope import dashscope
import google.generativeai as genai import google.generativeai as genai
import openai
import requests import requests
from PIL import Image from PIL import Image
from google.api_core.exceptions import InvalidArgument
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
@@ -422,7 +423,6 @@ class PromptAgent:
# with open("messages.json", "w") as f: # with open("messages.json", "w") as f:
# f.write(json.dumps(messages, indent=4)) # f.write(json.dumps(messages, indent=4))
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({ response = self.call_llm({
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
@@ -461,7 +461,7 @@ class PromptAgent:
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
} }
# logger.info("Generating content with GPT model: %s", self.model) logger.info("Generating content with GPT model: %s", self.model)
response = requests.post( response = requests.post(
"https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions",
headers=headers, headers=headers,
@@ -495,7 +495,7 @@ class PromptAgent:
temperature = payload["temperature"] temperature = payload["temperature"]
claude_messages = [] claude_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
claude_message = { claude_message = {
"role": message["role"], "role": message["role"],
@@ -503,17 +503,17 @@ class PromptAgent:
} }
assert len(message["content"]) in [1, 2], "One text, or one text with one image" assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]: for part in message["content"]:
if part['type'] == "image_url": if part['type'] == "image_url":
image_source = {} image_source = {}
image_source["type"] = "base64" image_source["type"] = "base64"
image_source["media_type"] = "image/png" image_source["media_type"] = "image/png"
image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "") image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "")
claude_message['content'].append({"type": "image", "source": image_source}) claude_message['content'].append({"type": "image", "source": image_source})
if part['type'] == "text": if part['type'] == "text":
claude_message['content'].append({"type": "text", "text": part['text']}) claude_message['content'].append({"type": "text", "text": part['text']})
claude_messages.append(claude_message) claude_messages.append(claude_message)
# the claude not support system message in our endpoint, so we concatenate it at the first user message # the claude not support system message in our endpoint, so we concatenate it at the first user message
@@ -522,7 +522,6 @@ class PromptAgent:
claude_messages[1]['content'].insert(0, claude_system_message_item) claude_messages[1]['content'].insert(0, claude_system_message_item)
claude_messages.pop(0) claude_messages.pop(0)
headers = { headers = {
"x-api-key": os.environ["ANTHROPIC_API_KEY"], "x-api-key": os.environ["ANTHROPIC_API_KEY"],
"anthropic-version": "2023-06-01", "anthropic-version": "2023-06-01",
@@ -540,7 +539,7 @@ class PromptAgent:
headers=headers, headers=headers,
json=payload json=payload
) )
if response.status_code != 200: if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text) logger.error("Failed to call LLM: " + response.text)
@@ -674,6 +673,7 @@ class PromptAgent:
try: try:
return response.text return response.text
except Exception as e: except Exception as e:
logger.error("Meet exception when calling Gemini API, " + str(e))
return "" return ""
elif self.model.startswith("qwen"): elif self.model.startswith("qwen"):
messages = payload["messages"] messages = payload["messages"]

42
run.py
View File

@@ -6,6 +6,7 @@ import datetime
import json import json
import logging import logging
import os import os
import random
import sys import sys
from tqdm import tqdm from tqdm import tqdm
@@ -69,7 +70,7 @@ def config() -> argparse.Namespace:
"screenshot_a11y_tree", "screenshot_a11y_tree",
"som" "som"
], ],
default="som", default="a11y_tree",
help="Observation type", help="Observation type",
) )
parser.add_argument("--screen_width", type=int, default=1920) parser.add_argument("--screen_width", type=int, default=1920)
@@ -82,7 +83,7 @@ def config() -> argparse.Namespace:
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples") parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
# lm config # lm config
parser.add_argument("--model", type=str, default="gpt-4-vision-preview") parser.add_argument("--model", type=str, default="gpt-4-0125-preview")
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500) parser.add_argument("--max_tokens", type=int, default=1500)
@@ -147,7 +148,7 @@ def test(
try: try:
lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir,
scores) scores)
except Exception as e: except TimeoutError as e:
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
logger.error(f"Time limit exceeded in {domain}/{example_id}") logger.error(f"Time limit exceeded in {domain}/{example_id}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
@@ -155,6 +156,14 @@ def test(
"Error": f"Time limit exceeded in {domain}/{example_id}" "Error": f"Time limit exceeded in {domain}/{example_id}"
})) }))
f.write("\n") f.write("\n")
except Exception as e:
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
logger.error(f"Exception in {domain}/{example_id}" + str(e))
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"Error": f"Exception in {domain}/{example_id}" + str(e)
}))
f.write("\n")
env.close() env.close()
logger.info(f"Average score: {sum(scores) / len(scores)}") logger.info(f"Average score: {sum(scores) / len(scores)}")
@@ -193,15 +202,13 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
def get_result(action_space, use_model, observation_type, result_dir, total_file_json): def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model) target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = [] all_result = []
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir): for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain) domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path): if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path): for example_id in os.listdir(domain_path):
@@ -209,10 +216,17 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file
if os.path.isdir(example_path): if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path): if "result.txt" in os.listdir(example_path):
# empty all files under example_id # empty all files under example_id
all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read())) try:
all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read()))
except:
all_result.append(0.0)
print("Success Rate:", sum(all_result) / len(all_result) * 100, "%") if not all_result:
return all_result print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == '__main__': if __name__ == '__main__':
@@ -242,4 +256,8 @@ if __name__ == '__main__':
test_all_meta test_all_meta
) )
# test(args, test_all_meta) # make the order of key random in test_all_meta
for domain in test_all_meta:
random.shuffle(test_all_meta[domain])
test(args, test_all_meta)