Add Llama3-70B Support (from Groq)

This commit is contained in:
Timothyxxx
2024-05-09 02:04:02 +08:00
parent 97b567a287
commit 54905380e6
2 changed files with 75 additions and 8 deletions

View File

@@ -14,6 +14,8 @@ import backoff
import dashscope
import google.generativeai as genai
import openai
from groq import Groq
import requests
import tiktoken
from PIL import Image
@@ -27,6 +29,8 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
logger = logging.getLogger("desktopenv.agent")
pure_text_settings = ['a11y_tree']
# Function to encode the image
def encode_image(image_content):
@@ -131,7 +135,7 @@ def parse_actions_from_string(input_string):
def parse_code_from_string(input_string):
input_string = input_string.replace(";", "\n")
input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()]
@@ -510,7 +514,7 @@ class PromptAgent:
return response, actions
@backoff.on_exception(
backoff.expo,
backoff.constant,
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
@@ -525,8 +529,12 @@ class PromptAgent:
ResourceExhausted,
InternalServerError,
BadRequest,
# Groq exceptions
# todo: check
),
max_tries=5
interval=30,
max_tries=10
)
def call_llm(self, payload):
@@ -632,6 +640,8 @@ class PromptAgent:
top_p = payload["top_p"]
temperature = payload["temperature"]
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
mistral_messages = []
for i, message in enumerate(messages):
@@ -650,12 +660,13 @@ class PromptAgent:
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
base_url='https://api.together.xyz',
)
logger.info("Generating content with Mistral model: %s", self.model)
flag = 0
while True:
try:
if flag > 20: break
if flag > 20:
break
logger.info("Generating content with model: %s", self.model)
response = client.chat.completions.create(
messages=mistral_messages,
model=self.model,
@@ -733,6 +744,9 @@ class PromptAgent:
top_p = payload["top_p"]
temperature = payload["temperature"]
if self.model == "gemini-pro":
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
gemini_messages = []
for i, message in enumerate(messages):
role_mapping = {
@@ -782,7 +796,7 @@ class PromptAgent:
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
# "max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
@@ -796,7 +810,6 @@ class PromptAgent:
)
return response.text
elif self.model == "gemini-1.5-pro-latest":
messages = payload["messages"]
max_tokens = payload["max_tokens"]
@@ -858,7 +871,7 @@ class PromptAgent:
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
# "max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
@@ -873,6 +886,59 @@ class PromptAgent:
return response.text
elif self.model == "llama3-70b":
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
groq_messages = []
for i, message in enumerate(messages):
groq_message = {
"role": message["role"],
"content": ""
}
for part in message["content"]:
groq_message['content'] = part['text'] if part['type'] == "text" else ""
groq_messages.append(groq_message)
# The implementation based on Groq API
client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)
flag = 0
while True:
try:
if flag > 20:
break
logger.info("Generating content with model: %s", self.model)
response = client.chat.completions.create(
messages=groq_messages,
model="llama3-70b-8192",
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature
)
break
except:
if flag == 0:
groq_messages = [groq_messages[0]] + groq_messages[-1:]
else:
groq_messages[-1]["content"] = ' '.join(groq_messages[-1]["content"].split()[:-500])
flag = flag + 1
try:
return response.choices[0].message.content
except Exception as e:
print("Failed to call LLM: " + str(e))
return ""
elif self.model.startswith("qwen"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]

View File

@@ -52,3 +52,4 @@ wandb
wrapt_timeout_decorator
gdown
tiktoken
groq