add wandb settings

This commit is contained in:
Jason Lee
2024-03-17 22:31:43 +08:00
2 changed files with 122 additions and 70 deletions

View File

@@ -5,20 +5,21 @@ import os
import re import re
import time import time
import uuid import uuid
import openai
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from typing import Dict, List from typing import Dict, List
from google.api_core.exceptions import InvalidArgument
import backoff import backoff
import dashscope import dashscope
import google.generativeai as genai import google.generativeai as genai
import openai
import requests import requests
import wandb import wandb
from PIL import Image from PIL import Image
from google.api_core.exceptions import InvalidArgument
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
@@ -423,7 +424,6 @@ class PromptAgent:
# with open("messages.json", "w") as f: # with open("messages.json", "w") as f:
# f.write(json.dumps(messages, indent=4)) # f.write(json.dumps(messages, indent=4))
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({ response = self.call_llm({
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
@@ -462,7 +462,7 @@ class PromptAgent:
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
} }
# logger.info("Generating content with GPT model: %s", self.model) logger.info("Generating content with GPT model: %s", self.model)
response = requests.post( response = requests.post(
"https://api.openai.com/v1/chat/completions", "https://api.openai.com/v1/chat/completions",
headers=headers, headers=headers,
@@ -496,7 +496,7 @@ class PromptAgent:
temperature = payload["temperature"] temperature = payload["temperature"]
claude_messages = [] claude_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
claude_message = { claude_message = {
"role": message["role"], "role": message["role"],
@@ -504,17 +504,17 @@ class PromptAgent:
} }
assert len(message["content"]) in [1, 2], "One text, or one text with one image" assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]: for part in message["content"]:
if part['type'] == "image_url": if part['type'] == "image_url":
image_source = {} image_source = {}
image_source["type"] = "base64" image_source["type"] = "base64"
image_source["media_type"] = "image/png" image_source["media_type"] = "image/png"
image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "") image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "")
claude_message['content'].append({"type": "image", "source": image_source}) claude_message['content'].append({"type": "image", "source": image_source})
if part['type'] == "text": if part['type'] == "text":
claude_message['content'].append({"type": "text", "text": part['text']}) claude_message['content'].append({"type": "text", "text": part['text']})
claude_messages.append(claude_message) claude_messages.append(claude_message)
# the claude not support system message in our endpoint, so we concatenate it at the first user message # the claude not support system message in our endpoint, so we concatenate it at the first user message
@@ -523,7 +523,6 @@ class PromptAgent:
claude_messages[1]['content'].insert(0, claude_system_message_item) claude_messages[1]['content'].insert(0, claude_system_message_item)
claude_messages.pop(0) claude_messages.pop(0)
headers = { headers = {
"x-api-key": os.environ["ANTHROPIC_API_KEY"], "x-api-key": os.environ["ANTHROPIC_API_KEY"],
"anthropic-version": "2023-06-01", "anthropic-version": "2023-06-01",
@@ -541,7 +540,7 @@ class PromptAgent:
headers=headers, headers=headers,
json=payload json=payload
) )
if response.status_code != 200: if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text) logger.error("Failed to call LLM: " + response.text)
@@ -551,55 +550,101 @@ class PromptAgent:
return response.json()['content'][0]['text'] return response.json()['content'][0]['text']
# elif self.model.startswith("mistral"): elif self.model.startswith("mistral"):
# print("Call mistral") print("Call mistral")
# messages = payload["messages"] messages = payload["messages"]
# max_tokens = payload["max_tokens"] max_tokens = payload["max_tokens"]
# top_p = payload["top_p"]
# misrtal_messages = [] temperature = payload["temperature"]
#
# for i, message in enumerate(messages): misrtal_messages = []
# mistral_message = {
# "role": message["role"], for i, message in enumerate(messages):
# "content": [] mistral_message = {
# } "role": message["role"],
# "content": ""
# for part in message["content"]: }
# mistral_message['content'] = part['text'] if part['type'] == "text" else None
# for part in message["content"]:
# misrtal_messages.append(mistral_message) mistral_message['content'] = part['text'] if part['type'] == "text" else ""
#
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message
# if misrtal_messages[0]['role'] == "system": misrtal_messages.append(mistral_message)
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
# misrtal_messages.pop(0)
# # openai.api_base = "http://localhost:8000/v1"
# # openai.api_base = "http://localhost:8000/v1" # response = openai.ChatCompletion.create(
# # openai.api_key = "test" # messages=misrtal_messages,
# # response = openai.ChatCompletion.create( # model="Mixtral-8x7B-Instruct-v0.1"
# # messages=misrtal_messages, # )
# # model="Mixtral-8x7B-Instruct-v0.1"
# # ) from openai import OpenAI
#
# from openai import OpenAI client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2" base_url='https://api.together.xyz',
# )
# client = OpenAI(api_key=TOGETHER_API_KEY, logger.info("Generating content with Mistral model: %s", self.model)
# base_url='https://api.together.xyz',
# ) response = client.chat.completions.create(
# logger.info("Generating content with Mistral model: %s", self.model) messages=misrtal_messages,
# response = client.chat.completions.create( model=self.model,
# messages=misrtal_messages, max_tokens=max_tokens
# model="mistralai/Mixtral-8x7B-Instruct-v0.1", )
# max_tokens=1024
# ) try:
# return response.choices[0].message.content
# try: except Exception as e:
# # return response['choices'][0]['message']['content'] print("Failed to call LLM: " + str(e))
# return response.choices[0].message.content return ""
# except Exception as e:
# print("Failed to call LLM: " + str(e)) elif self.model.startswith("THUDM"):
# return "" # THUDM/cogagent-chat-hf
print("Call CogAgent")
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
cog_messages = []
for i, message in enumerate(messages):
cog_message = {
"role": message["role"],
"content": []
}
for part in message["content"]:
if part['type'] == "image_url":
cog_message['content'].append({"type": "image_url", "image_url": {"url": part['image_url']['url'] } })
if part['type'] == "text":
cog_message['content'].append({"type": "text", "text": part['text']})
cog_messages.append(cog_message)
# the cogagent not support system message in our endpoint, so we concatenate it at the first user message
if cog_messages[0]['role'] == "system":
cog_system_message_item = cog_messages[0]['content'][0]
cog_messages[1]['content'].insert(0, cog_system_message_item)
cog_messages.pop(0)
payload = {
"model": self.model,
"max_tokens": max_tokens,
"messages": cog_messages
}
base_url = "http://127.0.0.1:8000"
response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=False)
if response.status_code == 200:
decoded_line = response.json()
content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "")
return content
else:
print("Failed to call LLM: ", response.status_code)
return ""
elif self.model.startswith("gemini"): elif self.model.startswith("gemini"):
def encoded_img_to_pil_img(data_str): def encoded_img_to_pil_img(data_str):
@@ -675,6 +720,7 @@ class PromptAgent:
try: try:
return response.text return response.text
except Exception as e: except Exception as e:
logger.error("Meet exception when calling Gemini API, " + str(e))
return "" return ""
elif self.model.startswith("qwen"): elif self.model.startswith("qwen"):
messages = payload["messages"] messages = payload["messages"]

