update SoM_agent
This commit is contained in:
6
mm_agents/task_adapter/semantic_sam/tasks/__init__.py
Normal file
6
mm_agents/task_adapter/semantic_sam/tasks/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .interactive_idino_m2m import interactive_infer_image as interactive_infer_image_idino_m2m
|
||||
from .interactive_idino_m2m import interactive_infer_image_semantic, interactive_infer_image_3l
|
||||
from .inference_semsam_m2m_auto import inference_semsam_m2m_auto
|
||||
from .interactive_idino_1o1_box import interactive_infer_image_box as interactive_infer_image_idino_m2m_box
|
||||
from .automatic_mask_generator import prompt_switch
|
||||
from .interactive_predictor import SemanticSAMPredictor
|
||||
@@ -0,0 +1,393 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
# from
|
||||
# from .modeling import Sam
|
||||
# from .predictor import SamPredictor
|
||||
from semantic_sam.utils.sam_utils.amg import (
|
||||
MaskData,
|
||||
area_from_rle,
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
box_xyxy_to_xywh,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
coco_encode_rle,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
mask_to_rle_pytorch,
|
||||
remove_small_regions,
|
||||
rle_to_mask,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
uncrop_points,
|
||||
)
|
||||
|
||||
|
||||
def prompt_switch(p):
|
||||
p = int(p)
|
||||
if p == 1:
|
||||
return 3
|
||||
if p == 2:
|
||||
return 2
|
||||
if p == 3:
|
||||
return 0
|
||||
if p == 4:
|
||||
return 4
|
||||
if p == 5:
|
||||
return 1
|
||||
if p == 6:
|
||||
return 5
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SemanticSamAutomaticMaskGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
points_per_side: Optional[int] = 32,
|
||||
points_per_batch: int = 200,
|
||||
pred_iou_thresh: float = 0.88,
|
||||
stability_score_thresh: float = 0.92,
|
||||
stability_score_offset: float = 1.0,
|
||||
box_nms_thresh: float = 0.7,
|
||||
crop_n_layers: int = 0,
|
||||
crop_nms_thresh: float = 0.7,
|
||||
crop_overlap_ratio: float = 512 / 1500,
|
||||
crop_n_points_downscale_factor: int = 1,
|
||||
point_grids: Optional[List[np.ndarray]] = None,
|
||||
min_mask_region_area: int = 10,
|
||||
output_mode: str = "binary_mask",
|
||||
level: list = [1, 2, 3, 4, 5, 6],
|
||||
) -> None:
|
||||
"""
|
||||
Using a SAM model, generates masks for the entire image.
|
||||
Generates a grid of point prompts over the image, then filters
|
||||
low quality and duplicate masks. The default settings are chosen
|
||||
for SAM with a ViT-H backbone.
|
||||
|
||||
Arguments:
|
||||
model (Sam): The SAM model to use for mask prediction.
|
||||
points_per_side (int or None): The number of points to be sampled
|
||||
along one side of the image. The total number of points is
|
||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
||||
point sampling.
|
||||
points_per_batch (int): Sets the number of points run simultaneously
|
||||
by the model. Higher numbers may be faster but use more GPU memory.
|
||||
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
||||
model's predicted mask quality.
|
||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
||||
the stability of the mask under changes to the cutoff used to binarize
|
||||
the model's mask predictions.
|
||||
stability_score_offset (float): The amount to shift the cutoff when
|
||||
calculated the stability score.
|
||||
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks.
|
||||
crops_n_layers (int): If >0, mask prediction will be run again on
|
||||
crops of the image. Sets the number of layers to run, where each
|
||||
layer has 2**i_layer number of image crops.
|
||||
crops_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks between different crops.
|
||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
||||
In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_n_points_downscale_factor (int): The number of points-per-side
|
||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
point_grids (list(np.ndarray) or None): A list over explicit grids
|
||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
||||
min_mask_region_area (int): If >0, postprocessing will be applied
|
||||
to remove disconnected regions and holes in masks with area smaller
|
||||
than min_mask_region_area. Requires opencv.
|
||||
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
||||
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
||||
For large resolutions, 'binary_mask' may consume large amounts of
|
||||
memory.
|
||||
"""
|
||||
self.level = [prompt_switch(l) for l in level]
|
||||
assert (points_per_side is None) != (
|
||||
point_grids is None
|
||||
), "Exactly one of points_per_side or point_grid must be provided."
|
||||
if points_per_side is not None:
|
||||
self.point_grids = build_all_layer_point_grids(
|
||||
points_per_side,
|
||||
crop_n_layers,
|
||||
crop_n_points_downscale_factor,
|
||||
)
|
||||
elif point_grids is not None:
|
||||
self.point_grids = point_grids
|
||||
else:
|
||||
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
||||
|
||||
assert output_mode in [
|
||||
"binary_mask",
|
||||
"uncompressed_rle",
|
||||
"coco_rle",
|
||||
], f"Unknown output_mode {output_mode}."
|
||||
if output_mode == "coco_rle":
|
||||
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
||||
|
||||
if min_mask_region_area > 0:
|
||||
import cv2 # type: ignore # noqa: F401
|
||||
|
||||
self.predictor = model
|
||||
self.points_per_batch = points_per_batch
|
||||
self.pred_iou_thresh = pred_iou_thresh
|
||||
self.stability_score_thresh = stability_score_thresh
|
||||
self.stability_score_offset = stability_score_offset
|
||||
self.box_nms_thresh = box_nms_thresh
|
||||
self.crop_n_layers = crop_n_layers
|
||||
self.crop_nms_thresh = crop_nms_thresh
|
||||
self.crop_overlap_ratio = crop_overlap_ratio
|
||||
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
||||
self.min_mask_region_area = min_mask_region_area
|
||||
self.output_mode = output_mode
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generates masks for the given image.
|
||||
|
||||
Arguments:
|
||||
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
||||
|
||||
Returns:
|
||||
list(dict(str, any)): A list over records for masks. Each record is
|
||||
a dict containing the following keys:
|
||||
segmentation (dict(str, any) or np.ndarray): The mask. If
|
||||
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
||||
is a dictionary containing the RLE.
|
||||
bbox (list(float)): The box around the mask, in XYWH format.
|
||||
area (int): The area in pixels of the mask.
|
||||
predicted_iou (float): The model's own prediction of the mask's
|
||||
quality. This is filtered by the pred_iou_thresh parameter.
|
||||
point_coords (list(list(float))): The point coordinates input
|
||||
to the model to generate this mask.
|
||||
stability_score (float): A measure of the mask's quality. This
|
||||
is filtered on using the stability_score_thresh parameter.
|
||||
crop_box (list(float)): The crop of the image used to generate
|
||||
the mask, given in XYWH format.
|
||||
"""
|
||||
|
||||
# Generate masks
|
||||
mask_data = self._generate_masks(image)
|
||||
|
||||
# Filter small disconnected regions and holes in masks
|
||||
if self.min_mask_region_area > 0:
|
||||
mask_data = self.postprocess_small_regions(
|
||||
mask_data,
|
||||
self.min_mask_region_area,
|
||||
max(self.box_nms_thresh, self.crop_nms_thresh),
|
||||
)
|
||||
# Encode masks
|
||||
if self.output_mode == "coco_rle":
|
||||
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
||||
elif self.output_mode == "binary_mask":
|
||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
||||
else:
|
||||
mask_data["segmentations"] = mask_data["rles"]
|
||||
|
||||
# Write mask records
|
||||
curr_anns = []
|
||||
for idx in range(len(mask_data["segmentations"])):
|
||||
ann = {
|
||||
"segmentation": mask_data["segmentations"][idx],
|
||||
"area": area_from_rle(mask_data["rles"][idx]),
|
||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
||||
"point_coords": [mask_data["points"][idx].tolist()],
|
||||
"stability_score": mask_data["stability_score"][idx].item(),
|
||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
||||
}
|
||||
curr_anns.append(ann)
|
||||
|
||||
return curr_anns
|
||||
|
||||
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
||||
orig_size = image.shape[-2:]
|
||||
crop_boxes, layer_idxs = generate_crop_boxes(
|
||||
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
||||
)
|
||||
|
||||
# Iterate over image crops
|
||||
assert len(crop_boxes)==1
|
||||
data = MaskData()
|
||||
# import ipdb; ipdb.set_trace()
|
||||
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
||||
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
||||
|
||||
data.cat(crop_data)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# Remove duplicate masks between crops
|
||||
if len(crop_boxes) > 1:
|
||||
# Prefer masks from smaller crops
|
||||
scores = 1 / box_area(data["crop_boxes"])
|
||||
scores = scores.to(data["boxes"].device)
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
scores,
|
||||
torch.zeros(len(data["boxes"])), # categories
|
||||
iou_threshold=self.crop_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
data.to_numpy()
|
||||
return data
|
||||
|
||||
def _process_crop(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
crop_box: List[int],
|
||||
crop_layer_idx: int,
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
# Crop the image and calculate embeddings
|
||||
x0, y0, x1, y1 = crop_box
|
||||
cropped_im = image#[y0:y1, x0:x1, :]
|
||||
cropped_im_size = cropped_im.shape[-2:]
|
||||
# self.predictor.set_image(cropped_im)
|
||||
|
||||
# Get points for this crop
|
||||
points_scale = np.array(cropped_im_size)[None, ::-1]
|
||||
points_for_image = self.point_grids[crop_layer_idx] #* points_scale
|
||||
|
||||
# Generate masks for this crop in batches
|
||||
data = MaskData()
|
||||
self.enc_features=None
|
||||
# import ipdb; ipdb.set_trace()
|
||||
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
||||
batch_data = self._process_batch(cropped_im,points, cropped_im_size, crop_box, orig_size)
|
||||
data.cat(batch_data)
|
||||
del batch_data
|
||||
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
data["iou_preds"],
|
||||
torch.zeros(len(data["boxes"])), # categories
|
||||
iou_threshold=self.box_nms_thresh,
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
data.filter(keep_by_nms)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# Return to the original image frame
|
||||
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
||||
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
||||
|
||||
return data
|
||||
|
||||
def _process_batch(
|
||||
self,
|
||||
images,
|
||||
points: np.ndarray,
|
||||
im_size: Tuple[int, ...],
|
||||
crop_box: List[int],
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
data = {"image": images, "height": orig_h, "width": orig_w}
|
||||
points=torch.tensor(points,dtype=torch.float).to(images.device)
|
||||
points = torch.cat([points, points.new_tensor([[0.005, 0.005]]).repeat(len(points), 1)], dim=-1)
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points']=points
|
||||
data['targets'][0]['pb']=points.new_tensor([0.]*len(points))
|
||||
batch_inputs = [data]
|
||||
if self.enc_features is None:
|
||||
masks, iou_preds,mask_features,multi_scale_features= self.predictor.model.evaluate_demo(batch_inputs,None,None,return_features=True, level=self.level)
|
||||
self.enc_features=(mask_features,multi_scale_features)
|
||||
else:
|
||||
masks, iou_preds= self.predictor.model.evaluate_demo(batch_inputs,None,None,self.enc_features[0],self.enc_features[1], level=self.level)
|
||||
|
||||
data = MaskData(
|
||||
masks=masks,
|
||||
iou_preds=iou_preds.flatten(),
|
||||
points=torch.as_tensor(points[:,None].repeat(1,len(self.level), 1).view(-1,4)),
|
||||
)
|
||||
del masks
|
||||
# Filter by predicted IoU
|
||||
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Calculate stability score
|
||||
data["stability_score"] = calculate_stability_score(
|
||||
data["masks"], 0.0, self.stability_score_offset
|
||||
)
|
||||
# if self.stability_score_thresh > 0.0:
|
||||
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Threshold masks and calculate boxes
|
||||
data["masks"] = data["masks"] > 0.0
|
||||
data["boxes"] = batched_mask_to_box(data["masks"])
|
||||
|
||||
# Filter boxes that touch crop boundaries
|
||||
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
||||
if not torch.all(keep_mask):
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Compress to RLE
|
||||
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
||||
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
||||
del data["masks"]
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def postprocess_small_regions(
|
||||
mask_data: MaskData, min_area: int, nms_thresh: float
|
||||
) -> MaskData:
|
||||
"""
|
||||
Removes small disconnected regions and holes in masks, then reruns
|
||||
box NMS to remove any new duplicates.
|
||||
|
||||
Edits mask_data in place.
|
||||
|
||||
Requires open-cv as a dependency.
|
||||
"""
|
||||
if len(mask_data["rles"]) == 0:
|
||||
return mask_data
|
||||
|
||||
# Filter small disconnected regions and holes
|
||||
new_masks = []
|
||||
scores = []
|
||||
for rle in mask_data["rles"]:
|
||||
mask = rle_to_mask(rle)
|
||||
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
||||
unchanged = not changed
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
||||
unchanged = unchanged and not changed
|
||||
|
||||
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
||||
# Give score=0 to changed masks and score=1 to unchanged masks
|
||||
# so NMS will prefer ones that didn't need postprocessing
|
||||
scores.append(float(unchanged))
|
||||
|
||||
# Recalculate boxes and remove any new duplicates
|
||||
masks = torch.cat(new_masks, dim=0)
|
||||
boxes = batched_mask_to_box(masks)
|
||||
keep_by_nms = batched_nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
torch.zeros(len(boxes)), # categories
|
||||
iou_threshold=nms_thresh,
|
||||
)
|
||||
|
||||
# Only recalculate RLEs for masks that have changed
|
||||
for i_mask in keep_by_nms:
|
||||
if scores[i_mask] == 0.0:
|
||||
mask_torch = masks[i_mask].unsqueeze(0)
|
||||
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
||||
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
||||
mask_data.filter(keep_by_nms)
|
||||
|
||||
return mask_data
|
||||
@@ -0,0 +1,108 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import io
|
||||
from .automatic_mask_generator import SemanticSamAutomaticMaskGenerator
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
def inference_semsam_m2m_auto(model, image, level, all_classes, all_parts, thresh, text_size, hole_scale, island_scale, semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image)
|
||||
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
|
||||
mask_generator = SemanticSamAutomaticMaskGenerator(model,points_per_side=32,
|
||||
pred_iou_thresh=0.88,
|
||||
stability_score_thresh=0.92,
|
||||
min_mask_region_area=10,
|
||||
level=level,
|
||||
)
|
||||
outputs = mask_generator.generate(images)
|
||||
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
||||
label = 1
|
||||
# for ann in sorted_anns:
|
||||
# mask = ann['segmentation']
|
||||
# color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# # color_mask = [int(c*255) for c in color_mask]
|
||||
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# label += 1
|
||||
# im = demo.get_image()
|
||||
|
||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
mask = ann['segmentation']
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# color_mask = [int(c*255) for c in color_mask]
|
||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# assign the mask to the mask_map
|
||||
mask_map[mask == 1] = label
|
||||
label += 1
|
||||
im = demo.get_image()
|
||||
# fig=plt.figure(figsize=(10, 10))
|
||||
# plt.imshow(image_ori)
|
||||
# show_anns(outputs)
|
||||
# fig.canvas.draw()
|
||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
||||
return im, sorted_anns
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
def show_anns(anns):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in sorted_anns:
|
||||
m = ann['segmentation']
|
||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack((img, m*0.35)))
|
||||
@@ -0,0 +1,144 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
from detectron2.structures import BitMasks
|
||||
from semantic_sam.utils import box_ops
|
||||
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
def interactive_infer_image_box(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image['image'])
|
||||
mask_ori = transform1(image['mask'])
|
||||
width = image_ori.size[0]
|
||||
height = image_ori.size[1]
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
||||
|
||||
|
||||
data = {"image": images, "height": height, "width": width}
|
||||
|
||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
||||
flaten_mask = mask_ori.unsqueeze(0)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
points=mask_ori.nonzero().float().to(images.device)
|
||||
if len(points)==0:
|
||||
point_=point=points.new_tensor([[0.5,0.5,0.5,0.5]])
|
||||
else:
|
||||
mean_point=points.mean(0)[None]
|
||||
box_xyxy = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor
|
||||
h = mask_ori.shape[0]
|
||||
w = mask_ori.shape[1]
|
||||
box_xywh = (box_ops.box_xyxy_to_cxcywh(box_xyxy) / torch.as_tensor([w, h, w, h])).cuda()
|
||||
|
||||
# point_=points.mean(0)[None]
|
||||
# point=point_.clone()
|
||||
# point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
||||
# point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
||||
# point = point[:, [1, 0]]
|
||||
point=box_xywh
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points']=point
|
||||
data['targets'][0]['pb']=point.new_tensor([1.])
|
||||
|
||||
|
||||
batch_inputs = [data]
|
||||
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts, task='demo_box')
|
||||
|
||||
pred_masks_poses = masks
|
||||
reses=[]
|
||||
ious=ious[0,0]
|
||||
ids=torch.argsort(ious,descending=True)
|
||||
|
||||
text_res=''
|
||||
try:
|
||||
thresh=float(thresh)
|
||||
except Exception:
|
||||
thresh=0.0
|
||||
mask_ls=[]
|
||||
ious_res=[]
|
||||
areas=[]
|
||||
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
|
||||
iou=round(float(iou),2)
|
||||
texts=f'{iou}'
|
||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
||||
area=mask.sum()
|
||||
conti=False
|
||||
if iou<thresh:
|
||||
conti=True
|
||||
for m in mask_ls:
|
||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
||||
conti=True
|
||||
break
|
||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
||||
conti=False
|
||||
if conti:
|
||||
continue
|
||||
ious_res.append(iou)
|
||||
mask_ls.append(mask)
|
||||
areas.append(area)
|
||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
||||
mask=(mask).astype(np.float)
|
||||
out_txt = texts
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
color=[0.,0.,1.0]
|
||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
||||
demo = visual.draw_box(box_xyxy[0])
|
||||
res = demo.get_image()
|
||||
# point_x0=max(0,int(point_[0, 1])-3)
|
||||
# point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
||||
# point_y0 = max(0, int(point_[0, 0]) - 3)
|
||||
# point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
||||
# res[point_y0:point_y1,point_x0:point_x1,0]=255
|
||||
# res[point_y0:point_y1,point_x0:point_x1,1]=0
|
||||
# res[point_y0:point_y1,point_x0:point_x1,2]=0
|
||||
reses.append(Image.fromarray(res))
|
||||
text_res=text_res+';'+out_txt
|
||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
||||
ids = [int(i) for i in ids]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return reses,[reses[i] for i in ids]
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
@@ -0,0 +1,322 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
def interactive_infer_image(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image['image'])
|
||||
mask_ori = transform1(image['mask'])
|
||||
width = image_ori.size[0]
|
||||
height = image_ori.size[1]
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
||||
|
||||
|
||||
data = {"image": images, "height": height, "width": width}
|
||||
|
||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
||||
points=mask_ori.nonzero().float().to(images.device)
|
||||
if len(points)==0:
|
||||
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
|
||||
else:
|
||||
point_=points.mean(0)[None]
|
||||
point=point_.clone()
|
||||
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
||||
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
||||
point = point[:, [1, 0]]
|
||||
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points']=point
|
||||
data['targets'][0]['pb']=point.new_tensor([0.])
|
||||
|
||||
|
||||
batch_inputs = [data]
|
||||
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts)
|
||||
|
||||
pred_masks_poses = masks
|
||||
reses=[]
|
||||
ious=ious[0,0]
|
||||
ids=torch.argsort(ious,descending=True)
|
||||
|
||||
text_res=''
|
||||
try:
|
||||
thresh=float(thresh)
|
||||
except Exception:
|
||||
thresh=0.0
|
||||
mask_ls=[]
|
||||
ious_res=[]
|
||||
areas=[]
|
||||
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
|
||||
iou=round(float(iou),2)
|
||||
texts=f'{iou}'
|
||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
||||
area=mask.sum()
|
||||
conti=False
|
||||
if iou<thresh:
|
||||
conti=True
|
||||
for m in mask_ls:
|
||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
||||
conti=True
|
||||
break
|
||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
||||
conti=False
|
||||
if conti:
|
||||
continue
|
||||
ious_res.append(iou)
|
||||
mask_ls.append(mask)
|
||||
areas.append(area)
|
||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
||||
mask=(mask).astype(np.float)
|
||||
out_txt = texts
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
color=[0.,0.,1.0]
|
||||
# demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
res = demo.get_image()
|
||||
point_x0=max(0,int(point_[0, 1])-3)
|
||||
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
||||
point_y0 = max(0, int(point_[0, 0]) - 3)
|
||||
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
||||
# res[point_y0:point_y1,point_x0:point_x1,0]=255
|
||||
# res[point_y0:point_y1,point_x0:point_x1,1]=0
|
||||
# res[point_y0:point_y1,point_x0:point_x1,2]=0
|
||||
reses.append(Image.fromarray(res))
|
||||
text_res=text_res+';'+out_txt
|
||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
||||
ids = [int(i) for i in ids]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return reses,[reses[i] for i in ids]
|
||||
|
||||
def interactive_infer_image_3l(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image['image'])
|
||||
mask_ori = transform1(image['mask'])
|
||||
width = image_ori.size[0]
|
||||
height = image_ori.size[1]
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
||||
|
||||
|
||||
data = {"image": images, "height": height, "width": width}
|
||||
|
||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
||||
points=mask_ori.nonzero().float().to(images.device)
|
||||
if len(points)==0:
|
||||
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
|
||||
else:
|
||||
point_=points.mean(0)[None]
|
||||
point=point_.clone()
|
||||
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
||||
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
||||
point = point[:, [1, 0]]
|
||||
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points']=point
|
||||
data['targets'][0]['pb']=point.new_tensor([0.])
|
||||
|
||||
|
||||
batch_inputs = [data]
|
||||
masks, ious, pred_class, pred_class_score = model.model.evaluate_demo(batch_inputs,all_classes,all_parts, level=[0,1,2])
|
||||
|
||||
pred_masks_poses = masks
|
||||
reses=[]
|
||||
ious=ious[0,0]
|
||||
ids=torch.argsort(ious,descending=True)
|
||||
|
||||
text_res=''
|
||||
try:
|
||||
thresh=float(thresh)
|
||||
except Exception:
|
||||
thresh=0.0
|
||||
mask_ls=[]
|
||||
ious_res=[]
|
||||
areas=[]
|
||||
new_pred_class = []
|
||||
new_pred_class_score = []
|
||||
for i in ids:
|
||||
new_pred_class_score.append(pred_class_score[i])
|
||||
new_pred_class.append(pred_class[i])
|
||||
# import ipdb; ipdb.set_trace()
|
||||
for i,(pred_masks_pos,iou, cls_name, cls_score) in enumerate(zip(pred_masks_poses[ids],ious[ids], new_pred_class, new_pred_class_score)):
|
||||
iou=round(float(iou),2)
|
||||
texts=f'{iou}_{cls_name}_{cls_score}'
|
||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
||||
area=mask.sum()
|
||||
conti=False
|
||||
if iou<thresh:
|
||||
conti=True
|
||||
for m in mask_ls:
|
||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
||||
conti=True
|
||||
break
|
||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
||||
conti=False
|
||||
if conti:
|
||||
continue
|
||||
ious_res.append(iou)
|
||||
mask_ls.append(mask)
|
||||
areas.append(area)
|
||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
||||
mask=(mask).astype(np.float)
|
||||
out_txt = texts
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
color=[0.,0.,1.0]
|
||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
||||
res = demo.get_image()
|
||||
point_x0=max(0,int(point_[0, 1])-3)
|
||||
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
||||
point_y0 = max(0, int(point_[0, 0]) - 3)
|
||||
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
||||
res[point_y0:point_y1,point_x0:point_x1,0]=255
|
||||
res[point_y0:point_y1,point_x0:point_x1,1]=0
|
||||
res[point_y0:point_y1,point_x0:point_x1,2]=0
|
||||
reses.append(Image.fromarray(res))
|
||||
text_res=text_res+';'+out_txt
|
||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
||||
ids = [int(i) for i in ids]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return reses,[reses[i] for i in ids]
|
||||
|
||||
def interactive_infer_image_semantic(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image['image'])
|
||||
mask_ori = transform1(image['mask'])
|
||||
width = image_ori.size[0]
|
||||
height = image_ori.size[1]
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
||||
|
||||
|
||||
data = {"image": images, "height": height, "width": width}
|
||||
|
||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
||||
points=mask_ori.nonzero().float().to(images.device)
|
||||
if len(points)==0:
|
||||
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
|
||||
else:
|
||||
point_=points.mean(0)[None]
|
||||
point=point_.clone()
|
||||
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
||||
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
||||
point = point[:, [1, 0]]
|
||||
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points']=point
|
||||
data['targets'][0]['pb']=point.new_tensor([0.])
|
||||
data['targets'][0]['pb']=point.new_tensor([1.])
|
||||
|
||||
|
||||
batch_inputs = [data]
|
||||
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts)
|
||||
|
||||
pred_masks_poses = masks
|
||||
reses=[]
|
||||
ious=ious[0,0]
|
||||
ids=torch.argsort(ious,descending=True)
|
||||
|
||||
text_res=''
|
||||
try:
|
||||
thresh=float(thresh)
|
||||
except Exception:
|
||||
thresh=0.0
|
||||
mask_ls=[]
|
||||
ious_res=[]
|
||||
areas=[]
|
||||
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
|
||||
iou=round(float(iou),2)
|
||||
texts=f'{iou}'
|
||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
||||
area=mask.sum()
|
||||
conti=False
|
||||
if iou<thresh:
|
||||
conti=True
|
||||
for m in mask_ls:
|
||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
||||
conti=True
|
||||
break
|
||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
||||
conti=False
|
||||
if conti:
|
||||
continue
|
||||
ious_res.append(iou)
|
||||
mask_ls.append(mask)
|
||||
areas.append(area)
|
||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
||||
mask=(mask).astype(np.float)
|
||||
out_txt = texts
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
color=[0.,0.,1.0]
|
||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
||||
res = demo.get_image()
|
||||
point_x0=max(0,int(point_[0, 1])-3)
|
||||
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
||||
point_y0 = max(0, int(point_[0, 0]) - 3)
|
||||
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
||||
res[point_y0:point_y1,point_x0:point_x1,0]=255
|
||||
res[point_y0:point_y1,point_x0:point_x1,1]=0
|
||||
res[point_y0:point_y1,point_x0:point_x1,2]=0
|
||||
reses.append(Image.fromarray(res))
|
||||
text_res=text_res+';'+out_txt
|
||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
||||
ids = [int(i) for i in ids]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return reses,[reses[i] for i in ids]
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
@@ -0,0 +1,139 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
|
||||
class SemanticSAMPredictor:
|
||||
def __init__(self, model, thresh=0.5, text_size=640, hole_scale=100, island_scale=100):
|
||||
"""
|
||||
thresh: iou thresh to filter low confidence objects
|
||||
text_size: resize the input image short edge for the model to process
|
||||
hole_scale: fill in small holes as in SAM
|
||||
island_scale: remove small regions as in SAM
|
||||
"""
|
||||
self.model = model
|
||||
self.thresh = thresh
|
||||
self.text_size = hole_scale
|
||||
self.hole_scale = hole_scale
|
||||
self.island_scale = island_scale
|
||||
self.point = None
|
||||
|
||||
def predict(self, image_ori, image, point=None):
|
||||
"""
|
||||
produce up to 6 prediction results for each click
|
||||
"""
|
||||
width = image_ori.shape[0]
|
||||
height = image_ori.shape[1]
|
||||
|
||||
data = {"image": image, "height": height, "width": width}
|
||||
# import ipdb; ipdb.set_trace()
|
||||
if point is None:
|
||||
point = torch.tensor([[0.5, 0.5, 0.006, 0.006]]).cuda()
|
||||
else:
|
||||
point = torch.tensor(point).cuda()
|
||||
point_ = point
|
||||
point = point_.clone()
|
||||
point[0, 0] = point_[0, 0]
|
||||
point[0, 1] = point_[0, 1]
|
||||
# point = point[:, [1, 0]]
|
||||
point = torch.cat([point, point.new_tensor([[0.005, 0.005]])], dim=-1)
|
||||
|
||||
self.point = point[:, :2].clone()*(torch.tensor([width, height]).to(point))
|
||||
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points'] = point
|
||||
data['targets'][0]['pb'] = point.new_tensor([0.])
|
||||
|
||||
batch_inputs = [data]
|
||||
masks, ious = self.model.model.evaluate_demo(batch_inputs)
|
||||
|
||||
return masks, ious
|
||||
|
||||
def process_multi_mask(self, masks, ious, image_ori):
|
||||
pred_masks_poses = masks
|
||||
reses = []
|
||||
ious = ious[0, 0]
|
||||
ids = torch.argsort(ious, descending=True)
|
||||
|
||||
text_res = ''
|
||||
mask_ls = []
|
||||
ious_res = []
|
||||
areas = []
|
||||
for i, (pred_masks_pos, iou) in enumerate(zip(pred_masks_poses[ids], ious[ids])):
|
||||
iou = round(float(iou), 2)
|
||||
texts = f'{iou}'
|
||||
mask = (pred_masks_pos > 0.0).cpu().numpy()
|
||||
area = mask.sum()
|
||||
conti = False
|
||||
if iou < self.thresh:
|
||||
conti = True
|
||||
for m in mask_ls:
|
||||
if np.logical_and(mask, m).sum() / np.logical_or(mask, m).sum() > 0.95:
|
||||
conti = True
|
||||
break
|
||||
if i == len(pred_masks_poses[ids]) - 1 and mask_ls == []:
|
||||
conti = False
|
||||
if conti:
|
||||
continue
|
||||
ious_res.append(iou)
|
||||
mask_ls.append(mask)
|
||||
areas.append(area)
|
||||
mask, _ = self.remove_small_regions(mask, int(self.hole_scale), mode="holes")
|
||||
mask, _ = self.remove_small_regions(mask, int(self.island_scale), mode="islands")
|
||||
mask = (mask).astype(np.float)
|
||||
out_txt = texts
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
color = [0., 0., 1.0]
|
||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
||||
res = demo.get_image()
|
||||
point_x0 = max(0, int(self.point[0, 0]) - 3)
|
||||
point_x1 = min(image_ori.shape[1], int(self.point[0, 0]) + 3)
|
||||
point_y0 = max(0, int(self.point[0, 1]) - 3)
|
||||
point_y1 = min(image_ori.shape[0], int(self.point[0, 1]) + 3)
|
||||
res[point_y0:point_y1, point_x0:point_x1, 0] = 255
|
||||
res[point_y0:point_y1, point_x0:point_x1, 1] = 0
|
||||
res[point_y0:point_y1, point_x0:point_x1, 2] = 0
|
||||
reses.append(Image.fromarray(res))
|
||||
text_res = text_res + ';' + out_txt
|
||||
ids = list(torch.argsort(torch.tensor(areas), descending=False))
|
||||
ids = [int(i) for i in ids]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return reses, [reses[i] for i in ids]
|
||||
|
||||
def predict_masks(self, image_ori, image, point=None):
|
||||
masks, ious = self.predict(image_ori, image, point)
|
||||
return self.process_multi_mask(masks, ious, image_ori)
|
||||
|
||||
@staticmethod
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
Reference in New Issue
Block a user