Add Llama3-70B Support (from Groq)

This commit is contained in:
Timothyxxx
2024-05-09 02:04:02 +08:00
parent 97b567a287
commit 54905380e6
2 changed files with 75 additions and 8 deletions

View File

@@ -14,6 +14,8 @@ import backoff
import dashscope import dashscope
import google.generativeai as genai import google.generativeai as genai
import openai import openai
from groq import Groq
import requests import requests
import tiktoken import tiktoken
from PIL import Image from PIL import Image
@@ -27,6 +29,8 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
logger = logging.getLogger("desktopenv.agent") logger = logging.getLogger("desktopenv.agent")
pure_text_settings = ['a11y_tree']
# Function to encode the image # Function to encode the image
def encode_image(image_content): def encode_image(image_content):
@@ -131,7 +135,7 @@ def parse_actions_from_string(input_string):
def parse_code_from_string(input_string): def parse_code_from_string(input_string):
input_string = input_string.replace(";", "\n") input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']: if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()] return [input_string.strip()]
@@ -510,7 +514,7 @@ class PromptAgent:
return response, actions return response, actions
@backoff.on_exception( @backoff.on_exception(
backoff.expo, backoff.constant,
# here you should add more model exceptions as you want, # here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception # but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit # because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
@@ -525,8 +529,12 @@ class PromptAgent:
ResourceExhausted, ResourceExhausted,
InternalServerError, InternalServerError,
BadRequest, BadRequest,
# Groq exceptions
# todo: check
), ),
max_tries=5 interval=30,
max_tries=10
) )
def call_llm(self, payload): def call_llm(self, payload):
@@ -632,6 +640,8 @@ class PromptAgent:
top_p = payload["top_p"] top_p = payload["top_p"]
temperature = payload["temperature"] temperature = payload["temperature"]
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
mistral_messages = [] mistral_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
@@ -650,12 +660,13 @@ class PromptAgent:
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"], client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
base_url='https://api.together.xyz', base_url='https://api.together.xyz',
) )
logger.info("Generating content with Mistral model: %s", self.model)
flag = 0 flag = 0
while True: while True:
try: try:
if flag > 20: break if flag > 20:
break
logger.info("Generating content with model: %s", self.model)
response = client.chat.completions.create( response = client.chat.completions.create(
messages=mistral_messages, messages=mistral_messages,
model=self.model, model=self.model,
@@ -733,6 +744,9 @@ class PromptAgent:
top_p = payload["top_p"] top_p = payload["top_p"]
temperature = payload["temperature"] temperature = payload["temperature"]
if self.model == "gemini-pro":
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
gemini_messages = [] gemini_messages = []
for i, message in enumerate(messages): for i, message in enumerate(messages):
role_mapping = { role_mapping = {
@@ -782,7 +796,7 @@ class PromptAgent:
gemini_messages, gemini_messages,
generation_config={ generation_config={
"candidate_count": 1, "candidate_count": 1,
"max_output_tokens": max_tokens, # "max_output_tokens": max_tokens,
"top_p": top_p, "top_p": top_p,
"temperature": temperature "temperature": temperature
}, },
@@ -796,7 +810,6 @@ class PromptAgent:
) )
return response.text return response.text
elif self.model == "gemini-1.5-pro-latest": elif self.model == "gemini-1.5-pro-latest":
messages = payload["messages"] messages = payload["messages"]
max_tokens = payload["max_tokens"] max_tokens = payload["max_tokens"]
@@ -858,7 +871,7 @@ class PromptAgent:
gemini_messages, gemini_messages,
generation_config={ generation_config={
"candidate_count": 1, "candidate_count": 1,
"max_output_tokens": max_tokens, # "max_output_tokens": max_tokens,
"top_p": top_p, "top_p": top_p,
"temperature": temperature "temperature": temperature
}, },
@@ -873,6 +886,59 @@ class PromptAgent:
return response.text return response.text
elif self.model == "llama3-70b":
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
groq_messages = []
for i, message in enumerate(messages):
groq_message = {
"role": message["role"],
"content": ""
}
for part in message["content"]:
groq_message['content'] = part['text'] if part['type'] == "text" else ""
groq_messages.append(groq_message)
# The implementation based on Groq API
client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)
flag = 0
while True:
try:
if flag > 20:
break
logger.info("Generating content with model: %s", self.model)
response = client.chat.completions.create(
messages=groq_messages,
model="llama3-70b-8192",
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature
)
break
except:
if flag == 0:
groq_messages = [groq_messages[0]] + groq_messages[-1:]
else:
groq_messages[-1]["content"] = ' '.join(groq_messages[-1]["content"].split()[:-500])
flag = flag + 1
try:
return response.choices[0].message.content
except Exception as e:
print("Failed to call LLM: " + str(e))
return ""
elif self.model.startswith("qwen"): elif self.model.startswith("qwen"):
messages = payload["messages"] messages = payload["messages"]
max_tokens = payload["max_tokens"] max_tokens = payload["max_tokens"]

View File

@@ -52,3 +52,4 @@ wandb
wrapt_timeout_decorator wrapt_timeout_decorator
gdown gdown
tiktoken tiktoken
groq