From 54905380e6edebf345e89bf46d7f2df636d89aba Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Thu, 9 May 2024 02:04:02 +0800 Subject: [PATCH] Add Llama3-70B Support (from Groq) --- mm_agents/agent.py | 82 +++++++++++++++++++++++++++++++++++++++++----- requirements.txt | 1 + 2 files changed, 75 insertions(+), 8 deletions(-) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index da28ea8..893d45f 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -14,6 +14,8 @@ import backoff import dashscope import google.generativeai as genai import openai +from groq import Groq + import requests import tiktoken from PIL import Image @@ -27,6 +29,8 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S logger = logging.getLogger("desktopenv.agent") +pure_text_settings = ['a11y_tree'] + # Function to encode the image def encode_image(image_content): @@ -131,7 +135,7 @@ def parse_actions_from_string(input_string): def parse_code_from_string(input_string): - input_string = input_string.replace(";", "\n") + input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()]) if input_string.strip() in ['WAIT', 'DONE', 'FAIL']: return [input_string.strip()] @@ -510,7 +514,7 @@ class PromptAgent: return response, actions @backoff.on_exception( - backoff.expo, + backoff.constant, # here you should add more model exceptions as you want, # but you are forbidden to add "Exception", that is, a common type of exception # because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit @@ -525,8 +529,12 @@ class PromptAgent: ResourceExhausted, InternalServerError, BadRequest, + + # Groq exceptions + # todo: check ), - max_tries=5 + interval=30, + max_tries=10 ) def call_llm(self, payload): @@ -632,6 +640,8 @@ class PromptAgent: top_p = payload["top_p"] temperature = payload["temperature"] + assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings" + mistral_messages = [] for i, message in enumerate(messages): @@ -650,12 +660,13 @@ class PromptAgent: client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"], base_url='https://api.together.xyz', ) - logger.info("Generating content with Mistral model: %s", self.model) flag = 0 while True: try: - if flag > 20: break + if flag > 20: + break + logger.info("Generating content with model: %s", self.model) response = client.chat.completions.create( messages=mistral_messages, model=self.model, @@ -733,6 +744,9 @@ class PromptAgent: top_p = payload["top_p"] temperature = payload["temperature"] + if self.model == "gemini-pro": + assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings" + gemini_messages = [] for i, message in enumerate(messages): role_mapping = { @@ -782,7 +796,7 @@ class PromptAgent: gemini_messages, generation_config={ "candidate_count": 1, - "max_output_tokens": max_tokens, + # "max_output_tokens": max_tokens, "top_p": top_p, "temperature": temperature }, @@ -796,7 +810,6 @@ class PromptAgent: ) return response.text - elif self.model == "gemini-1.5-pro-latest": messages = payload["messages"] max_tokens = payload["max_tokens"] @@ -858,7 +871,7 @@ class PromptAgent: gemini_messages, generation_config={ "candidate_count": 1, - "max_output_tokens": max_tokens, + # "max_output_tokens": max_tokens, "top_p": top_p, "temperature": temperature }, @@ -873,6 +886,59 @@ class PromptAgent: return response.text + elif self.model == "llama3-70b": + messages = payload["messages"] + max_tokens = payload["max_tokens"] + top_p = payload["top_p"] + temperature = payload["temperature"] + + assert self.observation_type in pure_text_settings, f"The model {self.model} can only support text-based input, please consider change based model or settings" + + groq_messages = [] + + for i, message in enumerate(messages): + groq_message = { + "role": message["role"], + "content": "" + } + + for part in message["content"]: + groq_message['content'] = part['text'] if part['type'] == "text" else "" + + groq_messages.append(groq_message) + + # The implementation based on Groq API + client = Groq( + api_key=os.environ.get("GROQ_API_KEY"), + ) + + flag = 0 + while True: + try: + if flag > 20: + break + logger.info("Generating content with model: %s", self.model) + response = client.chat.completions.create( + messages=groq_messages, + model="llama3-70b-8192", + max_tokens=max_tokens, + top_p=top_p, + temperature=temperature + ) + break + except: + if flag == 0: + groq_messages = [groq_messages[0]] + groq_messages[-1:] + else: + groq_messages[-1]["content"] = ' '.join(groq_messages[-1]["content"].split()[:-500]) + flag = flag + 1 + + try: + return response.choices[0].message.content + except Exception as e: + print("Failed to call LLM: " + str(e)) + return "" + elif self.model.startswith("qwen"): messages = payload["messages"] max_tokens = payload["max_tokens"] diff --git a/requirements.txt b/requirements.txt index 0229b7d..2cc96ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,3 +52,4 @@ wandb wrapt_timeout_decorator gdown tiktoken +groq