Merge branch 'main' of https://github.com/xlang-ai/DesktopEnv
This commit is contained in:
@@ -146,7 +146,13 @@ class DesktopEnv(gym.Env):
|
|||||||
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
|
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
|
||||||
|
|
||||||
# Get the screenshot and save to the image_path
|
# Get the screenshot and save to the image_path
|
||||||
screenshot = self.controller.get_screenshot()
|
max_retries = 20
|
||||||
|
for _ in range(max_retries):
|
||||||
|
screenshot = self.controller.get_screenshot()
|
||||||
|
if screenshot is not None:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
with open(image_path, "wb") as f:
|
with open(image_path, "wb") as f:
|
||||||
f.write(screenshot)
|
f.write(screenshot)
|
||||||
|
|
||||||
|
|||||||
@@ -531,21 +531,45 @@ def _create_pywinauto_node(node: BaseWrapper, depth: int = 0, flag: Optional[str
|
|||||||
|
|
||||||
# Value {{{ #
|
# Value {{{ #
|
||||||
if hasattr(node, "get_step"):
|
if hasattr(node, "get_step"):
|
||||||
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
|
try:
|
||||||
|
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
if hasattr(node, "value"):
|
if hasattr(node, "value"):
|
||||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
|
try:
|
||||||
|
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
if hasattr(node, "get_value"):
|
if hasattr(node, "get_value"):
|
||||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
|
try:
|
||||||
|
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
elif hasattr(node, "get_position"):
|
elif hasattr(node, "get_position"):
|
||||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
|
try:
|
||||||
|
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
if hasattr(node, "min_value"):
|
if hasattr(node, "min_value"):
|
||||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
|
try:
|
||||||
|
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
elif hasattr(node, "get_range_min"):
|
elif hasattr(node, "get_range_min"):
|
||||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
|
try:
|
||||||
|
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
if hasattr(node, "max_value"):
|
if hasattr(node, "max_value"):
|
||||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
|
try:
|
||||||
|
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
elif hasattr(node, "get_range_max"):
|
elif hasattr(node, "get_range_max"):
|
||||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
|
try:
|
||||||
|
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
# }}} Value #
|
# }}} Value #
|
||||||
|
|
||||||
attribute_dict["{{{:}}}class".format(_accessibility_ns_map["win"])] = str(type(node))
|
attribute_dict["{{{:}}}class".format(_accessibility_ns_map["win"])] = str(type(node))
|
||||||
|
|||||||
@@ -6,3 +6,4 @@ requests
|
|||||||
flask
|
flask
|
||||||
numpy
|
numpy
|
||||||
lxml
|
lxml
|
||||||
|
pygame
|
||||||
|
|||||||
@@ -108,7 +108,7 @@
|
|||||||
{
|
{
|
||||||
"type": "rule",
|
"type": "rule",
|
||||||
"rules": {
|
"rules": {
|
||||||
"expect": "project"
|
"expected": "project"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
"parameters": {
|
"parameters": {
|
||||||
"files": [
|
"files": [
|
||||||
{
|
{
|
||||||
"url": "https://docs.google.com/spreadsheets/d/13YL-KC__pav2qp3sFDs1BT2wZnpWGp7s/export?format=xlsx",
|
"url": "https://drive.google.com/uc?export=download&id=1B5GmhdVD07UeYj9Ox20DHsA_gaxEFQ6Z",
|
||||||
"path": "/home/user/Desktop/stock.xlsx"
|
"path": "/home/user/Desktop/stock.xlsx"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -36,7 +36,7 @@
|
|||||||
},
|
},
|
||||||
"expected": {
|
"expected": {
|
||||||
"type": "cloud_file",
|
"type": "cloud_file",
|
||||||
"path": "https://drive.google.com/uc?export=download&id=1oPPW_dozWGII5MRmdXdKKoEK5iBkd_8Q",
|
"path": "https://drive.google.com/uc?export=download&id=1wzlUL1gktA0d_j9W3WSSAAUcuKr5gw-n",
|
||||||
"dest": "result_gold.txt"
|
"dest": "result_gold.txt"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -304,15 +304,12 @@
|
|||||||
"os": [
|
"os": [
|
||||||
"94d95f96-9699-4208-98ba-3c3119edf9c2",
|
"94d95f96-9699-4208-98ba-3c3119edf9c2",
|
||||||
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
|
"bedcedc4-4d72-425e-ad62-21960b11fe0d",
|
||||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
|
||||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82",
|
|
||||||
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
|
"ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3",
|
||||||
"a462a795-fdc7-4b23-b689-e8b6df786b78",
|
"a462a795-fdc7-4b23-b689-e8b6df786b78",
|
||||||
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
|
"f9be0997-4b7c-45c5-b05c-4612b44a6118",
|
||||||
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
|
"28cc3b7e-b194-4bc9-8353-d04c0f4d56d2",
|
||||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||||
"e0df059f-28a6-4169-924f-b9623e7184cc",
|
"e0df059f-28a6-4169-924f-b9623e7184cc",
|
||||||
"ddc75b62-7311-4af8-bfb3-859558542b36",
|
|
||||||
"b6781586-6346-41cd-935a-a6b1487918fc",
|
"b6781586-6346-41cd-935a-a6b1487918fc",
|
||||||
"b3d4a89c-53f2-4d6b-8b6a-541fb5d205fa",
|
"b3d4a89c-53f2-4d6b-8b6a-541fb5d205fa",
|
||||||
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
|
"3ce045a0-877b-42aa-8d2c-b4a863336ab8",
|
||||||
@@ -322,8 +319,6 @@
|
|||||||
"23393935-50c7-4a86-aeea-2b78fd089c5c",
|
"23393935-50c7-4a86-aeea-2b78fd089c5c",
|
||||||
"5812b315-e7bd-4265-b51f-863c02174c28",
|
"5812b315-e7bd-4265-b51f-863c02174c28",
|
||||||
"c288e301-e626-4b98-a1ab-159dcb162af5",
|
"c288e301-e626-4b98-a1ab-159dcb162af5",
|
||||||
"cc9d4f34-1ca0-4a1b-8ff2-09302696acb9",
|
|
||||||
"c56de254-a3ec-414e-81a6-83d2ce8c41fa",
|
|
||||||
"4783cc41-c03c-4e1b-89b4-50658f642bd5",
|
"4783cc41-c03c-4e1b-89b4-50658f642bd5",
|
||||||
"5c1075ca-bb34-46a3-a7a0-029bd7463e79",
|
"5c1075ca-bb34-46a3-a7a0-029bd7463e79",
|
||||||
"5ced85fc-fa1a-4217-95fd-0fb530545ce2",
|
"5ced85fc-fa1a-4217-95fd-0fb530545ce2",
|
||||||
@@ -376,7 +371,6 @@
|
|||||||
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
"4e60007a-f5be-4bfc-9723-c39affa0a6d3",
|
||||||
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
"e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2",
|
||||||
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
"9439a27b-18ae-42d8-9778-5f68f891805e",
|
||||||
"ae506c68-352c-4094-9caa-ee9d42052317",
|
|
||||||
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
"ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae",
|
||||||
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
"930fdb3b-11a8-46fe-9bac-577332e2640e",
|
||||||
"276cc624-87ea-4f08-ab93-f770e3790175",
|
"276cc624-87ea-4f08-ab93-f770e3790175",
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
{
|
{
|
||||||
"chrome": [
|
"chrome": [
|
||||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||||
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3"
|
"7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3",
|
||||||
|
"35253b65-1c19-4304-8aa4-6884b8218fc0",
|
||||||
|
"a96b564e-dbe9-42c3-9ccf-b4498073438a"
|
||||||
],
|
],
|
||||||
"gimp": [
|
"gimp": [
|
||||||
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
|
"7a4deb26-d57d-4ea9-9a73-630f66a7b568",
|
||||||
@@ -9,7 +11,8 @@
|
|||||||
],
|
],
|
||||||
"libreoffice_calc": [
|
"libreoffice_calc": [
|
||||||
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
"357ef137-7eeb-4c80-a3bb-0951f26a8aff",
|
||||||
"42e0a640-4f19-4b28-973d-729602b5a4a7"
|
"42e0a640-4f19-4b28-973d-729602b5a4a7",
|
||||||
|
"abed40dc-063f-4598-8ba5-9fe749c0615d"
|
||||||
],
|
],
|
||||||
"libreoffice_impress": [
|
"libreoffice_impress": [
|
||||||
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
"5d901039-a89c-4bfb-967b-bf66f4df075e",
|
||||||
@@ -20,17 +23,14 @@
|
|||||||
"0a0faba3-5580-44df-965d-f562a99b291c"
|
"0a0faba3-5580-44df-965d-f562a99b291c"
|
||||||
],
|
],
|
||||||
"multi_apps": [
|
"multi_apps": [
|
||||||
|
"a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a",
|
||||||
|
"5990457f-2adb-467b-a4af-5c857c92d762",
|
||||||
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
"2b9493d7-49b8-493a-a71b-56cd1f4d6908",
|
||||||
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
"46407397-a7d5-4c6b-92c6-dbe038b1457b",
|
||||||
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
"4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc",
|
||||||
"510f64c8-9bcc-4be1-8d30-638705850618",
|
"510f64c8-9bcc-4be1-8d30-638705850618",
|
||||||
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
"897e3b53-5d4d-444b-85cb-2cdc8a97d903",
|
||||||
"c867c42d-a52d-4a24-8ae3-f75d256b5618",
|
"c867c42d-a52d-4a24-8ae3-f75d256b5618",
|
||||||
"e135df7c-7687-4ac0-a5f0-76b74438b53e",
|
|
||||||
"f7dfbef3-7697-431c-883a-db8583a4e4f9",
|
|
||||||
"6d72aad6-187a-4392-a4c4-ed87269c51cf",
|
|
||||||
"f918266a-b3e0-4914-865d-4faa564f1aef",
|
|
||||||
"da52d699-e8d2-4dc5-9191-a2199e0b6a9b",
|
|
||||||
"74d5859f-ed66-4d3e-aa0e-93d7a592ce41",
|
"74d5859f-ed66-4d3e-aa0e-93d7a592ce41",
|
||||||
"b5062e3e-641c-4e3a-907b-ac864d2e7652",
|
"b5062e3e-641c-4e3a-907b-ac864d2e7652",
|
||||||
"48d05431-6cd5-4e76-82eb-12b60d823f7d",
|
"48d05431-6cd5-4e76-82eb-12b60d823f7d",
|
||||||
@@ -38,53 +38,12 @@
|
|||||||
"d1acdb87-bb67-4f30-84aa-990e56a09c92",
|
"d1acdb87-bb67-4f30-84aa-990e56a09c92",
|
||||||
"deec51c9-3b1e-4b9e-993c-4776f20e8bb2",
|
"deec51c9-3b1e-4b9e-993c-4776f20e8bb2",
|
||||||
"8e116af7-7db7-4e35-a68b-b0939c066c78",
|
"8e116af7-7db7-4e35-a68b-b0939c066c78",
|
||||||
"185f29bd-5da0-40a6-b69c-ba7f4e0324ef",
|
|
||||||
"2c1ebcd7-9c6d-4c9a-afad-900e381ecd5e",
|
|
||||||
"3a93cae4-ad3e-403e-8c12-65303b271818",
|
|
||||||
"1f18aa87-af6f-41ef-9853-cdb8f32ebdea",
|
|
||||||
"26150609-0da3-4a7d-8868-0faf9c5f01bb",
|
|
||||||
"7e287123-70ca-47b9-8521-47db09b69b14",
|
|
||||||
"e2392362-125e-4f76-a2ee-524b183a3412",
|
|
||||||
"26660ad1-6ebb-4f59-8cba-a8432dfe8d38",
|
|
||||||
"a82b78bb-7fde-4cb3-94a4-035baf10bcf0",
|
|
||||||
"36037439-2044-4b50-b9d1-875b5a332143",
|
|
||||||
"716a6079-22da-47f1-ba73-c9d58f986a38",
|
"716a6079-22da-47f1-ba73-c9d58f986a38",
|
||||||
"a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a",
|
"2373b66a-092d-44cb-bfd7-82e86e7a3b4d"
|
||||||
"6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a",
|
|
||||||
"da922383-bfa4-4cd3-bbad-6bebab3d7742",
|
|
||||||
"2373b66a-092d-44cb-bfd7-82e86e7a3b4d",
|
|
||||||
"81c425f5-78f3-4771-afd6-3d2973825947",
|
|
||||||
"227d2f97-562b-4ccb-ae47-a5ec9e142fbb",
|
|
||||||
"20236825-b5df-46e7-89bf-62e1d640a897",
|
|
||||||
"02ce9a50-7af2-47ed-8596-af0c230501f8",
|
|
||||||
"4c26e3f3-3a14-4d86-b44a-d3cedebbb487",
|
|
||||||
"09a37c51-e625-49f4-a514-20a773797a8a",
|
|
||||||
"3e3fc409-bff3-4905-bf16-c968eee3f807",
|
|
||||||
"415ef462-bed3-493a-ac36-ca8c6d23bf1b",
|
|
||||||
"9f3bb592-209d-43bc-bb47-d77d9df56504",
|
|
||||||
"dd60633f-2c72-42ba-8547-6f2c8cb0fdb0",
|
|
||||||
"3f05f3b9-29ba-4b6b-95aa-2204697ffc06",
|
|
||||||
"f8369178-fafe-40c2-adc4-b9b08a125456",
|
|
||||||
"778efd0a-153f-4842-9214-f05fc176b877",
|
|
||||||
"47f7c0ce-a5fb-4100-a5e6-65cd0e7429e5",
|
|
||||||
"c2751594-0cd5-4088-be1b-b5f2f9ec97c4",
|
|
||||||
"48c46dc7-fe04-4505-ade7-723cba1aa6f6",
|
|
||||||
"42d25c08-fb87-4927-8b65-93631280a26f",
|
|
||||||
"3c8f201a-009d-4bbe-8b65-a6f8b35bb57f",
|
|
||||||
"d68204bf-11c1-4b13-b48b-d303c73d4bf6",
|
|
||||||
"91190194-f406-4cd6-b3f9-c43fac942b22",
|
|
||||||
"7f35355e-02a6-45b5-b140-f0be698bcf85",
|
|
||||||
"98e8e339-5f91-4ed2-b2b2-12647cb134f4",
|
|
||||||
"df67aebb-fb3a-44fd-b75b-51b6012df509",
|
|
||||||
"5df7b33a-9f77-4101-823e-02f863e1c1ae",
|
|
||||||
"22a4636f-8179-4357-8e87-d1743ece1f81",
|
|
||||||
"236833a3-5704-47fc-888c-4f298f09f799"
|
|
||||||
],
|
],
|
||||||
"os": [
|
"os": [
|
||||||
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
"5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57",
|
||||||
"5812b315-e7bd-4265-b51f-863c02174c28",
|
"5812b315-e7bd-4265-b51f-863c02174c28"
|
||||||
"43c2d64c-bab5-4dcb-a30c-b888321c319a",
|
|
||||||
"7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82"
|
|
||||||
],
|
],
|
||||||
"thunderbird": [
|
"thunderbird": [
|
||||||
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
"bb5e4c0d-f964-439c-97b6-bdb9747de3f4",
|
||||||
@@ -96,6 +55,7 @@
|
|||||||
],
|
],
|
||||||
"vs_code": [
|
"vs_code": [
|
||||||
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
"0ed39f63-6049-43d4-ba4d-5fa2fe04a951",
|
||||||
"53ad5833-3455-407b-bbc6-45b4c79ab8fb"
|
"53ad5833-3455-407b-bbc6-45b4c79ab8fb",
|
||||||
|
"276cc624-87ea-4f08-ab93-f770e3790175"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -80,9 +80,11 @@ def filter_nodes(root: ET, platform="ubuntu", check_image=False):
|
|||||||
return filtered_nodes
|
return filtered_nodes
|
||||||
|
|
||||||
|
|
||||||
def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
|
def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sampling_ratio=1.0):
|
||||||
# Load the screenshot image
|
# Load the screenshot image
|
||||||
image = Image.open(image_file_path)
|
image = Image.open(image_file_path)
|
||||||
|
if float(down_sampling_ratio) != 1.0:
|
||||||
|
image = image.resize((int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio)))
|
||||||
draw = ImageDraw.Draw(image)
|
draw = ImageDraw.Draw(image)
|
||||||
marks = []
|
marks = []
|
||||||
drew_nodes = []
|
drew_nodes = []
|
||||||
@@ -108,6 +110,15 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
|
|||||||
coords = tuple(map(int, coords_str.strip('()').split(', ')))
|
coords = tuple(map(int, coords_str.strip('()').split(', ')))
|
||||||
size = tuple(map(int, size_str.strip('()').split(', ')))
|
size = tuple(map(int, size_str.strip('()').split(', ')))
|
||||||
|
|
||||||
|
import copy
|
||||||
|
original_coords = copy.deepcopy(coords)
|
||||||
|
original_size = copy.deepcopy(size)
|
||||||
|
|
||||||
|
if float(down_sampling_ratio) != 1.0:
|
||||||
|
# Downsample the coordinates and size
|
||||||
|
coords = tuple(int(coord * down_sampling_ratio) for coord in coords)
|
||||||
|
size = tuple(int(s * down_sampling_ratio) for s in size)
|
||||||
|
|
||||||
# Check for negative sizes
|
# Check for negative sizes
|
||||||
if size[0] <= 0 or size[1] <= 0:
|
if size[0] <= 0 or size[1] <= 0:
|
||||||
raise ValueError(f"Size must be positive, got: {size}")
|
raise ValueError(f"Size must be positive, got: {size}")
|
||||||
@@ -138,7 +149,7 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
|
|||||||
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
|
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
|
||||||
|
|
||||||
# each mark is an x, y, w, h tuple
|
# each mark is an x, y, w, h tuple
|
||||||
marks.append([coords[0], coords[1], size[0], size[1]])
|
marks.append([original_coords[0], original_coords[1], original_size[0], original_size[1]])
|
||||||
drew_nodes.append(_node)
|
drew_nodes.append(_node)
|
||||||
|
|
||||||
if _node.text:
|
if _node.text:
|
||||||
|
|||||||
@@ -6,17 +6,16 @@ import re
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
import numpy as np
|
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
import dashscope
|
import dashscope
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
import cv2
|
import tiktoken
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from google.api_core.exceptions import InvalidArgument
|
from google.api_core.exceptions import InvalidArgument
|
||||||
|
|
||||||
@@ -28,14 +27,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
|
|||||||
|
|
||||||
logger = logging.getLogger("desktopenv.agent")
|
logger = logging.getLogger("desktopenv.agent")
|
||||||
|
|
||||||
def downsample_image(img: Union[str, np.ndarray], ratio: Tuple[float, float]):
|
|
||||||
fx, fy = ratio
|
|
||||||
if isinstance(img, str):
|
|
||||||
img = cv2.imread(img)
|
|
||||||
|
|
||||||
resized = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
|
||||||
return resized
|
|
||||||
|
|
||||||
|
|
||||||
# Function to encode the image
|
# Function to encode the image
|
||||||
def encode_image(image_path):
|
def encode_image(image_path):
|
||||||
@@ -43,49 +34,49 @@ def encode_image(image_path):
|
|||||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
def linearize_accessibility_tree(accessibility_tree):
|
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||||||
|
|
||||||
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
|
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
|
||||||
# Linearize the accessibility tree nodes into a table format
|
# Linearize the accessibility tree nodes into a table format
|
||||||
|
|
||||||
for node in filtered_nodes:
|
for node in filtered_nodes:
|
||||||
#linearized_accessibility_tree += node.tag + "\t"
|
# linearized_accessibility_tree += node.tag + "\t"
|
||||||
#linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||||
if node.text:
|
if node.text:
|
||||||
text = ( node.text if '"' not in node.text\
|
text = (node.text if '"' not in node.text \
|
||||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||||
)
|
)
|
||||||
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
||||||
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
||||||
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
||||||
text = (text if '"' not in text\
|
text = (text if '"' not in text \
|
||||||
else '"{:}"'.format(text.replace('"', '""'))
|
else '"{:}"'.format(text.replace('"', '""'))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text = '""'
|
text = '""'
|
||||||
#linearized_accessibility_tree += node.attrib.get(
|
# linearized_accessibility_tree += node.attrib.get(
|
||||||
#, "") + "\t"
|
# , "") + "\t"
|
||||||
#linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
||||||
linearized_accessibility_tree.append(
|
linearized_accessibility_tree.append(
|
||||||
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||||
node.tag, node.get("name", ""), text
|
node.tag, node.get("name", ""), text
|
||||||
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
|
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
|
||||||
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
|
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return "\n".join(linearized_accessibility_tree)
|
return "\n".join(linearized_accessibility_tree)
|
||||||
|
|
||||||
|
|
||||||
def tag_screenshot(screenshot, accessibility_tree):
|
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
|
||||||
# Creat a tmp file to store the screenshot in random name
|
# Creat a tmp file to store the screenshot in random name
|
||||||
uuid_str = str(uuid.uuid4())
|
uuid_str = str(uuid.uuid4())
|
||||||
os.makedirs("tmp/images", exist_ok=True)
|
os.makedirs("tmp/images", exist_ok=True)
|
||||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||||
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||||
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
|
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
|
||||||
# Make tag screenshot
|
# Make tag screenshot
|
||||||
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||||
|
|
||||||
@@ -181,9 +172,18 @@ def parse_code_from_som_string(input_string, masks):
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
|
||||||
|
enc = tiktoken.encoding_for_model("gpt-4")
|
||||||
|
tokens = enc.encode(linearized_accessibility_tree)
|
||||||
|
if len(tokens) > max_tokens:
|
||||||
|
linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
|
||||||
|
linearized_accessibility_tree += "[...]\n"
|
||||||
|
return linearized_accessibility_tree
|
||||||
|
|
||||||
class PromptAgent:
|
class PromptAgent:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
platform="ubuntu",
|
||||||
model="gpt-4-vision-preview",
|
model="gpt-4-vision-preview",
|
||||||
max_tokens=1500,
|
max_tokens=1500,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
@@ -191,8 +191,10 @@ class PromptAgent:
|
|||||||
action_space="computer_13",
|
action_space="computer_13",
|
||||||
observation_type="screenshot_a11y_tree",
|
observation_type="screenshot_a11y_tree",
|
||||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||||
max_trajectory_length=3
|
max_trajectory_length=3,
|
||||||
|
a11y_tree_max_tokens=10000
|
||||||
):
|
):
|
||||||
|
self.platform = platform
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
@@ -200,6 +202,7 @@ class PromptAgent:
|
|||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
self.observation_type = observation_type
|
self.observation_type = observation_type
|
||||||
self.max_trajectory_length = max_trajectory_length
|
self.max_trajectory_length = max_trajectory_length
|
||||||
|
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||||||
|
|
||||||
self.thoughts = []
|
self.thoughts = []
|
||||||
self.actions = []
|
self.actions = []
|
||||||
@@ -261,9 +264,14 @@ class PromptAgent:
|
|||||||
, "The number of observations and actions should be the same."
|
, "The number of observations and actions should be the same."
|
||||||
|
|
||||||
if len(self.observations) > self.max_trajectory_length:
|
if len(self.observations) > self.max_trajectory_length:
|
||||||
_observations = self.observations[-self.max_trajectory_length:]
|
if self.max_trajectory_length == 0:
|
||||||
_actions = self.actions[-self.max_trajectory_length:]
|
_observations = []
|
||||||
_thoughts = self.thoughts[-self.max_trajectory_length:]
|
_actions = []
|
||||||
|
_thoughts = []
|
||||||
|
else:
|
||||||
|
_observations = self.observations[-self.max_trajectory_length:]
|
||||||
|
_actions = self.actions[-self.max_trajectory_length:]
|
||||||
|
_thoughts = self.thoughts[-self.max_trajectory_length:]
|
||||||
else:
|
else:
|
||||||
_observations = self.observations
|
_observations = self.observations
|
||||||
_actions = self.actions
|
_actions = self.actions
|
||||||
@@ -360,9 +368,14 @@ class PromptAgent:
|
|||||||
# {{{1
|
# {{{1
|
||||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||||
base64_image = encode_image(obs["screenshot"])
|
base64_image = encode_image(obs["screenshot"])
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||||
|
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
|
||||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||||
|
|
||||||
|
if linearized_accessibility_tree:
|
||||||
|
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||||
|
self.a11y_tree_max_tokens)
|
||||||
|
|
||||||
if self.observation_type == "screenshot_a11y_tree":
|
if self.observation_type == "screenshot_a11y_tree":
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": base64_image,
|
"screenshot": base64_image,
|
||||||
@@ -394,9 +407,14 @@ class PromptAgent:
|
|||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.observation_type == "a11y_tree":
|
elif self.observation_type == "a11y_tree":
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||||
|
platform=self.platform)
|
||||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||||
|
|
||||||
|
if linearized_accessibility_tree:
|
||||||
|
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||||
|
self.a11y_tree_max_tokens)
|
||||||
|
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": None,
|
"screenshot": None,
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
"accessibility_tree": linearized_accessibility_tree
|
||||||
@@ -414,10 +432,15 @@ class PromptAgent:
|
|||||||
})
|
})
|
||||||
elif self.observation_type == "som":
|
elif self.observation_type == "som":
|
||||||
# Add som to the screenshot
|
# Add som to the screenshot
|
||||||
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
|
||||||
|
"accessibility_tree"], self.platform)
|
||||||
base64_image = encode_image(tagged_screenshot)
|
base64_image = encode_image(tagged_screenshot)
|
||||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||||
|
|
||||||
|
if linearized_accessibility_tree:
|
||||||
|
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||||
|
self.a11y_tree_max_tokens)
|
||||||
|
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": base64_image,
|
"screenshot": base64_image,
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
"accessibility_tree": linearized_accessibility_tree
|
||||||
@@ -446,7 +469,7 @@ class PromptAgent:
|
|||||||
# with open("messages.json", "w") as f:
|
# with open("messages.json", "w") as f:
|
||||||
# f.write(json.dumps(messages, indent=4))
|
# f.write(json.dumps(messages, indent=4))
|
||||||
|
|
||||||
#logger.info("PROMPT: %s", messages)
|
# logger.info("PROMPT: %s", messages)
|
||||||
|
|
||||||
response = self.call_llm({
|
response = self.call_llm({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
@@ -567,8 +590,6 @@ class PromptAgent:
|
|||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
@@ -581,7 +602,8 @@ class PromptAgent:
|
|||||||
attempt = 0
|
attempt = 0
|
||||||
while attempt < max_attempts:
|
while attempt < max_attempts:
|
||||||
# response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload)
|
# response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload)
|
||||||
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers, json=payload)
|
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers,
|
||||||
|
json=payload)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
result = response.json()['choices'][0]['message']['content']
|
result = response.json()['choices'][0]['message']['content']
|
||||||
break
|
break
|
||||||
@@ -616,7 +638,6 @@ class PromptAgent:
|
|||||||
|
|
||||||
mistral_messages.append(mistral_message)
|
mistral_messages.append(mistral_message)
|
||||||
|
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
||||||
@@ -752,26 +773,30 @@ class PromptAgent:
|
|||||||
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
||||||
genai.configure(api_key=api_key)
|
genai.configure(api_key=api_key)
|
||||||
logger.info("Generating content with Gemini model: %s", self.model)
|
logger.info("Generating content with Gemini model: %s", self.model)
|
||||||
response = genai.GenerativeModel(self.model).generate_content(
|
request_options = {"timeout": 120}
|
||||||
gemini_messages,
|
gemini_model = genai.GenerativeModel(self.model)
|
||||||
generation_config={
|
|
||||||
"candidate_count": 1,
|
|
||||||
"max_output_tokens": max_tokens,
|
|
||||||
"top_p": top_p,
|
|
||||||
"temperature": temperature
|
|
||||||
},
|
|
||||||
safety_settings={
|
|
||||||
"harassment": "block_none",
|
|
||||||
"hate": "block_none",
|
|
||||||
"sex": "block_none",
|
|
||||||
"danger": "block_none"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
response = gemini_model.generate_content(
|
||||||
|
gemini_messages,
|
||||||
|
generation_config={
|
||||||
|
"candidate_count": 1,
|
||||||
|
"max_output_tokens": max_tokens,
|
||||||
|
"top_p": top_p,
|
||||||
|
"temperature": temperature
|
||||||
|
},
|
||||||
|
safety_settings={
|
||||||
|
"harassment": "block_none",
|
||||||
|
"hate": "block_none",
|
||||||
|
"sex": "block_none",
|
||||||
|
"danger": "block_none"
|
||||||
|
},
|
||||||
|
request_options=request_options
|
||||||
|
)
|
||||||
return response.text
|
return response.text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Meet exception when calling Gemini API, " + str(e))
|
logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e))
|
||||||
|
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
|
||||||
|
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
|
||||||
return ""
|
return ""
|
||||||
elif self.model.startswith("qwen"):
|
elif self.model.startswith("qwen"):
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
|
|||||||
Reference in New Issue
Block a user