Add Llama3-70B Support (from Groq)
This commit is contained in:
@@ -14,6 +14,8 @@ import backoff
|
|||||||
import dashscope
|
import dashscope
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import openai
|
import openai
|
||||||
|
from groq import Groq
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -27,6 +29,8 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
|
|||||||
|
|
||||||
logger = logging.getLogger("desktopenv.agent")
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
|
pure_text_settings = ['a11y_tree']
|
||||||
|
|
||||||
|
|
||||||
# Function to encode the image
|
# Function to encode the image
|
||||||
def encode_image(image_content):
|
def encode_image(image_content):
|
||||||
@@ -131,7 +135,7 @@ def parse_actions_from_string(input_string):
|
|||||||
|
|
||||||
|
|
||||||
def parse_code_from_string(input_string):
|
def parse_code_from_string(input_string):
|
||||||
input_string = input_string.replace(";", "\n")
|
input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
|
||||||
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
|
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
|
||||||
return [input_string.strip()]
|
return [input_string.strip()]
|
||||||
|
|
||||||
@@ -510,7 +514,7 @@ class PromptAgent:
|
|||||||
return response, actions
|
return response, actions
|
||||||
|
|
||||||
@backoff.on_exception(
|
@backoff.on_exception(
|
||||||
backoff.expo,
|
backoff.constant,
|
||||||
# here you should add more model exceptions as you want,
|
# here you should add more model exceptions as you want,
|
||||||
# but you are forbidden to add "Exception", that is, a common type of exception
|
# but you are forbidden to add "Exception", that is, a common type of exception
|
||||||
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
||||||
@@ -525,8 +529,12 @@ class PromptAgent:
|
|||||||
ResourceExhausted,
|
ResourceExhausted,
|
||||||
InternalServerError,
|
InternalServerError,
|
||||||
BadRequest,
|
BadRequest,
|
||||||
|
|
||||||
|
# Groq exceptions
|
||||||
|
# todo: check
|
||||||
),
|
),
|
||||||
max_tries=5
|
interval=30,
|
||||||
|
max_tries=10
|
||||||
)
|
)
|
||||||
def call_llm(self, payload):
|
def call_llm(self, payload):
|
||||||
|
|
||||||
@@ -632,6 +640,8 @@ class PromptAgent:
|
|||||||
top_p = payload["top_p"]
|
top_p = payload["top_p"]
|
||||||
temperature = payload["temperature"]
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
|
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
|
||||||
|
|
||||||
mistral_messages = []
|
mistral_messages = []
|
||||||
|
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
@@ -650,12 +660,13 @@ class PromptAgent:
|
|||||||
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
||||||
base_url='https://api.together.xyz',
|
base_url='https://api.together.xyz',
|
||||||
)
|
)
|
||||||
logger.info("Generating content with Mistral model: %s", self.model)
|
|
||||||
|
|
||||||
flag = 0
|
flag = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
if flag > 20: break
|
if flag > 20:
|
||||||
|
break
|
||||||
|
logger.info("Generating content with model: %s", self.model)
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
messages=mistral_messages,
|
messages=mistral_messages,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
@@ -733,6 +744,9 @@ class PromptAgent:
|
|||||||
top_p = payload["top_p"]
|
top_p = payload["top_p"]
|
||||||
temperature = payload["temperature"]
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
|
if self.model == "gemini-pro":
|
||||||
|
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
|
||||||
|
|
||||||
gemini_messages = []
|
gemini_messages = []
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
role_mapping = {
|
role_mapping = {
|
||||||
@@ -782,7 +796,7 @@ class PromptAgent:
|
|||||||
gemini_messages,
|
gemini_messages,
|
||||||
generation_config={
|
generation_config={
|
||||||
"candidate_count": 1,
|
"candidate_count": 1,
|
||||||
"max_output_tokens": max_tokens,
|
# "max_output_tokens": max_tokens,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"temperature": temperature
|
"temperature": temperature
|
||||||
},
|
},
|
||||||
@@ -796,7 +810,6 @@ class PromptAgent:
|
|||||||
)
|
)
|
||||||
return response.text
|
return response.text
|
||||||
|
|
||||||
|
|
||||||
elif self.model == "gemini-1.5-pro-latest":
|
elif self.model == "gemini-1.5-pro-latest":
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
max_tokens = payload["max_tokens"]
|
max_tokens = payload["max_tokens"]
|
||||||
@@ -858,7 +871,7 @@ class PromptAgent:
|
|||||||
gemini_messages,
|
gemini_messages,
|
||||||
generation_config={
|
generation_config={
|
||||||
"candidate_count": 1,
|
"candidate_count": 1,
|
||||||
"max_output_tokens": max_tokens,
|
# "max_output_tokens": max_tokens,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"temperature": temperature
|
"temperature": temperature
|
||||||
},
|
},
|
||||||
@@ -873,6 +886,59 @@ class PromptAgent:
|
|||||||
|
|
||||||
return response.text
|
return response.text
|
||||||
|
|
||||||
|
elif self.model == "llama3-70b":
|
||||||
|
messages = payload["messages"]
|
||||||
|
max_tokens = payload["max_tokens"]
|
||||||
|
top_p = payload["top_p"]
|
||||||
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
|
assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings"
|
||||||
|
|
||||||
|
groq_messages = []
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
groq_message = {
|
||||||
|
"role": message["role"],
|
||||||
|
"content": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for part in message["content"]:
|
||||||
|
groq_message['content'] = part['text'] if part['type'] == "text" else ""
|
||||||
|
|
||||||
|
groq_messages.append(groq_message)
|
||||||
|
|
||||||
|
# The implementation based on Groq API
|
||||||
|
client = Groq(
|
||||||
|
api_key=os.environ.get("GROQ_API_KEY"),
|
||||||
|
)
|
||||||
|
|
||||||
|
flag = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if flag > 20:
|
||||||
|
break
|
||||||
|
logger.info("Generating content with model: %s", self.model)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
messages=groq_messages,
|
||||||
|
model="llama3-70b-8192",
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
if flag == 0:
|
||||||
|
groq_messages = [groq_messages[0]] + groq_messages[-1:]
|
||||||
|
else:
|
||||||
|
groq_messages[-1]["content"] = ' '.join(groq_messages[-1]["content"].split()[:-500])
|
||||||
|
flag = flag + 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
print("Failed to call LLM: " + str(e))
|
||||||
|
return ""
|
||||||
|
|
||||||
elif self.model.startswith("qwen"):
|
elif self.model.startswith("qwen"):
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
max_tokens = payload["max_tokens"]
|
max_tokens = payload["max_tokens"]
|
||||||
|
|||||||
@@ -52,3 +52,4 @@ wandb
|
|||||||
wrapt_timeout_decorator
|
wrapt_timeout_decorator
|
||||||
gdown
|
gdown
|
||||||
tiktoken
|
tiktoken
|
||||||
|
groq
|
||||||
|
|||||||
Reference in New Issue
Block a user