From 017dde896642ab7bba14522b1c6ff36d27430907 Mon Sep 17 00:00:00 2001 From: lfy79001 <843265183@qq.com> Date: Sat, 16 Mar 2024 01:37:42 +0800 Subject: [PATCH 1/4] add claude3 agent code --- mm_agents/agent.py | 65 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 039eda8..dcf0919 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -488,6 +488,71 @@ class PromptAgent: else: return response.json()['choices'][0]['message']['content'] + elif self.model.startswith("claude"): + messages = payload["messages"] + max_tokens = payload["max_tokens"] + + claude_messages = [] + + for i, message in enumerate(messages): + claude_message = { + "role": message["role"], + "content": [] + } + assert len(message["content"]) in [1, 2], "One text, or one text with one image" + for part in message["content"]: + + if part['type'] == "image_url": + image_source = {} + image_source["type"] = "base64" + image_source["media_type"] = "image/jpeg" + image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "") + claude_message['content'].append({"type": "image", "source": image_source}) + + if part['type'] == "text": + claude_message['content'].append({"type": "text", "text": part['text']}) + + claude_messages.append(claude_message) + + + headers = { + "x-api-key": os.environ["ANTHROPIC_API_KEY"], + "anthropic-version": "2023-06-01", + "content-type": "application/json" + } + + payload = { + "model": "claude-3-opus-20240229", + "max_tokens": max_tokens, + "messages": claude_messages + } + + response = requests.post( + "https://api.anthropic.com/v1/messages", + headers=headers, + json=payload + ) + + if response.status_code != 200: + if response.json()['error']['code'] == "context_length_exceeded": + logger.error("Context length exceeded. Retrying with a smaller context.") + payload["messages"] = payload["messages"][-1:] + retry_response = requests.post( + "https://api.anthropic.com/v1/messages", + headers=headers, + json=payload + ) + if retry_response.status_code != 200: + logger.error("Failed to call LLM: " + retry_response.text) + return "" + + logger.error("Failed to call LLM: " + response.text) + time.sleep(5) + return "" + else: + return response.json()['content'][0]['text'] + + # elif self.model.startswith("mistral"): # print("Call mistral") # messages = payload["messages"] From 3b13046745438f12a472c37b7c69459897a5d5bb Mon Sep 17 00:00:00 2001 From: lfy79001 <843265183@qq.com> Date: Sat, 16 Mar 2024 01:40:41 +0800 Subject: [PATCH 2/4] add claude3 agent code --- mm_agents/agent.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index dcf0919..7dd45b3 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -534,17 +534,17 @@ class PromptAgent: ) if response.status_code != 200: - if response.json()['error']['code'] == "context_length_exceeded": - logger.error("Context length exceeded. Retrying with a smaller context.") - payload["messages"] = payload["messages"][-1:] - retry_response = requests.post( - "https://api.anthropic.com/v1/messages", - headers=headers, - json=payload - ) - if retry_response.status_code != 200: - logger.error("Failed to call LLM: " + retry_response.text) - return "" + # if response.json()['error']['code'] == "context_length_exceeded": + # logger.error("Context length exceeded. Retrying with a smaller context.") + # payload["messages"] = payload["messages"][-1:] + # retry_response = requests.post( + # "https://api.anthropic.com/v1/messages", + # headers=headers, + # json=payload + # ) + # if retry_response.status_code != 200: + # logger.error("Failed to call LLM: " + retry_response.text) + # return "" logger.error("Failed to call LLM: " + response.text) time.sleep(5) From 684b4a1b7bf72c693f7edaa722e9a2a653d8d986 Mon Sep 17 00:00:00 2001 From: lfy79001 <843265183@qq.com> Date: Sat, 16 Mar 2024 11:27:09 +0800 Subject: [PATCH 3/4] claude3_agnet_code --- mm_agents/agent.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 7dd45b3..28b5f11 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -505,7 +505,7 @@ class PromptAgent: if part['type'] == "image_url": image_source = {} image_source["type"] = "base64" - image_source["media_type"] = "image/jpeg" + image_source["media_type"] = "image/png" image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "") claude_message['content'].append({"type": "image", "source": image_source}) @@ -514,6 +514,12 @@ class PromptAgent: claude_messages.append(claude_message) + # the claude not support system message in our endpoint, so we concatenate it at the first user message + if claude_messages[0]['role'] == "system": + claude_system_message_item = claude_messages[0]['content'][0] + claude_messages[1]['content'].insert(0, claude_system_message_item) + claude_messages.pop(0) + headers = { "x-api-key": os.environ["ANTHROPIC_API_KEY"], @@ -534,17 +540,6 @@ class PromptAgent: ) if response.status_code != 200: - # if response.json()['error']['code'] == "context_length_exceeded": - # logger.error("Context length exceeded. Retrying with a smaller context.") - # payload["messages"] = payload["messages"][-1:] - # retry_response = requests.post( - # "https://api.anthropic.com/v1/messages", - # headers=headers, - # json=payload - # ) - # if retry_response.status_code != 200: - # logger.error("Failed to call LLM: " + retry_response.text) - # return "" logger.error("Failed to call LLM: " + response.text) time.sleep(5) From 505e772463699cdb45ec98f488f66e6d6b3d477c Mon Sep 17 00:00:00 2001 From: lfy79001 <843265183@qq.com> Date: Sat, 16 Mar 2024 11:57:49 +0800 Subject: [PATCH 4/4] claude3_agent_code --- mm_agents/agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 28b5f11..7599b02 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -491,6 +491,8 @@ class PromptAgent: elif self.model.startswith("claude"): messages = payload["messages"] max_tokens = payload["max_tokens"] + top_p = payload["top_p"] + temperature = payload["temperature"] claude_messages = [] @@ -528,7 +530,7 @@ class PromptAgent: } payload = { - "model": "claude-3-opus-20240229", + "model": self.model, "max_tokens": max_tokens, "messages": claude_messages }