diff --git a/mm_agents/agent.py b/mm_agents/agent.py index ff92673..c769827 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -568,7 +568,7 @@ class PromptAgent: top_p = payload["top_p"] temperature = payload["temperature"] - misrtal_messages = [] + mistral_messages = [] for i, message in enumerate(messages): mistral_message = { @@ -579,13 +579,8 @@ class PromptAgent: for part in message["content"]: mistral_message['content'] = part['text'] if part['type'] == "text" else "" - misrtal_messages.append(mistral_message) + mistral_messages.append(mistral_message) - # openai.api_base = "http://localhost:8000/v1" - # response = openai.ChatCompletion.create( - # messages=misrtal_messages, - # model="Mixtral-8x7B-Instruct-v0.1" - # ) from openai import OpenAI @@ -593,12 +588,23 @@ class PromptAgent: base_url='https://api.together.xyz', ) logger.info("Generating content with Mistral model: %s", self.model) - - response = client.chat.completions.create( - messages=misrtal_messages, - model=self.model, - max_tokens=max_tokens - ) + + flag = 0 + while True: + try: + if flag > 20: break + response = client.chat.completions.create( + messages=mistral_messages, + model=self.model, + max_tokens=max_tokens + ) + break + except: + if flag == 0: + mistral_messages = [mistral_messages[0]] + mistral_messages[-1:] + else: + mistral_messages[-1]["content"] = ' '.join(mistral_messages[-1]["content"].split()[:-500]) + flag = flag + 1 try: return response.choices[0].message.content