Refactor baselines code implementations

This commit is contained in:
Timothyxxx
2024-01-20 18:55:21 +08:00
parent 09f3e776ae
commit f88331416c
7 changed files with 204 additions and 65 deletions

View File

@@ -60,19 +60,19 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
image = Image.open(image_file_path)
draw = ImageDraw.Draw(image)
marks = []
drew_nodes = []
# todo: change the image tagger to align with SoM paper
# Optional: Load a font. If you don't specify a font, a default one will be used.
try:
# Adjust the path to the font file you have or use a default one
font = ImageFont.truetype("arial.ttf", 20)
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
# Fallback to a basic font if the specified font can't be loaded
font = ImageFont.load_default()
index = 1
# Loop over all the visible nodes and draw their bounding boxes
for index, _node in enumerate(nodes):
for _node in nodes:
coords_str = _node.attrib.get('{uri:deskat:component.at-spi.gnome.org}screencoord')
size_str = _node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size')
@@ -93,22 +93,30 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
if bottom_right[0] < coords[0] or bottom_right[1] < coords[1]:
raise ValueError(f"Invalid coordinates or size, coords: {coords}, size: {size}")
# Draw rectangle on image
draw.rectangle([coords, bottom_right], outline="red", width=2)
# Check if the area only contains one color
cropped_image = image.crop((*coords, *bottom_right))
if len(set(list(cropped_image.getdata()))) == 1:
continue
# Draw index number at the bottom left of the bounding box
# Draw rectangle on image
draw.rectangle([coords, bottom_right], outline="red", width=1)
# Draw index number at the bottom left of the bounding box with black background
text_position = (coords[0], bottom_right[1]) # Adjust Y to be above the bottom right
draw.text(text_position, str(index), font=font, fill="purple")
draw.rectangle([text_position, (text_position[0] + 25, text_position[1] + 18)], fill='black')
draw.text(text_position, str(index), font=font, fill="white")
index += 1
# each mark is an x, y, w, h tuple
marks.append([coords[0], coords[1], size[0], size[1]])
drew_nodes.append(_node)
except ValueError as e:
pass
# Save the result
image.save(output_image_file_path)
return marks
return marks, drew_nodes
def print_nodes_with_indent(nodes, indent=0):
@@ -120,6 +128,10 @@ def print_nodes_with_indent(nodes, indent=0):
if __name__ == '__main__':
with open('chrome_desktop_example_1.xml', 'r', encoding='utf-8') as f:
xml_file_str = f.read()
filtered_nodes = filter_nodes(find_leaf_nodes(xml_file_str))
print(len(filtered_nodes))
masks = draw_bounding_boxes(filtered_nodes, 'screenshot.png',
'chrome_desktop_example_1_tagged_remove.png', )
nodes = ET.fromstring(xml_file_str)
print_nodes_with_indent(nodes)
# print(masks)
print(len(masks))