122324154
This commit is contained in:
@@ -15,12 +15,10 @@ import google.generativeai as genai
|
||||
import openai
|
||||
import requests
|
||||
from PIL import Image
|
||||
from openai.error import (
|
||||
from openai import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
InvalidRequestError
|
||||
RateLimitError
|
||||
)
|
||||
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
||||
@@ -115,6 +113,7 @@ def parse_actions_from_string(input_string):
|
||||
|
||||
|
||||
def parse_code_from_string(input_string):
|
||||
input_string = input_string.replace(";", "\n")
|
||||
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
|
||||
return [input_string.strip()]
|
||||
|
||||
@@ -475,14 +474,12 @@ class GPT4v_Agent:
|
||||
with open("messages.json", "w") as f:
|
||||
f.write(json.dumps(messages, indent=4))
|
||||
|
||||
try:
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
})
|
||||
except:
|
||||
response = ""
|
||||
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens
|
||||
})
|
||||
|
||||
logger.debug("RESPONSE: %s", response)
|
||||
|
||||
@@ -527,7 +524,7 @@ class GPT4v_Agent:
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
(APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError),
|
||||
(APIError, RateLimitError, APIConnectionError),
|
||||
max_tries=10
|
||||
)
|
||||
def call_llm(self, payload):
|
||||
@@ -580,23 +577,34 @@ class GPT4v_Agent:
|
||||
misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
|
||||
misrtal_messages.pop(0)
|
||||
|
||||
openai.api_base = "http://localhost:8000/v1"
|
||||
openai.api_key = "test"
|
||||
response = openai.ChatCompletion.create(
|
||||
# openai.api_base = "http://localhost:8000/v1"
|
||||
# openai.api_key = "test"
|
||||
# response = openai.ChatCompletion.create(
|
||||
# messages=misrtal_messages,
|
||||
# model="Mixtral-8x7B-Instruct-v0.1"
|
||||
# )
|
||||
|
||||
from openai import OpenAI
|
||||
TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
|
||||
|
||||
client = OpenAI(api_key=TOGETHER_API_KEY,
|
||||
base_url='https://api.together.xyz',
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
messages=misrtal_messages,
|
||||
model="Mixtral-8x7B-Instruct-v0.1"
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
max_tokens=1024
|
||||
)
|
||||
|
||||
try:
|
||||
return response['choices'][0]['message']['content']
|
||||
# return response['choices'][0]['message']['content']
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print("Failed to call LLM: " + str(e))
|
||||
return ""
|
||||
|
||||
elif self.model.startswith("gemini"):
|
||||
|
||||
api_key = os.environ.get("GENAI_API_KEY")
|
||||
genai.api_key = api_key
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
@@ -609,8 +617,13 @@ class GPT4v_Agent:
|
||||
|
||||
gemini_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
role_mapping = {
|
||||
"assistant": "model",
|
||||
"user": "user",
|
||||
"system": "system"
|
||||
}
|
||||
gemini_message = {
|
||||
"role": message["role"],
|
||||
"role": role_mapping[message["role"]],
|
||||
"parts": []
|
||||
}
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
@@ -626,6 +639,15 @@ class GPT4v_Agent:
|
||||
|
||||
gemini_messages.append(gemini_message)
|
||||
|
||||
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
||||
if gemini_messages[0]['role'] == "system":
|
||||
gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0]
|
||||
gemini_messages.pop(0)
|
||||
|
||||
print(gemini_messages)
|
||||
api_key = os.environ.get("GENAI_API_KEY")
|
||||
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
||||
genai.configure(api_key=api_key)
|
||||
response = genai.GenerativeModel(self.model).generate_content(
|
||||
gemini_messages,
|
||||
generation_config={
|
||||
|
||||
Reference in New Issue
Block a user