This commit is contained in:
tsuky_chen
2024-03-26 17:10:04 +08:00
9 changed files with 155 additions and 134 deletions

View File

@@ -146,7 +146,13 @@ class DesktopEnv(gym.Env):
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
# Get the screenshot and save to the image_path
screenshot = self.controller.get_screenshot()
max_retries = 20
for _ in range(max_retries):
screenshot = self.controller.get_screenshot()
if screenshot is not None:
break
time.sleep(1)
with open(image_path, "wb") as f:
f.write(screenshot)

View File

@@ -531,21 +531,45 @@ def _create_pywinauto_node(node: BaseWrapper, depth: int = 0, flag: Optional[str
# Value {{{ #
if hasattr(node, "get_step"):
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
try:
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
except:
pass
if hasattr(node, "value"):
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
try:
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
except:
pass
if hasattr(node, "get_value"):
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
try:
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
except:
pass
elif hasattr(node, "get_position"):
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
try:
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
except:
pass
if hasattr(node, "min_value"):
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
try:
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
except:
pass
elif hasattr(node, "get_range_min"):
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
try:
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
except:
pass
if hasattr(node, "max_value"):
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
try:
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
except:
pass
elif hasattr(node, "get_range_max"):
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
try:
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
except:
pass
# }}} Value #
attribute_dict["{{{:}}}class".format(_accessibility_ns_map["win"])] = str(type(node))

View File

@@ -6,3 +6,4 @@ requests
flask
numpy
lxml
pygame

View File

@@ -108,7 +108,7 @@
{
"type": "rule",
"rules": {
"expect": "project"
"expected": "project"
}
}
]

View File

@@ -9,7 +9,7 @@
"parameters": {
"files": [
{
"url": "https://docs.google.com/spreadsheets/d/13YL-KC__pav2qp3sFDs1BT2wZnpWGp7s/export?format=xlsx",
"url": "https://drive.google.com/uc?export=download&id=1B5GmhdVD07UeYj9Ox20DHsA_gaxEFQ6Z",
"path": "/home/user/Desktop/stock.xlsx"
}
]
@@ -36,7 +36,7 @@
},
"expected": {
"type": "cloud_file",
"path": "https://drive.google.com/uc?export=download&id=1oPPW_dozWGII5MRmdXdKKoEK5iBkd_8Q",
"path": "https://drive.google.com/uc?export=download&id=1wzlUL1gktA0d_j9W3WSSAAUcuKr5gw-n",
"dest": "result_gold.txt"
}
}

View File

