diff --git a/desktop_env/evaluators/metrics/gimp.py b/desktop_env/evaluators/metrics/gimp.py index b513260..8d976f6 100644 --- a/desktop_env/evaluators/metrics/gimp.py +++ b/desktop_env/evaluators/metrics/gimp.py @@ -276,14 +276,13 @@ def check_triangle_position(tgt_path): img = Image.open(tgt_path) img_array = np.array(img) - # We will determine if the triangle is in the middle of the picture by checking the centroid # We assume the triangle is a different color from the background # Find the unique colors - unique_colors = np.unique(img_array.reshape(-1, img_array.shape[2]), axis=0) + unique_colors, counts = np.unique(img_array.reshape(-1, img_array.shape[2]), axis=0, return_counts=True) + unique_colors_sorted = unique_colors[np.argsort(counts)] - # Assuming the background is the most common color and the triangle is a different color, - # we identify the triangle's color as the least common one - triangle_color = unique_colors[-1] + # Assuming the background is the most common color and the triangle is a different color + triangle_color = unique_colors_sorted[1] # Create a mask where the triangle pixels are True triangle_mask = np.all(img_array == triangle_color, axis=2) @@ -502,3 +501,7 @@ if __name__ == "__main__": src_path = "../../../cache/734d6579-c07d-47a8-9ae2-13339795476b/green_background_with_object.png" tgt_path = "../../../cache/734d6579-c07d-47a8-9ae2-13339795476b/white_background_with_object.png" print(check_green_background(src_path, tgt_path)) + + tgt_path = "../../../cache/f4aec372-4fb0-4df5-a52b-79e0e2a5d6ce/Triangle_In_The_Middle.png" + print(check_triangle_position(tgt_path)) +