Update README and ROADMAP; Fix typos; optimize the code for llm calling in agent.py

This commit is contained in:
Timothyxxx
2024-04-26 13:32:41 +08:00
parent 974c1a1387
commit 97b567a287
4 changed files with 90 additions and 89 deletions

View File

@@ -1,5 +1,4 @@
import base64
import hashlib
import json
import logging
import os
@@ -9,7 +8,6 @@ import time
import xml.etree.ElementTree as ET
from http import HTTPStatus
from io import BytesIO
from pathlib import Path
from typing import Dict, List
import backoff
@@ -19,7 +17,7 @@ import openai
import requests
import tiktoken
from PIL import Image
from google.api_core.exceptions import InvalidArgument
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
@@ -487,13 +485,17 @@ class PromptAgent:
# logger.info("PROMPT: %s", messages)
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
})
try:
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
})
except Exception as e:
logger.error("Failed to call" + self.model + ", Error: " + str(e))
response = ""
logger.info("RESPONSE: %s", response)
@@ -512,10 +514,18 @@ class PromptAgent:
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
(openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
InvalidArgument),
(
# OpenAI exceptions
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
# Google exceptions
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
),
max_tries=5
)
def call_llm(self, payload):
@@ -767,29 +777,25 @@ class PromptAgent:
logger.info("Generating content with Gemini model: %s", self.model)
request_options = {"timeout": 120}
gemini_model = genai.GenerativeModel(self.model)
try:
response = gemini_model.generate_content(
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
"harassment": "block_none",
"hate": "block_none",
"sex": "block_none",
"danger": "block_none"
},
request_options=request_options
)
return response.text
except Exception as e:
logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e))
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
return ""
response = gemini_model.generate_content(
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
"harassment": "block_none",
"hate": "block_none",
"sex": "block_none",
"danger": "block_none"
},
request_options=request_options
)
return response.text
elif self.model == "gemini-1.5-pro-latest":
messages = payload["messages"]
@@ -797,19 +803,6 @@ class PromptAgent:
top_p = payload["top_p"]
temperature = payload["temperature"]
uploaded_files = []
# def upload_if_needed(pathname: str) -> list[str]:
# path = Path(pathname)
# hash_id = hashlib.sha256(path.read_bytes()).hexdigest()
# try:
# existing_file = genai.get_file(name=hash_id)
# return [existing_file.uri]
# except:
# pass
# uploaded_files.append(genai.upload_file(path=path, display_name=hash_id))
# return [uploaded_files[-1].uri]
gemini_messages = []
for i, message in enumerate(messages):
role_mapping = {
@@ -818,21 +811,23 @@ class PromptAgent:
"system": "system"
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
gemini_message = {
"role": role_mapping[message["role"]],
"parts": []
}
# The gemini only support the last image as single image input
for part in message["content"]:
gemini_message = {
"role": role_mapping[message["role"]],
"parts": []
}
if part['type'] == "image_url":
gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url']))
# Put the image at the beginning of the message
gemini_message['parts'].insert(0, encoded_img_to_pil_img(part['image_url']['url']))
elif part['type'] == "text":
gemini_message['parts'].append(part['text'])
else:
raise ValueError("Invalid content type: " + part['type'])
gemini_messages.append(gemini_message)
gemini_messages.append(gemini_message)
# the system message of gemini-1.5-pro-latest need to be inputted through model initialization parameter
system_instruction = None
@@ -849,33 +844,34 @@ class PromptAgent:
self.model,
system_instruction=system_instruction
)
try:
response = gemini_model.generate_content(
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
"harassment": "block_none",
"hate": "block_none",
"sex": "block_none",
"danger": "block_none"
},
request_options=request_options
)
for uploaded_file in uploaded_files:
genai.delete_file(name=uploaded_file.name)
return response.text
except Exception as e:
logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e))
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
for uploaded_file in uploaded_files:
genai.delete_file(name=uploaded_file.name)
return ""
with open("response.json", "w") as f:
messages_to_save = []
for message in gemini_messages:
messages_to_save.append({
"role": message["role"],
"content": [part if isinstance(part, str) else "image" for part in message["parts"]]
})
json.dump(messages_to_save, f, indent=4)
response = gemini_model.generate_content(
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
"harassment": "block_none",
"hate": "block_none",
"sex": "block_none",
"danger": "block_none"
},
request_options=request_options
)
return response.text
elif self.model.startswith("qwen"):
messages = payload["messages"]