@@ -304,15 +304,12 @@
"os": [
"94d95f96-9699-4208-98ba-3c3119edf9c2",
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
"a462a795-fdc7-4b23-b689-e8b6df786b78",
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
"e0df059f-28a6-4169-924f-b9623e7184cc",
"ddc75b62-7311-4af8-bfb3-859558542b36",
"b6781586-6346-41cd-935a-a6b1487918fc",
"b3d4a89c-53f2-4d6b-8b6a-541fb5d205fa",
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
@@ -322,8 +319,6 @@
"23393935-50c7-4a86-aeea-2b78fd089c5c",
"5812b315-e7bd-4265-b51f-863c02174c28",
"c288e301-e626-4b98-a1ab-159dcb162af5",
"cc9d4f34-1ca0-4a1b-8ff2-09302696acb9",
"c56de254-a3ec-414e-81a6-83d2ce8c41fa",
"4783cc41-c03c-4e1b-89b4-50658f642bd5",
"5c1075ca-bb34-46a3-a7a0-029bd7463e79",
"5ced85fc-fa1a-4217-95fd-0fb530545ce2",
@@ -376,7 +371,6 @@
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
"9439a27b-18ae-42d8-9778-5f68f891805e",
"ae506c68-352c-4094-9caa-ee9d42052317",
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
"930fdb3b-11a8-46fe-9bac-577332e2640e",
"276cc624-87ea-4f08-ab93-f770e3790175",

View File

@@ -1,7 +1,9 @@
{
"chrome": [
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
"35253b65-1c19-4304-8aa4-6884b8218fc0",
"a96b564e-dbe9-42c3-9ccf-b4498073438a"
],
"gimp": [
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
@@ -9,7 +11,8 @@
],
"libreoffice_calc": [
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
"42e0a640-4f19-4b28-973d-729602b5a4a7"
"42e0a640-4f19-4b28-973d-729602b5a4a7",
"abed40dc-063f-4598-8ba5-9fe749c0615d"
],
"libreoffice_impress": [
"5d901039-a89c-4bfb-967b-bf66f4df075e",
@@ -20,17 +23,14 @@
"0a0faba3-5580-44df-965d-f562a99b291c"
],
"multi_apps": [
"a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a",
"5990457f-2adb-467b-a4af-5c857c92d762",
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
"510f64c8-9bcc-4be1-8d30-638705850618",
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
"c867c42d-a52d-4a24-8ae3-f75d256b5618",
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
"f7dfbef3-7697-431c-883a-db8583a4e4f9",
"6d72aad6-187a-4392-a4c4-ed87269c51cf",
"f918266a-b3e0-4914-865d-4faa564f1aef",
"da52d699-e8d2-4dc5-9191-a2199e0b6a9b",
"74d5859f-ed66-4d3e-aa0e-93d7a592ce41",
"b5062e3e-641c-4e3a-907b-ac864d2e7652",
"48d05431-6cd5-4e76-82eb-12b60d823f7d",
@@ -38,53 +38,12 @@
"d1acdb87-bb67-4f30-84aa-990e56a09c92",
"deec51c9-3b1e-4b9e-993c-4776f20e8bb2",
"8e116af7-7db7-4e35-a68b-b0939c066c78",
"185f29bd-5da0-40a6-b69c-ba7f4e0324ef",
"2c1ebcd7-9c6d-4c9a-afad-900e381ecd5e",
"3a93cae4-ad3e-403e-8c12-65303b271818",
"1f18aa87-af6f-41ef-9853-cdb8f32ebdea",
"26150609-0da3-4a7d-8868-0faf9c5f01bb",
"7e287123-70ca-47b9-8521-47db09b69b14",
"e2392362-125e-4f76-a2ee-524b183a3412",
"26660ad1-6ebb-4f59-8cba-a8432dfe8d38",
"a82b78bb-7fde-4cb3-94a4-035baf10bcf0",
"36037439-2044-4b50-b9d1-875b5a332143",
"716a6079-22da-47f1-ba73-c9d58f986a38",
"a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a",
"6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a",
"da922383-bfa4-4cd3-bbad-6bebab3d7742",
"2373b66a-092d-44cb-bfd7-82e86e7a3b4d",
"81c425f5-78f3-4771-afd6-3d2973825947",
"227d2f97-562b-4ccb-ae47-a5ec9e142fbb",
"20236825-b5df-46e7-89bf-62e1d640a897",
"02ce9a50-7af2-47ed-8596-af0c230501f8",
"4c26e3f3-3a14-4d86-b44a-d3cedebbb487",
"09a37c51-e625-49f4-a514-20a773797a8a",
"3e3fc409-bff3-4905-bf16-c968eee3f807",
"415ef462-bed3-493a-ac36-ca8c6d23bf1b",
"9f3bb592-209d-43bc-bb47-d77d9df56504",
"dd60633f-2c72-42ba-8547-6f2c8cb0fdb0",
"3f05f3b9-29ba-4b6b-95aa-2204697ffc06",
"f8369178-fafe-40c2-adc4-b9b08a125456",
"778efd0a-153f-4842-9214-f05fc176b877",
"47f7c0ce-a5fb-4100-a5e6-65cd0e7429e5",
"c2751594-0cd5-4088-be1b-b5f2f9ec97c4",
"48c46dc7-fe04-4505-ade7-723cba1aa6f6",
"42d25c08-fb87-4927-8b65-93631280a26f",
"3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
"d68204bf-11c1-4b13-b48b-d303c73d4bf6",
"91190194-f406-4cd6-b3f9-c43fac942b22",
"7f35355e-02a6-45b5-b140-f0be698bcf85",
"98e8e339-5f91-4ed2-b2b2-12647cb134f4",
"df67aebb-fb3a-44fd-b75b-51b6012df509",
"5df7b33a-9f77-4101-823e-02f863e1c1ae",
"22a4636f-8179-4357-8e87-d1743ece1f81",
"236833a3-5704-47fc-888c-4f298f09f799"
"2373b66a-092d-44cb-bfd7-82e86e7a3b4d"
],
"os": [
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
"5812b315-e7bd-4265-b51f-863c02174c28",
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82"
"5812b315-e7bd-4265-b51f-863c02174c28"
],
"thunderbird": [
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
@@ -96,6 +55,7 @@
],
"vs_code": [
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
"53ad5833-3455-407b-bbc6-45b4c79ab8fb"
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
"276cc624-87ea-4f08-ab93-f770e3790175"
]
}

View File

@@ -80,9 +80,11 @@ def filter_nodes(root: ET, platform="ubuntu", check_image=False):
return filtered_nodes
def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sampling_ratio=1.0):
# Load the screenshot image
image = Image.open(image_file_path)
if float(down_sampling_ratio) != 1.0:
image = image.resize((int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio)))
draw = ImageDraw.Draw(image)
marks = []
drew_nodes = []
@@ -108,6 +110,15 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
coords = tuple(map(int, coords_str.strip('()').split(', ')))
size = tuple(map(int, size_str.strip('()').split(', ')))
import copy
original_coords = copy.deepcopy(coords)
original_size = copy.deepcopy(size)
if float(down_sampling_ratio) != 1.0:
# Downsample the coordinates and size
coords = tuple(int(coord * down_sampling_ratio) for coord in coords)
size = tuple(int(s * down_sampling_ratio) for s in size)
# Check for negative sizes
if size[0] <= 0 or size[1] <= 0:
raise ValueError(f"Size must be positive, got: {size}")
@@ -138,7 +149,7 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
# each mark is an x, y, w, h tuple
marks.append([coords[0], coords[1], size[0], size[1]])
marks.append([original_coords[0], original_coords[1], original_size[0], original_size[1]])
drew_nodes.append(_node)
if _node.text:

