Clean Code; Refactor README

This commit is contained in:
Timothyxxx
2024-03-27 16:21:49 +08:00
parent ee8e9451b4
commit 26ed70ef70
6 changed files with 128 additions and 91 deletions

View File

@@ -180,6 +180,7 @@ def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
linearized_accessibility_tree += "[...]\n"
return linearized_accessibility_tree
class PromptAgent:
def __init__(
self,
@@ -572,22 +573,10 @@ class PromptAgent:
logger.debug("CLAUDE MESSAGE: %s", repr(claude_messages))
# headers = {
# "x-api-key": os.environ["ANTHROPIC_API_KEY"],
# "anthropic-version": "2023-06-01",
# "content-type": "application/json"
# }
# headers = {
# "Accept": "application / json",
# "Authorization": "Bearer " + os.environ["ANTHROPIC_API_KEY"],
# "User-Agent": "Apifox/1.0.0 (https://apifox.com)",
# "Content-Type": "application/json"
# }
headers = {
"Authorization": os.environ["ANTHROPIC_API_KEY"],
"Content-Type": "application/json"
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
payload = {
@@ -598,28 +587,21 @@ class PromptAgent:
"top_p": top_p
}
max_attempts = 20
attempt = 0
while attempt < max_attempts:
# response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload)
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers,
json=payload)
if response.status_code == 200:
result = response.json()['choices'][0]['message']['content']
break
else:
logger.error(f"Failed to call LLM: {response.text}")
time.sleep(10)
attempt += 1
response = requests.post(
"https://api.anthropic.com/v1/messages",
headers=headers,
json=payload
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
print("Exceeded maximum attempts to call LLM.")
result = ""
return result
return response.json()['content'][0]['text']
elif self.model.startswith("mistral"):
print("Call mistral")
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
@@ -652,7 +634,9 @@ class PromptAgent:
response = client.chat.completions.create(
messages=mistral_messages,
model=self.model,
max_tokens=max_tokens
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature
)
break
except:
@@ -670,7 +654,6 @@ class PromptAgent:
elif self.model.startswith("THUDM"):
# THUDM/cogagent-chat-hf
print("Call CogAgent")
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
@@ -703,7 +686,9 @@ class PromptAgent:
payload = {
"model": self.model,
"max_tokens": max_tokens,
"messages": cog_messages
"messages": cog_messages,
"temperature": temperature,
"top_p": top_p
}
base_url = "http://127.0.0.1:8000"
@@ -717,7 +702,6 @@ class PromptAgent:
print("Failed to call LLM: ", response.status_code)
return ""
elif self.model.startswith("gemini"):
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
@@ -802,7 +786,8 @@ class PromptAgent:
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
if payload["temperature"]:
logger.warning("Qwen model does not support temperature parameter, it will be ignored.")
qwen_messages = []
@@ -821,7 +806,9 @@ class PromptAgent:
response = dashscope.MultiModalConversation.call(
model='qwen-vl-plus',
messages=messages, # todo: add the hyperparameters
messages=messages,
max_length=max_tokens,
top_p=top_p,
)
# The response status_code is HTTPStatus.OK indicate success,
# otherwise indicate request is failed, you can get error code