Update agent.py mixtral
This commit is contained in:
@@ -568,7 +568,7 @@ class PromptAgent:
|
|||||||
top_p = payload["top_p"]
|
top_p = payload["top_p"]
|
||||||
temperature = payload["temperature"]
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
misrtal_messages = []
|
mistral_messages = []
|
||||||
|
|
||||||
for i, message in enumerate(messages):
|
for i, message in enumerate(messages):
|
||||||
mistral_message = {
|
mistral_message = {
|
||||||
@@ -579,13 +579,8 @@ class PromptAgent:
|
|||||||
for part in message["content"]:
|
for part in message["content"]:
|
||||||
mistral_message['content'] = part['text'] if part['type'] == "text" else ""
|
mistral_message['content'] = part['text'] if part['type'] == "text" else ""
|
||||||
|
|
||||||
misrtal_messages.append(mistral_message)
|
mistral_messages.append(mistral_message)
|
||||||
|
|
||||||
# openai.api_base = "http://localhost:8000/v1"
|
|
||||||
# response = openai.ChatCompletion.create(
|
|
||||||
# messages=misrtal_messages,
|
|
||||||
# model="Mixtral-8x7B-Instruct-v0.1"
|
|
||||||
# )
|
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
@@ -593,12 +588,23 @@ class PromptAgent:
|
|||||||
base_url='https://api.together.xyz',
|
base_url='https://api.together.xyz',
|
||||||
)
|
)
|
||||||
logger.info("Generating content with Mistral model: %s", self.model)
|
logger.info("Generating content with Mistral model: %s", self.model)
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
flag = 0
|
||||||
messages=misrtal_messages,
|
while True:
|
||||||
model=self.model,
|
try:
|
||||||
max_tokens=max_tokens
|
if flag > 20: break
|
||||||
)
|
response = client.chat.completions.create(
|
||||||
|
messages=mistral_messages,
|
||||||
|
model=self.model,
|
||||||
|
max_tokens=max_tokens
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
if flag == 0:
|
||||||
|
mistral_messages = [mistral_messages[0]] + mistral_messages[-1:]
|
||||||
|
else:
|
||||||
|
mistral_messages[-1]["content"] = ' '.join(mistral_messages[-1]["content"].split()[:-500])
|
||||||
|
flag = flag + 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return response.choices[0].message.content
|
return response.choices[0].message.content
|
||||||
|
|||||||
Reference in New Issue
Block a user