add azure_gpt_4o (#197)
This commit is contained in:
@@ -9,6 +9,7 @@ import xml.etree.ElementTree as ET
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import backoff
|
||||
import dashscope
|
||||
@@ -565,7 +566,55 @@ class PromptAgent:
|
||||
)
|
||||
def call_llm(self, payload):
|
||||
|
||||
if self.model.startswith("gpt"):
|
||||
if payload['model'].startswith("azure-gpt-4o"):
|
||||
|
||||
|
||||
#.env config example :
|
||||
# AZURE_OPENAI_API_BASE=YOUR_API_BASE
|
||||
# AZURE_OPENAI_DEPLOYMENT=YOUR_DEPLOYMENT
|
||||
# AZURE_OPENAI_API_VERSION=YOUR_API_VERSION
|
||||
# AZURE_OPENAI_MODEL=gpt-4o-mini
|
||||
# AZURE_OPENAI_API_KEY={{YOUR_API_KEY}}
|
||||
# AZURE_OPENAI_ENDPOINT=${AZURE_OPENAI_API_BASE}/openai/deployments/${AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version=${AZURE_OPENAI_API_VERSION}
|
||||
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
api_key = os.getenv('AZURE_OPENAI_API_KEY')
|
||||
openai_endpoint = os.getenv('AZURE_OPENAI_ENDPOINT')
|
||||
#logger.info("Openai endpoint: %s", openai_endpoint)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": api_key
|
||||
}
|
||||
logger.info("Generating content with GPT model: %s", payload['model'])
|
||||
response = requests.post(
|
||||
openai_endpoint,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
if response.json()['error']['code'] == "context_length_exceeded":
|
||||
logger.error("Context length exceeded. Retrying with a smaller context.")
|
||||
payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
|
||||
retry_response = requests.post(
|
||||
openai_endpoint,
|
||||
headers=headers,
|
||||
json=payload
|
||||
)
|
||||
if retry_response.status_code != 200:
|
||||
logger.error(
|
||||
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
|
||||
return ""
|
||||
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
time.sleep(5)
|
||||
return ""
|
||||
else:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
elif self.model.startswith("gpt"):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
||||
@@ -1046,7 +1095,7 @@ class PromptAgent:
|
||||
except Exception as e:
|
||||
print("Failed to call LLM: " + str(e))
|
||||
return ""
|
||||
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid model: " + self.model)
|
||||
|
||||
|
||||
@@ -60,3 +60,4 @@ azure-mgmt-compute
|
||||
azure-mgmt-network
|
||||
docker
|
||||
loguru
|
||||
dotenv
|
||||
|
||||
Reference in New Issue
Block a user