122324154

This commit is contained in:
Timothyxxx
2024-02-02 14:36:53 +08:00
parent 32bcdd0937
commit 068c6f5769
7 changed files with 436 additions and 146 deletions

View File

@@ -15,12 +15,10 @@ import google.generativeai as genai
import openai
import requests
from PIL import Image
from openai.error import (
from openai import (
APIConnectionError,
APIError,
RateLimitError,
ServiceUnavailableError,
InvalidRequestError
RateLimitError
)
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
@@ -115,6 +113,7 @@ def parse_actions_from_string(input_string):
def parse_code_from_string(input_string):
input_string = input_string.replace(";", "\n")
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()]
@@ -475,14 +474,12 @@ class GPT4v_Agent:
with open("messages.json", "w") as f:
f.write(json.dumps(messages, indent=4))
try:
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens
})
except:
response = ""
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens
})
logger.debug("RESPONSE: %s", response)
@@ -527,7 +524,7 @@ class GPT4v_Agent:
@backoff.on_exception(
backoff.expo,
(APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError),
(APIError, RateLimitError, APIConnectionError),
max_tries=10
)
def call_llm(self, payload):
@@ -580,23 +577,34 @@ class GPT4v_Agent:
misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
misrtal_messages.pop(0)
openai.api_base = "http://localhost:8000/v1"
openai.api_key = "test"
response = openai.ChatCompletion.create(
# openai.api_base = "http://localhost:8000/v1"
# openai.api_key = "test"
# response = openai.ChatCompletion.create(
# messages=misrtal_messages,
# model="Mixtral-8x7B-Instruct-v0.1"
# )
from openai import OpenAI
TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
client = OpenAI(api_key=TOGETHER_API_KEY,
base_url='https://api.together.xyz',
)
response = client.chat.completions.create(
messages=misrtal_messages,
model="Mixtral-8x7B-Instruct-v0.1"
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
max_tokens=1024
)
try:
return response['choices'][0]['message']['content']
# return response['choices'][0]['message']['content']
return response.choices[0].message.content
except Exception as e:
print("Failed to call LLM: " + str(e))
return ""
elif self.model.startswith("gemini"):
api_key = os.environ.get("GENAI_API_KEY")
genai.api_key = api_key
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
@@ -609,8 +617,13 @@ class GPT4v_Agent:
gemini_messages = []
for i, message in enumerate(messages):
role_mapping = {
"assistant": "model",
"user": "user",
"system": "system"
}
gemini_message = {
"role": message["role"],
"role": role_mapping[message["role"]],
"parts": []
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
@@ -626,6 +639,15 @@ class GPT4v_Agent:
gemini_messages.append(gemini_message)
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
if gemini_messages[0]['role'] == "system":
gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0]
gemini_messages.pop(0)
print(gemini_messages)
api_key = os.environ.get("GENAI_API_KEY")
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
genai.configure(api_key=api_key)
response = genai.GenerativeModel(self.model).generate_content(
gemini_messages,
generation_config={