26
run.py
View File

@@ -6,6 +6,7 @@ import datetime
import json import json
import logging import logging
import os import os
import random
import sys import sys
import wandb import wandb
@@ -75,7 +76,7 @@ def config() -> argparse.Namespace:
"screenshot_a11y_tree", "screenshot_a11y_tree",
"som" "som"
], ],
default="som", default="a11y_tree",
help="Observation type", help="Observation type",
) )
parser.add_argument("--screen_width", type=int, default=1920) parser.add_argument("--screen_width", type=int, default=1920)
@@ -88,7 +89,7 @@ def config() -> argparse.Namespace:
parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples") parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples")
# lm config # lm config
parser.add_argument("--model", type=str, default="gpt-4-vision-preview") parser.add_argument("--model", type=str, default="gpt-4-0125-preview")
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500) parser.add_argument("--max_tokens", type=int, default=1500)
@@ -231,15 +232,13 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_
def get_result(action_space, use_model, observation_type, result_dir, total_file_json): def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model) target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = [] all_result = []
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir): for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain) domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path): if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path): for example_id in os.listdir(domain_path):
@@ -247,10 +246,17 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file
if os.path.isdir(example_path): if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path): if "result.txt" in os.listdir(example_path):
# empty all files under example_id # empty all files under example_id
all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read())) try:
all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read()))
except:
all_result.append(0.0)
print("Success Rate:", sum(all_result) / len(all_result) * 100, "%") if not all_result:
return all_result print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == '__main__': if __name__ == '__main__':