update SoM_agent
This commit is contained in:
164
mm_agents/task_adapter/seem/tasks/inference_seem_pano.py
Normal file
164
mm_agents/task_adapter/seem/tasks/inference_seem_pano.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import io
|
||||
from .automatic_mask_generator import SeemAutomaticMaskGenerator
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
from segment_anything.utils.amg import (
|
||||
MaskData,
|
||||
area_from_rle,
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
box_xyxy_to_xywh,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
coco_encode_rle,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
mask_to_rle_pytorch,
|
||||
remove_small_regions,
|
||||
rle_to_mask,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
uncrop_points,
|
||||
)
|
||||
|
||||
|
||||
def inference_seem_pano(model, image, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image)
|
||||
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
|
||||
orig_size = images.shape[-2:]
|
||||
orig_h, orig_w = orig_size
|
||||
crop_box = [0,0,orig_w,orig_h]
|
||||
|
||||
data = {"image": images, "height": orig_h, "width": orig_w}
|
||||
batch_inputs = [data]
|
||||
|
||||
model.model.metadata = metadata
|
||||
outputs = model.model.evaluate(batch_inputs)
|
||||
|
||||
pano_mask = outputs[0]['panoptic_seg'][0]
|
||||
pano_info = outputs[0]['panoptic_seg'][1]
|
||||
|
||||
masks = []
|
||||
for seg_info in pano_info:
|
||||
masks += [pano_mask == seg_info['id']]
|
||||
masks = torch.stack(masks, dim=0)
|
||||
iou_preds = torch.ones(masks.shape[0], dtype=torch.float32)
|
||||
points = torch.zeros((masks.shape[0], 2), dtype=torch.float32)
|
||||
|
||||
mask_data = MaskData(
|
||||
masks=masks,
|
||||
iou_preds=iou_preds,
|
||||
points=points,
|
||||
)
|
||||
mask_data["stability_score"] = torch.ones(masks.shape[0], dtype=torch.float32)
|
||||
del masks
|
||||
|
||||
mask_data["boxes"] = batched_mask_to_box(mask_data["masks"])
|
||||
mask_data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(mask_data["boxes"]))])
|
||||
|
||||
# Compress to RLE
|
||||
mask_data["masks"] = uncrop_masks(mask_data["masks"], crop_box, orig_h, orig_w)
|
||||
mask_data["rles"] = mask_to_rle_pytorch(mask_data["masks"])
|
||||
del mask_data["masks"]
|
||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
||||
|
||||
# Write mask records
|
||||
outputs = []
|
||||
for idx in range(len(mask_data["segmentations"])):
|
||||
ann = {
|
||||
"segmentation": mask_data["segmentations"][idx],
|
||||
"area": area_from_rle(mask_data["rles"][idx]),
|
||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
||||
"point_coords": [mask_data["points"][idx].tolist()],
|
||||
"stability_score": mask_data["stability_score"][idx].item(),
|
||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
||||
}
|
||||
outputs.append(ann)
|
||||
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
# create a full zero image as the image_orig
|
||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
||||
label = 1
|
||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
mask = ann['segmentation']
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# color_mask = [int(c*255) for c in color_mask]
|
||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# assign the mask to the mask_map
|
||||
mask_map[mask == 1] = label
|
||||
label += 1
|
||||
im = demo.get_image()
|
||||
# fig=plt.figure(figsize=(10, 10))
|
||||
# plt.imshow(image_ori)
|
||||
# show_anns(outputs)
|
||||
# fig.canvas.draw()
|
||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
||||
return im, sorted_anns
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
def show_anns(anns):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in sorted_anns:
|
||||
m = ann['segmentation']
|
||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack((img, m*0.35)))
|
||||
Reference in New Issue
Block a user