Add Llama3-70B Support (from Groq)
This commit is contained in:
@@ -14,6 +14,8 @@ import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
from groq import Groq
|
||||
|
||||
import requests
|
||||
import tiktoken
|
||||
from PIL import Image
|
||||
@@ -27,6 +29,8 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
pure_text_settings = ['a11y_tree']
|
||||
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_content):
|
||||
@@ -131,7 +135,7 @@ def parse_actions_from_string(input_string):
|
||||
|
||||
|
||||
def parse_code_from_string(input_string):
|
||||
input_string = input_string.replace(";", "\n")
|
||||
input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
|
||||
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
|
||||
return [input_string.strip()]
|
||||
|
||||
@@ -510,7 +514,7 @@ class PromptAgent:
|
||||
return response, actions
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
backoff.constant,
|
||||
# here you should add more model exceptions as you want,
|
||||
# but you are forbidden to add "Exception", that is, a common type of exception
|
||||
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
||||
@@ -525,8 +529,12 @@ class PromptAgent:
|
||||
ResourceExhausted,
|
||||
InternalServerError,
|
||||
BadRequest,
|
||||
|
||||
# Groq exceptions
|
||||
# todo: check
|
||||
),
|
||||
max_tries=5
|
||||
interval=30,
|
||||
max_tries=10
|
||||
)
|
||||
def call_llm(self, payload):
|
||||
|
||||
@@ -632,6 +640,8 @@ class PromptAgent:
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
|
||||
|
||||
mistral_messages = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
@@ -650,12 +660,13 @@ class PromptAgent:
|
||||
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
||||
base_url='https://api.together.xyz',
|
||||
)
|
||||
logger.info("Generating content with Mistral model: %s", self.model)
|
||||
|
||||
flag = 0
|
||||
while True:
|
||||
try:
|
||||
if flag > 20: break
|
||||
if flag > 20:
|
||||
break
|
||||
logger.info("Generating content with model: %s", self.model)
|
||||
response = client.chat.completions.create(
|
||||
messages=mistral_messages,
|
||||
model=self.model,
|
||||
@@ -733,6 +744,9 @@ class PromptAgent:
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
if self.model == "gemini-pro":
|
||||
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
|
||||
|
||||
gemini_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
role_mapping = {
|
||||
@@ -782,7 +796,7 @@ class PromptAgent:
|
||||
gemini_messages,
|
||||
generation_config={
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": max_tokens,
|
||||
# "max_output_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature
|
||||
},
|
||||
@@ -796,7 +810,6 @@ class PromptAgent:
|
||||
)
|
||||
return response.text
|
||||
|
||||
|
||||
elif self.model == "gemini-1.5-pro-latest":
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
@@ -858,7 +871,7 @@ class PromptAgent:
|
||||
gemini_messages,
|
||||
generation_config={
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": max_tokens,
|
||||
# "max_output_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature
|
||||
},
|
||||
@@ -873,6 +886,59 @@ class PromptAgent:
|
||||
|
||||
return response.text
|
||||
|
||||
elif self.model == "llama3-70b":
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
|
||||
|
||||
groq_messages = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
groq_message = {
|
||||
"role": message["role"],
|
||||
"content": ""
|
||||
}
|
||||
|
||||
for part in message["content"]:
|
||||
groq_message['content'] = part['text'] if part['type'] == "text" else ""
|
||||
|
||||
groq_messages.append(groq_message)
|
||||
|
||||
# The implementation based on Groq API
|
||||
client = Groq(
|
||||
api_key=os.environ.get("GROQ_API_KEY"),
|
||||
)
|
||||
|
||||
flag = 0
|
||||
while True:
|
||||
try:
|
||||
if flag > 20:
|
||||
break
|
||||
logger.info("Generating content with model: %s", self.model)
|
||||
response = client.chat.completions.create(
|
||||
messages=groq_messages,
|
||||
model="llama3-70b-8192",
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature
|
||||
)
|
||||
break
|
||||
except:
|
||||
if flag == 0:
|
||||
groq_messages = [groq_messages[0]] + groq_messages[-1:]
|
||||
else:
|
||||
groq_messages[-1]["content"] = ' '.join(groq_messages[-1]["content"].split()[:-500])
|
||||
flag = flag + 1
|
||||
|
||||
try:
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
print("Failed to call LLM: " + str(e))
|
||||
return ""
|
||||
|
||||
elif self.model.startswith("qwen"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
|
||||
@@ -52,3 +52,4 @@ wandb
|
||||
wrapt_timeout_decorator
|
||||
gdown
|
||||
tiktoken
|
||||
groq
|
||||
|
||||
Reference in New Issue
Block a user