add azure_gpt_4o (#197)
This commit is contained in:
@@ -9,6 +9,7 @@ import xml.etree.ElementTree as ET
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
import dashscope
|
import dashscope
|
||||||
@@ -565,7 +566,55 @@ class PromptAgent:
|
|||||||
)
|
)
|
||||||
def call_llm(self, payload):
|
def call_llm(self, payload):
|
||||||
|
|
||||||
if self.model.startswith("gpt"):
|
if payload['model'].startswith("azure-gpt-4o"):
|
||||||
|
|
||||||
|
|
||||||
|
#.env config example :
|
||||||
|
# AZURE_OPENAI_API_BASE=YOUR_API_BASE
|
||||||
|
# AZURE_OPENAI_DEPLOYMENT=YOUR_DEPLOYMENT
|
||||||
|
# AZURE_OPENAI_API_VERSION=YOUR_API_VERSION
|
||||||
|
# AZURE_OPENAI_MODEL=gpt-4o-mini
|
||||||
|
# AZURE_OPENAI_API_KEY={{YOUR_API_KEY}}
|
||||||
|
# AZURE_OPENAI_ENDPOINT=${AZURE_OPENAI_API_BASE}/openai/deployments/${AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version=${AZURE_OPENAI_API_VERSION}
|
||||||
|
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
api_key = os.getenv('AZURE_OPENAI_API_KEY')
|
||||||
|
openai_endpoint = os.getenv('AZURE_OPENAI_ENDPOINT')
|
||||||
|
#logger.info("Openai endpoint: %s", openai_endpoint)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"api-key": api_key
|
||||||
|
}
|
||||||
|
logger.info("Generating content with GPT model: %s", payload['model'])
|
||||||
|
response = requests.post(
|
||||||
|
openai_endpoint,
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
if response.json()['error']['code'] == "context_length_exceeded":
|
||||||
|
logger.error("Context length exceeded. Retrying with a smaller context.")
|
||||||
|
payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
|
||||||
|
retry_response = requests.post(
|
||||||
|
openai_endpoint,
|
||||||
|
headers=headers,
|
||||||
|
json=payload
|
||||||
|
)
|
||||||
|
if retry_response.status_code != 200:
|
||||||
|
logger.error(
|
||||||
|
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
logger.error("Failed to call LLM: " + response.text)
|
||||||
|
time.sleep(5)
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
return response.json()['choices'][0]['message']['content']
|
||||||
|
elif self.model.startswith("gpt"):
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
||||||
@@ -1046,7 +1095,7 @@ class PromptAgent:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Failed to call LLM: " + str(e))
|
print("Failed to call LLM: " + str(e))
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid model: " + self.model)
|
raise ValueError("Invalid model: " + self.model)
|
||||||
|
|
||||||
|
|||||||
@@ -60,3 +60,4 @@ azure-mgmt-compute
|
|||||||
azure-mgmt-network
|
azure-mgmt-network
|
||||||
docker
|
docker
|
||||||
loguru
|
loguru
|
||||||
|
dotenv
|
||||||
|
|||||||
Reference in New Issue
Block a user