Merge branch 'main' of https://github.com/xlang-ai/DesktopEnv
This commit is contained in:
@@ -146,7 +146,13 @@ class DesktopEnv(gym.Env):
|
||||
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
|
||||
|
||||
# Get the screenshot and save to the image_path
|
||||
screenshot = self.controller.get_screenshot()
|
||||
max_retries = 20
|
||||
for _ in range(max_retries):
|
||||
screenshot = self.controller.get_screenshot()
|
||||
if screenshot is not None:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(screenshot)
|
||||
|
||||
|
||||
@@ -531,21 +531,45 @@ def _create_pywinauto_node(node: BaseWrapper, depth: int = 0, flag: Optional[str
|
||||
|
||||
# Value {{{ #
|
||||
if hasattr(node, "get_step"):
|
||||
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
|
||||
try:
|
||||
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
|
||||
except:
|
||||
pass
|
||||
if hasattr(node, "value"):
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
|
||||
try:
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
|
||||
except:
|
||||
pass
|
||||
if hasattr(node, "get_value"):
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
|
||||
try:
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
|
||||
except:
|
||||
pass
|
||||
elif hasattr(node, "get_position"):
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
|
||||
try:
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
|
||||
except:
|
||||
pass
|
||||
if hasattr(node, "min_value"):
|
||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
|
||||
try:
|
||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
|
||||
except:
|
||||
pass
|
||||
elif hasattr(node, "get_range_min"):
|
||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
|
||||
try:
|
||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
|
||||
except:
|
||||
pass
|
||||
if hasattr(node, "max_value"):
|
||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
|
||||
try:
|
||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
|
||||
except:
|
||||
pass
|
||||
elif hasattr(node, "get_range_max"):
|
||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
|
||||
try:
|
||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
|
||||
except:
|
||||
pass
|
||||
# }}} Value #
|
||||
|
||||
attribute_dict["{{{:}}}class".format(_accessibility_ns_map["win"])] = str(type(node))
|
||||
|
||||
@@ -6,3 +6,4 @@ requests
|
||||
flask
|
||||
numpy
|
||||
lxml
|
||||
pygame
|
||||
|
||||
@@ -108,7 +108,7 @@
|
||||
{
|
||||
"type": "rule",
|
||||
"rules": {
|
||||
"expect": "project"
|
||||
"expected": "project"
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"parameters": {
|
||||
"files": [
|
||||
{
|
||||
"url": "https://docs.google.com/spreadsheets/d/13YL-KC__pav2qp3sFDs1BT2wZnpWGp7s/export?format=xlsx",
|
||||
"url": "https://drive.google.com/uc?export=download&id=1B5GmhdVD07UeYj9Ox20DHsA_gaxEFQ6Z",
|
||||
"path": "/home/user/Desktop/stock.xlsx"
|
||||
}
|
||||
]
|
||||
@@ -36,7 +36,7 @@
|
||||
},
|
||||
"expected": {
|
||||
"type": "cloud_file",
|
||||
"path": "https://drive.google.com/uc?export=download&id=1oPPW_dozWGII5MRmdXdKKoEK5iBkd_8Q",
|
||||
"path": "https://drive.google.com/uc?export=download&id=1wzlUL1gktA0d_j9W3WSSAAUcuKr5gw-n",
|
||||
"dest": "result_gold.txt"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,15 +304,12 @@
|
||||
"os": [
|
||||
"94d95f96-9699-4208-98ba-3c3119edf9c2",
|
||||
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
|
||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
|
||||
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
|
||||
"a462a795-fdc7-4b23-b689-e8b6df786b78",
|
||||
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
|
||||
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
|
||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||
"e0df059f-28a6-4169-924f-b9623e7184cc",
|
||||
"ddc75b62-7311-4af8-bfb3-859558542b36",
|
||||
"b6781586-6346-41cd-935a-a6b1487918fc",
|
||||
"b3d4a89c-53f2-4d6b-8b6a-541fb5d205fa",
|
||||
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
|
||||
@@ -322,8 +319,6 @@
|
||||
"23393935-50c7-4a86-aeea-2b78fd089c5c",
|
||||
"5812b315-e7bd-4265-b51f-863c02174c28",
|
||||
"c288e301-e626-4b98-a1ab-159dcb162af5",
|
||||
"cc9d4f34-1ca0-4a1b-8ff2-09302696acb9",
|
||||
"c56de254-a3ec-414e-81a6-83d2ce8c41fa",
|
||||
"4783cc41-c03c-4e1b-89b4-50658f642bd5",
|
||||
"5c1075ca-bb34-46a3-a7a0-029bd7463e79",
|
||||
"5ced85fc-fa1a-4217-95fd-0fb530545ce2",
|
||||
@@ -376,7 +371,6 @@
|
||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||
"ae506c68-352c-4094-9caa-ee9d42052317",
|
||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
{
|
||||
"chrome": [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||
"a96b564e-dbe9-42c3-9ccf-b4498073438a"
|
||||
],
|
||||
"gimp": [
|
||||
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
|
||||
@@ -9,7 +11,8 @@
|
||||
],
|
||||
"libreoffice_calc": [
|
||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||
"42e0a640-4f19-4b28-973d-729602b5a4a7"
|
||||
"42e0a640-4f19-4b28-973d-729602b5a4a7",
|
||||
"abed40dc-063f-4598-8ba5-9fe749c0615d"
|
||||
],
|
||||
"libreoffice_impress": [
|
||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||
@@ -20,17 +23,14 @@
|
||||
"0a0faba3-5580-44df-965d-f562a99b291c"
|
||||
],
|
||||
"multi_apps": [
|
||||
"a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a",
|
||||
"5990457f-2adb-467b-a4af-5c857c92d762",
|
||||
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
||||
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
||||
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
||||
"510f64c8-9bcc-4be1-8d30-638705850618",
|
||||
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
||||
"c867c42d-a52d-4a24-8ae3-f75d256b5618",
|
||||
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
||||
"f7dfbef3-7697-431c-883a-db8583a4e4f9",
|
||||
"6d72aad6-187a-4392-a4c4-ed87269c51cf",
|
||||
"f918266a-b3e0-4914-865d-4faa564f1aef",
|
||||
"da52d699-e8d2-4dc5-9191-a2199e0b6a9b",
|
||||
"74d5859f-ed66-4d3e-aa0e-93d7a592ce41",
|
||||
"b5062e3e-641c-4e3a-907b-ac864d2e7652",
|
||||
"48d05431-6cd5-4e76-82eb-12b60d823f7d",
|
||||
@@ -38,53 +38,12 @@
|
||||
"d1acdb87-bb67-4f30-84aa-990e56a09c92",
|
||||
"deec51c9-3b1e-4b9e-993c-4776f20e8bb2",
|
||||
"8e116af7-7db7-4e35-a68b-b0939c066c78",
|
||||
"185f29bd-5da0-40a6-b69c-ba7f4e0324ef",
|
||||
"2c1ebcd7-9c6d-4c9a-afad-900e381ecd5e",
|
||||
"3a93cae4-ad3e-403e-8c12-65303b271818",
|
||||
"1f18aa87-af6f-41ef-9853-cdb8f32ebdea",
|
||||
"26150609-0da3-4a7d-8868-0faf9c5f01bb",
|
||||
"7e287123-70ca-47b9-8521-47db09b69b14",
|
||||
"e2392362-125e-4f76-a2ee-524b183a3412",
|
||||
"26660ad1-6ebb-4f59-8cba-a8432dfe8d38",
|
||||
"a82b78bb-7fde-4cb3-94a4-035baf10bcf0",
|
||||
"36037439-2044-4b50-b9d1-875b5a332143",
|
||||
"716a6079-22da-47f1-ba73-c9d58f986a38",
|
||||
"a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a",
|
||||
"6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a",
|
||||
"da922383-bfa4-4cd3-bbad-6bebab3d7742",
|
||||
"2373b66a-092d-44cb-bfd7-82e86e7a3b4d",
|
||||
"81c425f5-78f3-4771-afd6-3d2973825947",
|
||||
"227d2f97-562b-4ccb-ae47-a5ec9e142fbb",
|
||||
"20236825-b5df-46e7-89bf-62e1d640a897",
|
||||
"02ce9a50-7af2-47ed-8596-af0c230501f8",
|
||||
"4c26e3f3-3a14-4d86-b44a-d3cedebbb487",
|
||||
"09a37c51-e625-49f4-a514-20a773797a8a",
|
||||
"3e3fc409-bff3-4905-bf16-c968eee3f807",
|
||||
"415ef462-bed3-493a-ac36-ca8c6d23bf1b",
|
||||
"9f3bb592-209d-43bc-bb47-d77d9df56504",
|
||||
"dd60633f-2c72-42ba-8547-6f2c8cb0fdb0",
|
||||
"3f05f3b9-29ba-4b6b-95aa-2204697ffc06",
|
||||
"f8369178-fafe-40c2-adc4-b9b08a125456",
|
||||
"778efd0a-153f-4842-9214-f05fc176b877",
|
||||
"47f7c0ce-a5fb-4100-a5e6-65cd0e7429e5",
|
||||
"c2751594-0cd5-4088-be1b-b5f2f9ec97c4",
|
||||
"48c46dc7-fe04-4505-ade7-723cba1aa6f6",
|
||||
"42d25c08-fb87-4927-8b65-93631280a26f",
|
||||
"3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
|
||||
"d68204bf-11c1-4b13-b48b-d303c73d4bf6",
|
||||
"91190194-f406-4cd6-b3f9-c43fac942b22",
|
||||
"7f35355e-02a6-45b5-b140-f0be698bcf85",
|
||||
"98e8e339-5f91-4ed2-b2b2-12647cb134f4",
|
||||
"df67aebb-fb3a-44fd-b75b-51b6012df509",
|
||||
"5df7b33a-9f77-4101-823e-02f863e1c1ae",
|
||||
"22a4636f-8179-4357-8e87-d1743ece1f81",
|
||||
"236833a3-5704-47fc-888c-4f298f09f799"
|
||||
"2373b66a-092d-44cb-bfd7-82e86e7a3b4d"
|
||||
],
|
||||
"os": [
|
||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||
"5812b315-e7bd-4265-b51f-863c02174c28",
|
||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82"
|
||||
"5812b315-e7bd-4265-b51f-863c02174c28"
|
||||
],
|
||||
"thunderbird": [
|
||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||
@@ -96,6 +55,7 @@
|
||||
],
|
||||
"vs_code": [
|
||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb"
|
||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
||||
"276cc624-87ea-4f08-ab93-f770e3790175"
|
||||
]
|
||||
}
|
||||
@@ -80,9 +80,11 @@ def filter_nodes(root: ET, platform="ubuntu", check_image=False):
|
||||
return filtered_nodes
|
||||
|
||||
|
||||
def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
|
||||
def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sampling_ratio=1.0):
|
||||
# Load the screenshot image
|
||||
image = Image.open(image_file_path)
|
||||
if float(down_sampling_ratio) != 1.0:
|
||||
image = image.resize((int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio)))
|
||||
draw = ImageDraw.Draw(image)
|
||||
marks = []
|
||||
drew_nodes = []
|
||||
@@ -108,6 +110,15 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
|
||||
coords = tuple(map(int, coords_str.strip('()').split(', ')))
|
||||
size = tuple(map(int, size_str.strip('()').split(', ')))
|
||||
|
||||
import copy
|
||||
original_coords = copy.deepcopy(coords)
|
||||
original_size = copy.deepcopy(size)
|
||||
|
||||
if float(down_sampling_ratio) != 1.0:
|
||||
# Downsample the coordinates and size
|
||||
coords = tuple(int(coord * down_sampling_ratio) for coord in coords)
|
||||
size = tuple(int(s * down_sampling_ratio) for s in size)
|
||||
|
||||
# Check for negative sizes
|
||||
if size[0] <= 0 or size[1] <= 0:
|
||||
raise ValueError(f"Size must be positive, got: {size}")
|
||||
@@ -138,7 +149,7 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
|
||||
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
|
||||
|
||||
# each mark is an x, y, w, h tuple
|
||||
marks.append([coords[0], coords[1], size[0], size[1]])
|
||||
marks.append([original_coords[0], original_coords[1], original_size[0], original_size[1]])
|
||||
drew_nodes.append(_node)
|
||||
|
||||
if _node.text:
|
||||
|
||||
@@ -6,17 +6,16 @@ import re
|
||||
import time
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
import numpy as np
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, List
|
||||
|
||||
import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
import requests
|
||||
import cv2
|
||||
import tiktoken
|
||||
from PIL import Image
|
||||
from google.api_core.exceptions import InvalidArgument
|
||||
|
||||
@@ -28,14 +27,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
def downsample_image(img: Union[str, np.ndarray], ratio: Tuple[float, float]):
|
||||
fx, fy = ratio
|
||||
if isinstance(img, str):
|
||||
img = cv2.imread(img)
|
||||
|
||||
resized = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
||||
return resized
|
||||
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_path):
|
||||
@@ -43,49 +34,49 @@ def encode_image(image_path):
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree):
|
||||
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||||
|
||||
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
|
||||
# Linearize the accessibility tree nodes into a table format
|
||||
|
||||
for node in filtered_nodes:
|
||||
#linearized_accessibility_tree += node.tag + "\t"
|
||||
#linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||
# linearized_accessibility_tree += node.tag + "\t"
|
||||
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||
if node.text:
|
||||
text = ( node.text if '"' not in node.text\
|
||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||
)
|
||||
text = (node.text if '"' not in node.text \
|
||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||
)
|
||||
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
||||
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
||||
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
||||
text = (text if '"' not in text\
|
||||
else '"{:}"'.format(text.replace('"', '""'))
|
||||
)
|
||||
text = (text if '"' not in text \
|
||||
else '"{:}"'.format(text.replace('"', '""'))
|
||||
)
|
||||
else:
|
||||
text = '""'
|
||||
#linearized_accessibility_tree += node.attrib.get(
|
||||
#, "") + "\t"
|
||||
#linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
||||
# linearized_accessibility_tree += node.attrib.get(
|
||||
# , "") + "\t"
|
||||
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
||||
linearized_accessibility_tree.append(
|
||||
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||
node.tag, node.get("name", ""), text
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
|
||||
)
|
||||
)
|
||||
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||
node.tag, node.get("name", ""), text
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(linearized_accessibility_tree)
|
||||
|
||||
|
||||
def tag_screenshot(screenshot, accessibility_tree):
|
||||
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
|
||||
# Creat a tmp file to store the screenshot in random name
|
||||
uuid_str = str(uuid.uuid4())
|
||||
os.makedirs("tmp/images", exist_ok=True)
|
||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
|
||||
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
|
||||
# Make tag screenshot
|
||||
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||
|
||||
@@ -181,9 +172,18 @@ def parse_code_from_som_string(input_string, masks):
|
||||
return actions
|
||||
|
||||
|
||||
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
|
||||
enc = tiktoken.encoding_for_model("gpt-4")
|
||||
tokens = enc.encode(linearized_accessibility_tree)
|
||||
if len(tokens) > max_tokens:
|
||||
linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
|
||||
linearized_accessibility_tree += "[...]\n"
|
||||
return linearized_accessibility_tree
|
||||
|
||||
class PromptAgent:
|
||||
def __init__(
|
||||
self,
|
||||
platform="ubuntu",
|
||||
model="gpt-4-vision-preview",
|
||||
max_tokens=1500,
|
||||
top_p=0.9,
|
||||
@@ -191,8 +191,10 @@ class PromptAgent:
|
||||
action_space="computer_13",
|
||||
observation_type="screenshot_a11y_tree",
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
max_trajectory_length=3
|
||||
max_trajectory_length=3,
|
||||
a11y_tree_max_tokens=10000
|
||||
):
|
||||
self.platform = platform
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
@@ -200,6 +202,7 @@ class PromptAgent:
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
@@ -261,9 +264,14 @@ class PromptAgent:
|
||||
, "The number of observations and actions should be the same."
|
||||
|
||||
if len(self.observations) > self.max_trajectory_length:
|
||||
_observations = self.observations[-self.max_trajectory_length:]
|
||||
_actions = self.actions[-self.max_trajectory_length:]
|
||||
_thoughts = self.thoughts[-self.max_trajectory_length:]
|
||||
if self.max_trajectory_length == 0:
|
||||
_observations = []
|
||||
_actions = []
|
||||
_thoughts = []
|
||||
else:
|
||||
_observations = self.observations[-self.max_trajectory_length:]
|
||||
_actions = self.actions[-self.max_trajectory_length:]
|
||||
_thoughts = self.thoughts[-self.max_trajectory_length:]
|
||||
else:
|
||||
_observations = self.observations
|
||||
_actions = self.actions
|
||||
@@ -360,9 +368,14 @@ class PromptAgent:
|
||||
# {{{1
|
||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||
base64_image = encode_image(obs["screenshot"])
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
|
||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
|
||||
if linearized_accessibility_tree:
|
||||
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||
self.a11y_tree_max_tokens)
|
||||
|
||||
if self.observation_type == "screenshot_a11y_tree":
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
@@ -394,9 +407,14 @@ class PromptAgent:
|
||||
]
|
||||
})
|
||||
elif self.observation_type == "a11y_tree":
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||
platform=self.platform)
|
||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
|
||||
if linearized_accessibility_tree:
|
||||
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||
self.a11y_tree_max_tokens)
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": None,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
@@ -414,10 +432,15 @@ class PromptAgent:
|
||||
})
|
||||
elif self.observation_type == "som":
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
|
||||
"accessibility_tree"], self.platform)
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
|
||||
if linearized_accessibility_tree:
|
||||
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||
self.a11y_tree_max_tokens)
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
@@ -446,7 +469,7 @@ class PromptAgent:
|
||||
# with open("messages.json", "w") as f:
|
||||
# f.write(json.dumps(messages, indent=4))
|
||||
|
||||
#logger.info("PROMPT: %s", messages)
|
||||
# logger.info("PROMPT: %s", messages)
|
||||
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
@@ -567,8 +590,6 @@ class PromptAgent:
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -581,7 +602,8 @@ class PromptAgent:
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
# response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload)
|
||||
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers, json=payload)
|
||||
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers,
|
||||
json=payload)
|
||||
if response.status_code == 200:
|
||||
result = response.json()['choices'][0]['message']['content']
|
||||
break
|
||||
@@ -592,7 +614,7 @@ class PromptAgent:
|
||||
else:
|
||||
print("Exceeded maximum attempts to call LLM.")
|
||||
result = ""
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -616,14 +638,13 @@ class PromptAgent:
|
||||
|
||||
mistral_messages.append(mistral_message)
|
||||
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
||||
base_url='https://api.together.xyz',
|
||||
)
|
||||
logger.info("Generating content with Mistral model: %s", self.model)
|
||||
|
||||
|
||||
flag = 0
|
||||
while True:
|
||||
try:
|
||||
@@ -752,26 +773,30 @@ class PromptAgent:
|
||||
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
||||
genai.configure(api_key=api_key)
|
||||
logger.info("Generating content with Gemini model: %s", self.model)
|
||||
response = genai.GenerativeModel(self.model).generate_content(
|
||||
gemini_messages,
|
||||
generation_config={
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature
|
||||
},
|
||||
safety_settings={
|
||||
"harassment": "block_none",
|
||||
"hate": "block_none",
|
||||
"sex": "block_none",
|
||||
"danger": "block_none"
|
||||
}
|
||||
)
|
||||
|
||||
request_options = {"timeout": 120}
|
||||
gemini_model = genai.GenerativeModel(self.model)
|
||||
try:
|
||||
response = gemini_model.generate_content(
|
||||
gemini_messages,
|
||||
generation_config={
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature
|
||||
},
|
||||
safety_settings={
|
||||
"harassment": "block_none",
|
||||
"hate": "block_none",
|
||||
"sex": "block_none",
|
||||
"danger": "block_none"
|
||||
},
|
||||
request_options=request_options
|
||||
)
|
||||
return response.text
|
||||
except Exception as e:
|
||||
logger.error("Meet exception when calling Gemini API, " + str(e))
|
||||
logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e))
|
||||
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
|
||||
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
|
||||
return ""
|
||||
elif self.model.startswith("qwen"):
|
||||
messages = payload["messages"]
|
||||
|
||||
Reference in New Issue
Block a user