diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 039eda8..dcf0919 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -488,6 +488,71 @@ class PromptAgent: else: return response.json()['choices'][0]['message']['content'] + elif self.model.startswith("claude"): + messages = payload["messages"] + max_tokens = payload["max_tokens"] + + claude_messages = [] + + for i, message in enumerate(messages): + claude_message = { + "role": message["role"], + "content": [] + } + assert len(message["content"]) in [1, 2], "One text, or one text with one image" + for part in message["content"]: + + if part['type'] == "image_url": + image_source = {} + image_source["type"] = "base64" + image_source["media_type"] = "image/jpeg" + image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "") + claude_message['content'].append({"type": "image", "source": image_source}) + + if part['type'] == "text": + claude_message['content'].append({"type": "text", "text": part['text']}) + + claude_messages.append(claude_message) + + + headers = { + "x-api-key": os.environ["ANTHROPIC_API_KEY"], + "anthropic-version": "2023-06-01", + "content-type": "application/json" + } + + payload = { + "model": "claude-3-opus-20240229", + "max_tokens": max_tokens, + "messages": claude_messages + } + + response = requests.post( + "https://api.anthropic.com/v1/messages", + headers=headers, + json=payload + ) + + if response.status_code != 200: + if response.json()['error']['code'] == "context_length_exceeded": + logger.error("Context length exceeded. Retrying with a smaller context.") + payload["messages"] = payload["messages"][-1:] + retry_response = requests.post( + "https://api.anthropic.com/v1/messages", + headers=headers, + json=payload + ) + if retry_response.status_code != 200: + logger.error("Failed to call LLM: " + retry_response.text) + return "" + + logger.error("Failed to call LLM: " + response.text) + time.sleep(5) + return "" + else: + return response.json()['content'][0]['text'] + + # elif self.model.startswith("mistral"): # print("Call mistral") # messages = payload["messages"]