View File

@@ -6,17 +6,16 @@ import re
import time
import uuid
import xml.etree.ElementTree as ET
import numpy as np
from http import HTTPStatus
from io import BytesIO
from typing import Dict, List, Tuple, Union
from typing import Dict, List
import backoff
import dashscope
import google.generativeai as genai
import openai
import requests
import cv2
import tiktoken
from PIL import Image
from google.api_core.exceptions import InvalidArgument
@@ -28,14 +27,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
logger = logging.getLogger("desktopenv.agent")
def downsample_image(img: Union[str, np.ndarray], ratio: Tuple[float, float]):
fx, fy = ratio
if isinstance(img, str):
img = cv2.imread(img)
resized = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
return resized
# Function to encode the image
def encode_image(image_path):
@@ -43,49 +34,49 @@ def encode_image(image_path):
return base64.b64encode(image_file.read()).decode('utf-8')
def linearize_accessibility_tree(accessibility_tree):
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
# leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
# Linearize the accessibility tree nodes into a table format
for node in filtered_nodes:
#linearized_accessibility_tree += node.tag + "\t"
#linearized_accessibility_tree += node.attrib.get('name') + "\t"
# linearized_accessibility_tree += node.tag + "\t"
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
if node.text:
text = ( node.text if '"' not in node.text\
else '"{:}"'.format(node.text.replace('"', '""'))
)
text = (node.text if '"' not in node.text \
else '"{:}"'.format(node.text.replace('"', '""'))
)
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
text = (text if '"' not in text\
else '"{:}"'.format(text.replace('"', '""'))
)
text = (text if '"' not in text \
else '"{:}"'.format(text.replace('"', '""'))
)
else:
text = '""'
#linearized_accessibility_tree += node.attrib.get(
#, "") + "\t"
#linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
# linearized_accessibility_tree += node.attrib.get(
# , "") + "\t"
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
linearized_accessibility_tree.append(
"{:}\t{:}\t{:}\t{:}\t{:}".format(
node.tag, node.get("name", ""), text
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
)
)
"{:}\t{:}\t{:}\t{:}\t{:}".format(
node.tag, node.get("name", ""), text
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
)
)
return "\n".join(linearized_accessibility_tree)
def tag_screenshot(screenshot, accessibility_tree):
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
# Creat a tmp file to store the screenshot in random name
uuid_str = str(uuid.uuid4())
os.makedirs("tmp/images", exist_ok=True)
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
# Make tag screenshot
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
@@ -181,9 +172,18 @@ def parse_code_from_som_string(input_string, masks):
return actions
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
enc = tiktoken.encoding_for_model("gpt-4")
tokens = enc.encode(linearized_accessibility_tree)
if len(tokens) > max_tokens:
linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
linearized_accessibility_tree += "[...]\n"
return linearized_accessibility_tree
class PromptAgent:
def __init__(
self,
platform="ubuntu",
model="gpt-4-vision-preview",
max_tokens=1500,
top_p=0.9,
@@ -191,8 +191,10 @@ class PromptAgent:
action_space="computer_13",
observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
max_trajectory_length=3
max_trajectory_length=3,
a11y_tree_max_tokens=10000
):
self.platform = platform
self.model = model
self.max_tokens = max_tokens
self.top_p = top_p
@@ -200,6 +202,7 @@ class PromptAgent:
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.a11y_tree_max_tokens = a11y_tree_max_tokens
self.thoughts = []
self.actions = []
@@ -261,9 +264,14 @@ class PromptAgent:
, "The number of observations and actions should be the same."
if len(self.observations) > self.max_trajectory_length:
_observations = self.observations[-self.max_trajectory_length:]
_actions = self.actions[-self.max_trajectory_length:]
_thoughts = self.thoughts[-self.max_trajectory_length:]
if self.max_trajectory_length == 0:
_observations = []
_actions = []
_thoughts = []
else:
_observations = self.observations[-self.max_trajectory_length:]
_actions = self.actions[-self.max_trajectory_length:]
_thoughts = self.thoughts[-self.max_trajectory_length:]
else:
_observations = self.observations
_actions = self.actions
@@ -360,9 +368,14 @@ class PromptAgent:
# {{{1
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if linearized_accessibility_tree:
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
self.a11y_tree_max_tokens)
if self.observation_type == "screenshot_a11y_tree":
self.observations.append({
"screenshot": base64_image,
@@ -394,9 +407,14 @@ class PromptAgent:
]
})
elif self.observation_type == "a11y_tree":
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
platform=self.platform)
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if linearized_accessibility_tree:
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
self.a11y_tree_max_tokens)
self.observations.append({
"screenshot": None,
"accessibility_tree": linearized_accessibility_tree
@@ -414,10 +432,15 @@ class PromptAgent:
})
elif self.observation_type == "som":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
"accessibility_tree"], self.platform)
base64_image = encode_image(tagged_screenshot)
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if linearized_accessibility_tree:
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
self.a11y_tree_max_tokens)
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
@@ -446,7 +469,7 @@ class PromptAgent:
# with open("messages.json", "w") as f:
# f.write(json.dumps(messages, indent=4))
#logger.info("PROMPT: %s", messages)
# logger.info("PROMPT: %s", messages)
response = self.call_llm({
"model": self.model,
@@ -567,8 +590,6 @@ class PromptAgent:
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"max_tokens": max_tokens,
@@ -581,7 +602,8 @@ class PromptAgent:
attempt = 0
while attempt < max_attempts:
# response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload)
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers, json=payload)
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers,
json=payload)
if response.status_code == 200:
result = response.json()['choices'][0]['message']['content']
break
@@ -592,7 +614,7 @@ class PromptAgent:
else:
print("Exceeded maximum attempts to call LLM.")
result = ""
return result
@@ -616,14 +638,13 @@ class PromptAgent:
mistral_messages.append(mistral_message)
from openai import OpenAI
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
base_url='https://api.together.xyz',
)
logger.info("Generating content with Mistral model: %s", self.model)
flag = 0
while True:
try:
@@ -752,26 +773,30 @@ class PromptAgent:
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
genai.configure(api_key=api_key)
logger.info("Generating content with Gemini model: %s", self.model)
response = genai.GenerativeModel(self.model).generate_content(
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
"harassment": "block_none",
"hate": "block_none",
"sex": "block_none",
"danger": "block_none"
}
)
request_options = {"timeout": 120}
gemini_model = genai.GenerativeModel(self.model)
try:
response = gemini_model.generate_content(
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
"harassment": "block_none",
"hate": "block_none",
"sex": "block_none",
"danger": "block_none"
},
request_options=request_options
)
return response.text
except Exception as e:
logger.error("Meet exception when calling Gemini API, " + str(e))
logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e))
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
return ""
elif self.model.startswith("qwen"):
messages = payload["messages"]