Update new func
This commit is contained in:
@@ -5,19 +5,20 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import openai
|
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from google.api_core.exceptions import InvalidArgument
|
|
||||||
import backoff
|
import backoff
|
||||||
import dashscope
|
import dashscope
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
|
import openai
|
||||||
import requests
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from google.api_core.exceptions import InvalidArgument
|
||||||
|
|
||||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
|
||||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
||||||
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
||||||
@@ -422,7 +423,6 @@ class PromptAgent:
|
|||||||
# with open("messages.json", "w") as f:
|
# with open("messages.json", "w") as f:
|
||||||
# f.write(json.dumps(messages, indent=4))
|
# f.write(json.dumps(messages, indent=4))
|
||||||
|
|
||||||
logger.info("Generating content with GPT model: %s", self.model)
|
|
||||||
response = self.call_llm({
|
response = self.call_llm({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
@@ -461,7 +461,7 @@ class PromptAgent:
|
|||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
||||||
}
|
}
|
||||||
# logger.info("Generating content with GPT model: %s", self.model)
|
logger.info("Generating content with GPT model: %s", self.model)
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"https://api.openai.com/v1/chat/completions",
|
"https://api.openai.com/v1/chat/completions",
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@@ -495,7 +495,7 @@ class PromptAgent:
|
|||||||
temperature = payload["temperature"]
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
claude_messages = []
|
claude_messages = []
|
||||||
|
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
claude_message = {
|
claude_message = {
|
||||||
"role": message["role"],
|
"role": message["role"],
|
||||||
@@ -503,17 +503,17 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||||
for part in message["content"]:
|
for part in message["content"]:
|
||||||
|
|
||||||
if part['type'] == "image_url":
|
if part['type'] == "image_url":
|
||||||
image_source = {}
|
image_source = {}
|
||||||
image_source["type"] = "base64"
|
image_source["type"] = "base64"
|
||||||
image_source["media_type"] = "image/png"
|
image_source["media_type"] = "image/png"
|
||||||
image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "")
|
image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "")
|
||||||
claude_message['content'].append({"type": "image", "source": image_source})
|
claude_message['content'].append({"type": "image", "source": image_source})
|
||||||
|
|
||||||
if part['type'] == "text":
|
if part['type'] == "text":
|
||||||
claude_message['content'].append({"type": "text", "text": part['text']})
|
claude_message['content'].append({"type": "text", "text": part['text']})
|
||||||
|
|
||||||
claude_messages.append(claude_message)
|
claude_messages.append(claude_message)
|
||||||
|
|
||||||
# the claude not support system message in our endpoint, so we concatenate it at the first user message
|
# the claude not support system message in our endpoint, so we concatenate it at the first user message
|
||||||
@@ -522,7 +522,6 @@ class PromptAgent:
|
|||||||
claude_messages[1]['content'].insert(0, claude_system_message_item)
|
claude_messages[1]['content'].insert(0, claude_system_message_item)
|
||||||
claude_messages.pop(0)
|
claude_messages.pop(0)
|
||||||
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
|
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
|
||||||
"anthropic-version": "2023-06-01",
|
"anthropic-version": "2023-06-01",
|
||||||
@@ -540,7 +539,7 @@ class PromptAgent:
|
|||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload
|
json=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
|
|
||||||
logger.error("Failed to call LLM: " + response.text)
|
logger.error("Failed to call LLM: " + response.text)
|
||||||
@@ -674,6 +673,7 @@ class PromptAgent:
|
|||||||
try:
|
try:
|
||||||
return response.text
|
return response.text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error("Meet exception when calling Gemini API, " + str(e))
|
||||||
return ""
|
return ""
|
||||||
elif self.model.startswith("qwen"):
|
elif self.model.startswith("qwen"):
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
|
|||||||
42
run.py
42
run.py
@@ -6,6 +6,7 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -69,7 +70,7 @@ def config() -> argparse.Namespace:
|
|||||||
"screenshot_a11y_tree",
|
"screenshot_a11y_tree",
|
||||||
"som"
|
"som"
|
||||||
],
|
],
|
||||||
default="som",
|
default="a11y_tree",
|
||||||
help="Observation type",
|
help="Observation type",
|
||||||
)
|
)
|
||||||
parser.add_argument("--screen_width", type=int, default=1920)
|
parser.add_argument("--screen_width", type=int, default=1920)
|
||||||
@@ -82,7 +83,7 @@ def config() -> argparse.Namespace:
|
|||||||
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
|
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
|
||||||
|
|
||||||
# lm config
|
# lm config
|
||||||
parser.add_argument("--model", type=str, default="gpt-4-vision-preview")
|
parser.add_argument("--model", type=str, default="gpt-4-0125-preview")
|
||||||
parser.add_argument("--temperature", type=float, default=1.0)
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
parser.add_argument("--top_p", type=float, default=0.9)
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||||
@@ -147,7 +148,7 @@ def test(
|
|||||||
try:
|
try:
|
||||||
lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir,
|
lib_run_single.run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir,
|
||||||
scores)
|
scores)
|
||||||
except Exception as e:
|
except TimeoutError as e:
|
||||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||||
logger.error(f"Time limit exceeded in {domain}/{example_id}")
|
logger.error(f"Time limit exceeded in {domain}/{example_id}")
|
||||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
@@ -155,6 +156,14 @@ def test(
|
|||||||
"Error": f"Time limit exceeded in {domain}/{example_id}"
|
"Error": f"Time limit exceeded in {domain}/{example_id}"
|
||||||
}))
|
}))
|
||||||
f.write("\n")
|
f.write("\n")
|
||||||
|
except Exception as e:
|
||||||
|
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||||
|
logger.error(f"Exception in {domain}/{example_id}" + str(e))
|
||||||
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
|
f.write(json.dumps({
|
||||||
|
"Error": f"Exception in {domain}/{example_id}" + str(e)
|
||||||
|
}))
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||||
@@ -193,15 +202,13 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
|
|||||||
|
|
||||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
|
||||||
all_result = []
|
all_result = []
|
||||||
|
|
||||||
if not os.path.exists(target_dir):
|
|
||||||
return total_file_json
|
|
||||||
|
|
||||||
finished = {}
|
|
||||||
for domain in os.listdir(target_dir):
|
for domain in os.listdir(target_dir):
|
||||||
finished[domain] = []
|
|
||||||
domain_path = os.path.join(target_dir, domain)
|
domain_path = os.path.join(target_dir, domain)
|
||||||
if os.path.isdir(domain_path):
|
if os.path.isdir(domain_path):
|
||||||
for example_id in os.listdir(domain_path):
|
for example_id in os.listdir(domain_path):
|
||||||
@@ -209,10 +216,17 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file
|
|||||||
if os.path.isdir(example_path):
|
if os.path.isdir(example_path):
|
||||||
if "result.txt" in os.listdir(example_path):
|
if "result.txt" in os.listdir(example_path):
|
||||||
# empty all files under example_id
|
# empty all files under example_id
|
||||||
all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read()))
|
try:
|
||||||
|
all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read()))
|
||||||
|
except:
|
||||||
|
all_result.append(0.0)
|
||||||
|
|
||||||
print("Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
if not all_result:
|
||||||
return all_result
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||||
|
return all_result
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -242,4 +256,8 @@ if __name__ == '__main__':
|
|||||||
test_all_meta
|
test_all_meta
|
||||||
)
|
)
|
||||||
|
|
||||||
# test(args, test_all_meta)
|
# make the order of key random in test_all_meta
|
||||||
|
for domain in test_all_meta:
|
||||||
|
random.shuffle(test_all_meta[domain])
|
||||||
|
|
||||||
|
test(args, test_all_meta)
|
||||||
|
|||||||
Reference in New Issue
Block a user