From a845824f06be024fbed6a1432c0a43bdf8e8a87a Mon Sep 17 00:00:00 2001 From: uvheart <37827236+uvheart@users.noreply.github.com> Date: Fri, 23 May 2025 03:57:42 +0800 Subject: [PATCH] add azure_gpt_4o (#197) --- mm_agents/agent.py | 53 ++++++++++++++++++++++++++++++++++++++++++++-- requirements.txt | 1 + 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 42aa53e..1eefbba 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -9,6 +9,7 @@ import xml.etree.ElementTree as ET from http import HTTPStatus from io import BytesIO from typing import Dict, List +from dotenv import load_dotenv import backoff import dashscope @@ -565,7 +566,55 @@ class PromptAgent: ) def call_llm(self, payload): - if self.model.startswith("gpt"): + if payload['model'].startswith("azure-gpt-4o"): + + + #.env config example : + # AZURE_OPENAI_API_BASE=YOUR_API_BASE + # AZURE_OPENAI_DEPLOYMENT=YOUR_DEPLOYMENT + # AZURE_OPENAI_API_VERSION=YOUR_API_VERSION + # AZURE_OPENAI_MODEL=gpt-4o-mini + # AZURE_OPENAI_API_KEY={{YOUR_API_KEY}} + # AZURE_OPENAI_ENDPOINT=${AZURE_OPENAI_API_BASE}/openai/deployments/${AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version=${AZURE_OPENAI_API_VERSION} + + + # Load environment variables + load_dotenv() + api_key = os.getenv('AZURE_OPENAI_API_KEY') + openai_endpoint = os.getenv('AZURE_OPENAI_ENDPOINT') + #logger.info("Openai endpoint: %s", openai_endpoint) + + headers = { + "Content-Type": "application/json", + "api-key": api_key + } + logger.info("Generating content with GPT model: %s", payload['model']) + response = requests.post( + openai_endpoint, + headers=headers, + json=payload + ) + + if response.status_code != 200: + if response.json()['error']['code'] == "context_length_exceeded": + logger.error("Context length exceeded. Retrying with a smaller context.") + payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:] + retry_response = requests.post( + openai_endpoint, + headers=headers, + json=payload + ) + if retry_response.status_code != 200: + logger.error( + "Failed to call LLM even after attempt on shortening the history: " + retry_response.text) + return "" + + logger.error("Failed to call LLM: " + response.text) + time.sleep(5) + return "" + else: + return response.json()['choices'][0]['message']['content'] + elif self.model.startswith("gpt"): headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" @@ -1046,7 +1095,7 @@ class PromptAgent: except Exception as e: print("Failed to call LLM: " + str(e)) return "" - + else: raise ValueError("Invalid model: " + self.model) diff --git a/requirements.txt b/requirements.txt index 6a19d54..48e354f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -60,3 +60,4 @@ azure-mgmt-compute azure-mgmt-network docker loguru +dotenv