update SoM_agent

This commit is contained in:
Hilbert-Johnson
2023-12-31 19:13:17 +08:00
parent f04e625ad9
commit 7560f4dc46
19 changed files with 3729 additions and 49 deletions

View File

@@ -3,10 +3,10 @@
import os
import re
import base64
import PIL.Image
from desktop_env.envs.desktop_env import Action, MouseClick
import json
import requests
from mm_agents.gpt_4v_prompt import SYS_PROMPT
import torch
import argparse
@@ -33,9 +33,37 @@ from task_adapter.sam.tasks.inference_sam_m2m_interactive import inference_sam_m
from scipy.ndimage import label
import numpy as np
SYS_PROMPT = '''
You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection.
For each step, you will get an observation of an image, which is the screenshot of the computer screen. And you will predict the action of the computer based on the image.
Firstly you need to predict the class of your action, select from one below:
- **CLICK**: click on the screen with the specified integer label
- **TYPE**: type a string on the keyboard
- For CLICK, you need to predict the correct integer label shown on the screenshot
for example, format as:
```
{
"action_type": "MOUSE_MOVE",
"label": 7
}
```
- For TYPE, you need to specify the text you want to type
for example, format as:
```
{
"action_type": "TYPE",
"text": "hello world"
}
```
For every step, you should only return the action_type and the parameters of your action as a dict, without any other things. You MUST wrap the dict with backticks (\`).
You can predict multiple actions at one step, but you should only return one action for each step.
You MUST choose and ONLY CHOOSE from the action space above, otherwise your action will be considered as invalid and you will get a penalty.
'''
build args
'''
# build args
semsam_cfg = "configs/semantic_sam_only_sa-1b_swinL.yaml"
seem_cfg = "configs/seem_focall_unicl_lang_v1.yaml"
@@ -47,9 +75,7 @@ opt_semsam = load_opt_from_config_file(semsam_cfg)
opt_seem = load_opt_from_config_file(seem_cfg)
opt_seem = init_distributed_seem(opt_seem)
'''
build model
'''
# build model
model_semsam = BaseModel(opt_semsam, build_model(opt_semsam)).from_pretrained(semsam_ckpt).eval().cuda()
model_sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt).eval().cuda()
model_seem = BaseModel_Seem(opt_seem, build_model_seem(opt_seem)).from_pretrained(seem_ckpt).eval().cuda()
@@ -65,65 +91,46 @@ def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs
elif slider > 2.5:
model_name = 'sam'
else:
if mode == 'Automatic':
model_name = 'semantic-sam'
if slider < 1.5 + 0.14:
level = [1]
elif slider < 1.5 + 0.28:
level = [2]
elif slider < 1.5 + 0.42:
level = [3]
elif slider < 1.5 + 0.56:
level = [4]
elif slider < 1.5 + 0.70:
level = [5]
elif slider < 1.5 + 0.84:
level = [6]
else:
level = [6, 1, 2, 3, 4, 5]
model_name = 'semantic-sam'
if slider < 1.5 + 0.14:
level = [1]
elif slider < 1.5 + 0.28:
level = [2]
elif slider < 1.5 + 0.42:
level = [3]
elif slider < 1.5 + 0.56:
level = [4]
elif slider < 1.5 + 0.70:
level = [5]
elif slider < 1.5 + 0.84:
level = [6]
else:
model_name = 'sam'
level = [6, 1, 2, 3, 4, 5]
if label_mode == 'Alphabet':
label_mode = 'a'
else:
label_mode = '1'
text_size, hole_scale, island_scale = 640, 100, 100
label_mode = 'a' if label_mode == 'Alphabet' else '1'
text_size, hole_scale, island_scale = 1280, 100, 100
text, text_part, text_thresh = '', '', '0.0'
with torch.autocast(device_type='cuda', dtype=torch.float16):
semantic = False
if mode == "Interactive":
labeled_array, num_features = label(np.asarray(image['mask'].convert('L')))
spatial_masks = torch.stack([torch.from_numpy(labeled_array == i+1) for i in range(num_features)])
if model_name == 'semantic-sam':
model = model_semsam
output, mask = inference_semsam_m2m_auto(model, image['image'], level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs)
output, mask = inference_semsam_m2m_auto(model, image, level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs)
elif model_name == 'sam':
model = model_sam
if mode == "Automatic":
output, mask = inference_sam_m2m_auto(model, image['image'], text_size, label_mode, alpha, anno_mode)
elif mode == "Interactive":
output, mask = inference_sam_m2m_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
output, mask = inference_sam_m2m_auto(model, image, text_size, label_mode, alpha, anno_mode)
elif model_name == 'seem':
model = model_seem
if mode == "Automatic":
output, mask = inference_seem_pano(model, image['image'], text_size, label_mode, alpha, anno_mode)
elif mode == "Interactive":
output, mask = inference_seem_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
output, mask = inference_seem_pano(model, image, text_size, label_mode, alpha, anno_mode)
return output
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def encode_image(image):
return base64.b64encode(image).decode('utf-8')
def parse_actions_from_string(input_string):
# Search for a JSON string within the input string
@@ -156,7 +163,6 @@ def parse_actions_from_string(input_string):
except json.JSONDecodeError as e:
raise ValueError("Invalid response format: " + input_string)
class GPT4v_Agent:
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300):
self.instruction = instruction
@@ -181,7 +187,7 @@ class GPT4v_Agent:
]
def predict(self, obs):
obs = inference(obs, slider=2.0, mode="Automatic", alpha=0.1, label_mode="Alphabet", anno_mode=["Mask", "Mark"])
obs = inference(obs, slider=2.0, mode="Automatic", alpha=0.1, label_mode="Number", anno_mode=["Mark", "Box"])
base64_image = encode_image(obs)
self.trajectory.append({
"role": "user",
@@ -274,4 +280,5 @@ if __name__ == '__main__':
api_key = os.environ.get("OPENAI_API_KEY")
agent = GPT4v_Agent(api_key=api_key, instruction="Open Google Sheet")
print(agent.predict(obs="stackoverflow.png"))
obs = PIL.Image.open('stackoverflow.png')
print(agent.predict(obs=obs))

19
mm_agents/gemini_test.py Normal file
View File

@@ -0,0 +1,19 @@
import PIL.Image
import google.generativeai as genai
genai.configure(api_key="AIzaSyANsETKHVo-D8jZu1SnTSaQgLOJEDgnj9Q")
# for m in genai.list_models():
# if 'generateContent' in m.supported_generation_methods:
# print(m.name)
model = genai.GenerativeModel('gemini-pro-vision')
img = PIL.Image.open('image.jpg')
messages = [
{'role':'user',
'parts': ["Explain this image.", img]}
]
response = model.generate_content(messages)

View File

View File

@@ -0,0 +1,2 @@
from .inference_sam_m2m_auto import *
from .inference_sam_m2m_interactive import *

View File

@@ -0,0 +1,103 @@
# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------
import torch
import numpy as np
from torchvision import transforms
from task_adapter.utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
import matplotlib.pyplot as plt
import cv2
import io
from segment_anything import SamAutomaticMaskGenerator
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
def inference_sam_m2m_auto(model, image, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
transform1 = transforms.Compose(t)
image_ori = transform1(image)
image_ori = np.asarray(image_ori)
mask_generator = SamAutomaticMaskGenerator(model)
outputs = mask_generator.generate(image_ori)
from task_adapter.utils.visualizer import Visualizer
visual = Visualizer(image_ori, metadata=metadata)
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
label = 1
# for ann in sorted_anns:
# mask = ann['segmentation']
# color_mask = np.random.random((1, 3)).tolist()[0]
# # color_mask = [int(c*255) for c in color_mask]
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
# label += 1
# im = demo.get_image()
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
for i, ann in enumerate(sorted_anns):
mask = ann['segmentation']
color_mask = np.random.random((1, 3)).tolist()[0]
# color_mask = [int(c*255) for c in color_mask]
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
# assign the mask to the mask_map
mask_map[mask == 1] = label
label += 1
im = demo.get_image()
# fig=plt.figure(figsize=(10, 10))
# plt.imshow(image_ori)
# show_anns(outputs)
# fig.canvas.draw()
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
return im, sorted_anns
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))

View File

@@ -0,0 +1,221 @@
# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from task_adapter.utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
from kornia.contrib import distance_transform
import matplotlib.pyplot as plt
import cv2
import io
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
from segment_anything import SamAutomaticMaskGenerator
from segment_anything.utils.amg import (
MaskData,
area_from_rle,
batch_iterator,
batched_mask_to_box,
box_xyxy_to_xywh,
build_all_layer_point_grids,
calculate_stability_score,
coco_encode_rle,
generate_crop_boxes,
is_box_near_crop_edge,
mask_to_rle_pytorch,
remove_small_regions,
rle_to_mask,
uncrop_boxes_xyxy,
uncrop_masks,
uncrop_points,
)
def sam_interactive_mask(mask_generator, points, in_points, in_labels, mask_input):
masks, iou_preds, _ = mask_generator.predictor.predict_torch(
in_points,
in_labels,
mask_input=mask_input,
multimask_output=True,
return_logits=True,
)
nm,_,h,w = masks.shape
# Serialize predictions and store in MaskData
data = MaskData(
masks=masks.flatten(0, 1),
iou_preds=iou_preds.flatten(0, 1),
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
)
del masks
# Calculate stability score
data["stability_score"] = calculate_stability_score(
data["masks"], mask_generator.predictor.model.mask_threshold, mask_generator.stability_score_offset
)
masks = data["masks"].reshape(nm, -1, h, w)
scores = (data['iou_preds'] + data['stability_score']).reshape(nm, -1)
index = torch.stack([torch.arange(nm).cuda(), scores.argmax(dim=1)]).tolist()
return masks[index]
def inference_sam_m2m_interactive(model, image, spatial_masks, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
transform1 = transforms.Compose(t)
image_ori = transform1(image)
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
orig_size = images.shape[-2:]
orig_h, orig_w = orig_size
crop_box = [0,0,orig_w,orig_h]
spatial_masks = spatial_masks[:, None].float().cuda()
spatial_masks = F.interpolate(spatial_masks, size=(orig_h, orig_w), mode='bicubic', align_corners=False) > 0
# generate single center point
# n,_,h,w = spatial_masks.shape
# mask_dt = (distance_transform((~F.pad(spatial_masks, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(n,-1)
# max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
# next_mask = torch.zeros(spatial_masks.shape, device=torch.cuda.current_device()).bool()
# next_mask = next_mask.view(n,-1)
# next_mask[max_xy_idx] = True
# next_mask = next_mask.reshape((n,1,h,w))
# points = next_mask.nonzero()[:,2:].flip(dims=[1]).cpu().numpy()
# stack sampled points
acc_points = []
for i in range(len(spatial_masks)):
points = spatial_masks[i:i+1].nonzero()[:,2:].flip(dims=[1]).cpu().numpy()
rand_ids = np.random.choice(points.shape[0], size=40, replace=True)
points = points[rand_ids]
acc_points.append(points)
_np = len(acc_points)
points = np.concatenate(acc_points)
mask_generator = SamAutomaticMaskGenerator(model)
mask_generator.predictor.set_image(image_ori)
im_size = image_ori.shape[:-1]
transformed_points = mask_generator.predictor.transform.apply_coords(points, im_size)
in_points = torch.as_tensor(transformed_points, device=mask_generator.predictor.device).reshape(_np,-1,2).transpose(0,1)
in_labels = torch.ones((in_points.shape[0], _np), dtype=torch.int, device=mask_generator.predictor.device)
masks = sam_interactive_mask(mask_generator, points, in_points.transpose(0,1), in_labels.transpose(0,1), None)
masks = masks > 0.0
iou_preds = torch.ones(masks.shape[0], dtype=torch.float32)
points = torch.zeros((masks.shape[0], 2), dtype=torch.float32)
mask_data = MaskData(
masks=masks,
iou_preds=iou_preds,
points=points,
)
mask_data["stability_score"] = torch.ones(masks.shape[0], dtype=torch.float32)
del masks
mask_data["boxes"] = batched_mask_to_box(mask_data["masks"])
mask_data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(mask_data["boxes"]))])
# Compress to RLE
mask_data["masks"] = uncrop_masks(mask_data["masks"], crop_box, orig_h, orig_w)
mask_data["rles"] = mask_to_rle_pytorch(mask_data["masks"])
del mask_data["masks"]
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
# Write mask records
outputs = []
for idx in range(len(mask_data["segmentations"])):
ann = {
"segmentation": mask_data["segmentations"][idx],
"area": area_from_rle(mask_data["rles"][idx]),
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
"predicted_iou": mask_data["iou_preds"][idx].item(),
"point_coords": [mask_data["points"][idx].tolist()],
"stability_score": mask_data["stability_score"][idx].item(),
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
}
outputs.append(ann)
from task_adapter.utils.visualizer import Visualizer
visual = Visualizer(image_ori, metadata=metadata)
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
label = 1
# for ann in sorted_anns:
# mask = ann['segmentation']
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
# label += 1
# im = demo.get_image()
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
for i, ann in enumerate(sorted_anns):
mask = ann['segmentation']
color_mask = np.random.random((1, 3)).tolist()[0]
# color_mask = [int(c*255) for c in color_mask]
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
# assign the mask to the mask_map
mask_map[mask == 1] = label
label += 1
im = demo.get_image()
# fig=plt.figure(figsize=(10, 10))
# plt.imshow(image_ori)
# show_anns(outputs)
# fig.canvas.draw()
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
return im, sorted_anns
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))

View File

View File

@@ -0,0 +1,3 @@
from .interactive_seem_m2m_auto import *
from .inference_seem_pano import *
from .inference_seem_interactive import *

View File

@@ -0,0 +1,382 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.nn as nn
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
from typing import Any, Dict, List, Optional, Tuple
from segment_anything.modeling import Sam
from segment_anything.utils.amg import (
MaskData,
area_from_rle,
batch_iterator,
batched_mask_to_box,
box_xyxy_to_xywh,
build_all_layer_point_grids,
calculate_stability_score,
coco_encode_rle,
generate_crop_boxes,
is_box_near_crop_edge,
mask_to_rle_pytorch,
remove_small_regions,
rle_to_mask,
uncrop_boxes_xyxy,
uncrop_masks,
uncrop_points,
)
class SeemAutomaticMaskGenerator:
def __init__(
self,
model: Sam,
points_per_side: Optional[int] = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.9,
stability_score_thresh: float = 0.5,
stability_score_offset: float = 1.0,
box_nms_thresh: float = 0.7,
crop_n_layers: int = 0,
crop_nms_thresh: float = 0.7,
crop_overlap_ratio: float = 512 / 1500,
crop_n_points_downscale_factor: int = 1,
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 0,
output_mode: str = "binary_mask",
) -> None:
"""
Using a SAM model, generates masks for the entire image.
Generates a grid of point prompts over the image, then filters
low quality and duplicate masks. The default settings are chosen
for SAM with a ViT-H backbone.
Arguments:
model (Sam): The SAM model to use for mask prediction.
points_per_side (int or None): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2. If None, 'point_grids' must provide explicit
point sampling.
points_per_batch (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
pred_iou_thresh (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
box_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks.
crop_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops.
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_n_points_downscale_factor (int): The number of points-per-side
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
point_grids (list(np.ndarray) or None): A list over explicit grids
of points used for sampling, normalized to [0,1]. The nth grid in the
list is used in the nth crop layer. Exclusive with points_per_side.
min_mask_region_area (int): If >0, postprocessing will be applied
to remove disconnected regions and holes in masks with area smaller
than min_mask_region_area. Requires opencv.
output_mode (str): The form masks are returned in. Can be 'binary_mask',
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
For large resolutions, 'binary_mask' may consume large amounts of
memory.
"""
assert (points_per_side is None) != (
point_grids is None
), "Exactly one of points_per_side or point_grid must be provided."
if points_per_side is not None:
self.point_grids = build_all_layer_point_grids(
points_per_side,
crop_n_layers,
crop_n_points_downscale_factor,
)
elif point_grids is not None:
self.point_grids = point_grids
else:
raise ValueError("Can't have both points_per_side and point_grid be None.")
assert output_mode in [
"binary_mask",
"uncompressed_rle",
"coco_rle",
], f"Unknown output_mode {output_mode}."
if output_mode == "coco_rle":
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
if min_mask_region_area > 0:
import cv2 # type: ignore # noqa: F401
self.predictor = model
self.points_per_batch = points_per_batch
self.pred_iou_thresh = pred_iou_thresh
self.stability_score_thresh = stability_score_thresh
self.stability_score_offset = stability_score_offset
self.box_nms_thresh = box_nms_thresh
self.crop_n_layers = crop_n_layers
self.crop_nms_thresh = crop_nms_thresh
self.crop_overlap_ratio = crop_overlap_ratio
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
self.min_mask_region_area = min_mask_region_area
self.output_mode = output_mode
# dilate conv
self.dilation = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3, bias=False)
self.dilation.weight.data.fill_(1.0)
self.dilation.cuda()
@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""
Generates masks for the given image.
Arguments:
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
Returns:
list(dict(str, any)): A list over records for masks. Each record is
a dict containing the following keys:
segmentation (dict(str, any) or np.ndarray): The mask. If
output_mode='binary_mask', is an array of shape HW. Otherwise,
is a dictionary containing the RLE.
bbox (list(float)): The box around the mask, in XYWH format.
area (int): The area in pixels of the mask.
predicted_iou (float): The model's own prediction of the mask's
quality. This is filtered by the pred_iou_thresh parameter.
point_coords (list(list(float))): The point coordinates input
to the model to generate this mask.
stability_score (float): A measure of the mask's quality. This
is filtered on using the stability_score_thresh parameter.
crop_box (list(float)): The crop of the image used to generate
the mask, given in XYWH format.
"""
# Generate masks
mask_data = self._generate_masks(image)
# Filter small disconnected regions and holes in masks
if self.min_mask_region_area > 0:
mask_data = self.postprocess_small_regions(
mask_data,
self.min_mask_region_area,
max(self.box_nms_thresh, self.crop_nms_thresh),
)
# Encode masks
if self.output_mode == "coco_rle":
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
elif self.output_mode == "binary_mask":
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
else:
mask_data["segmentations"] = mask_data["rles"]
# Write mask records
curr_anns = []
for idx in range(len(mask_data["segmentations"])):
ann = {
"segmentation": mask_data["segmentations"][idx],
"area": area_from_rle(mask_data["rles"][idx]),
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
"predicted_iou": mask_data["iou_preds"][idx].item(),
"point_coords": [mask_data["points"][idx].tolist()],
"stability_score": mask_data["stability_score"][idx].item(),
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
}
curr_anns.append(ann)
return curr_anns
def _generate_masks(self, image: np.ndarray) -> MaskData:
orig_size = image.shape[-2:]
crop_boxes, layer_idxs = generate_crop_boxes(
orig_size, self.crop_n_layers, self.crop_overlap_ratio
)
# Iterate over image crops
data = MaskData()
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
data.cat(crop_data)
# Remove duplicate masks between crops
if len(crop_boxes) > 1:
# Prefer masks from smaller crops
scores = 1 / box_area(data["crop_boxes"])
scores = scores.to(data["boxes"].device)
keep_by_nms = batched_nms(
data["boxes"].float(),
scores,
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.crop_nms_thresh,
)
data.filter(keep_by_nms)
data.to_numpy()
return data
def _process_crop(
self,
image: np.ndarray,
crop_box: List[int],
crop_layer_idx: int,
orig_size: Tuple[int, ...],
) -> MaskData:
# Crop the image and calculate embeddings
x0, y0, x1, y1 = crop_box
cropped_im = image#[y0:y1, x0:x1, :]
cropped_im_size = cropped_im.shape[-2:]
# self.predictor.set_image(cropped_im)
# Get points for this crop
points_scale = np.array(cropped_im_size)[None, ::-1]
points_for_image = self.point_grids[crop_layer_idx] #* points_scale
# Generate masks for this crop in batches
data = MaskData()
self.enc_features=None
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self._process_batch(cropped_im, points, cropped_im_size, crop_box, orig_size)
data.cat(batch_data)
del batch_data
# Remove duplicates within this crop.
keep_by_nms = batched_nms(
data["boxes"].float(),
data["iou_preds"],
torch.zeros(len(data["boxes"])), # categories
iou_threshold=self.box_nms_thresh,
)
data.filter(keep_by_nms)
# Return to the original image frame
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
return data
def _process_batch(
self,
images,
points: np.ndarray,
im_size: Tuple[int, ...],
crop_box: List[int],
orig_size: Tuple[int, ...],
) -> MaskData:
orig_h, orig_w = orig_size
data = {"image": images, "height": orig_h, "width": orig_w}
points = torch.tensor(points,dtype=torch.float).to(images.device)
# prepare interactive mask for seem
abs_points = (points * torch.tensor(orig_size)[None,:].to(points.device)).long()
abs_masks = torch.zeros((len(points), orig_h, orig_w), dtype=torch.bool).to(device=points.device)
abs_masks[torch.arange(0, abs_points.size(0))[:,None], abs_points[:,0:1], abs_points[:,1:2]] = True
abs_masks = self.dilation(abs_masks[:,None].float())[:,0] > 0
data['spatial_query'] = {'rand_shape': abs_masks[:,None]}
batch_inputs = [data]
if self.enc_features is None:
masks, iou_preds, mask_features, transformer_encoder_features, multi_scale_features = self.predictor.model.evaluate_demo(batch_inputs, None, None, return_features=True)
self.enc_features = (mask_features, transformer_encoder_features, multi_scale_features)
else:
masks, iou_preds = self.predictor.model.evaluate_demo(batch_inputs, self.enc_features[0], self.enc_features[1], self.enc_features[2])
data = MaskData(
masks=masks,
iou_preds=iou_preds,
points=points,
)
del masks
# Filter by predicted IoU
if self.pred_iou_thresh > 0.0:
keep_mask = data["iou_preds"] > self.pred_iou_thresh
data.filter(keep_mask)
# Calculate stability score
data["stability_score"] = calculate_stability_score(
data["masks"], 0.0, self.stability_score_offset
)
if self.stability_score_thresh > 0.0:
keep_mask = data["stability_score"] >= self.stability_score_thresh
data.filter(keep_mask)
# Threshold masks and calculate boxes
data["masks"] = data["masks"] > 0.0
data["boxes"] = batched_mask_to_box(data["masks"])
# Filter boxes that touch crop boundaries
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
if not torch.all(keep_mask):
data.filter(keep_mask)
# Compress to RLE
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
data["rles"] = mask_to_rle_pytorch(data["masks"])
del data["masks"]
return data
@staticmethod
def postprocess_small_regions(
mask_data: MaskData, min_area: int, nms_thresh: float
) -> MaskData:
"""
Removes small disconnected regions and holes in masks, then reruns
box NMS to remove any new duplicates.
Edits mask_data in place.
Requires open-cv as a dependency.
"""
if len(mask_data["rles"]) == 0:
return mask_data
# Filter small disconnected regions and holes
new_masks = []
scores = []
for rle in mask_data["rles"]:
mask = rle_to_mask(rle)
mask, changed = remove_small_regions(mask, min_area, mode="holes")
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode="islands")
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
# Give score=0 to changed masks and score=1 to unchanged masks
# so NMS will prefer ones that didn't need postprocessing
scores.append(float(unchanged))
# Recalculate boxes and remove any new duplicates
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks)
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores),
torch.zeros_like(boxes[:, 0]), # categories
iou_threshold=nms_thresh,
)
# Only recalculate RLEs for masks that have changed
for i_mask in keep_by_nms:
if scores[i_mask] == 0.0:
mask_torch = masks[i_mask].unsqueeze(0)
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
mask_data.filter(keep_by_nms)
return mask_data

View File

@@ -0,0 +1,169 @@
# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import transforms
from task_adapter.utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
import matplotlib.pyplot as plt
import cv2
import io
from .automatic_mask_generator import SeemAutomaticMaskGenerator
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
from segment_anything.utils.amg import (
MaskData,
area_from_rle,
batch_iterator,
batched_mask_to_box,
box_xyxy_to_xywh,
build_all_layer_point_grids,
calculate_stability_score,
coco_encode_rle,
generate_crop_boxes,
is_box_near_crop_edge,
mask_to_rle_pytorch,
remove_small_regions,
rle_to_mask,
uncrop_boxes_xyxy,
uncrop_masks,
uncrop_points,
)
def inference_seem_interactive(model, image, spatial_masks, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
transform1 = transforms.Compose(t)
image_ori = transform1(image)
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
orig_size = images.shape[-2:]
orig_h, orig_w = orig_size
crop_box = [0,0,orig_w,orig_h]
data = {"image": images, "height": orig_h, "width": orig_w}
spatial_masks = spatial_masks[:, None].float().cuda()
spatial_masks = F.interpolate(spatial_masks, size=(orig_h, orig_w), mode='bicubic', align_corners=False) > 0
data['spatial_query'] = {'rand_shape': spatial_masks}
model.model.metadata = metadata
masks, _ = model.model.evaluate_demo([data])
masks = masks > 0.0
iou_preds = torch.ones(masks.shape[0], dtype=torch.float32)
points = torch.zeros((masks.shape[0], 2), dtype=torch.float32)
mask_data = MaskData(
masks=masks,
iou_preds=iou_preds,
points=points,
)
mask_data["stability_score"] = torch.ones(masks.shape[0], dtype=torch.float32)
del masks
mask_data["boxes"] = batched_mask_to_box(mask_data["masks"])
mask_data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(mask_data["boxes"]))])
# Compress to RLE
mask_data["masks"] = uncrop_masks(mask_data["masks"], crop_box, orig_h, orig_w)
mask_data["rles"] = mask_to_rle_pytorch(mask_data["masks"])
del mask_data["masks"]
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
# Write mask records
outputs = []
for idx in range(len(mask_data["segmentations"])):
ann = {
"segmentation": mask_data["segmentations"][idx],
"area": area_from_rle(mask_data["rles"][idx]),
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
"predicted_iou": mask_data["iou_preds"][idx].item(),
"point_coords": [mask_data["points"][idx].tolist()],
"stability_score": mask_data["stability_score"][idx].item(),
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
}
outputs.append(ann)
from task_adapter.utils.visualizer import Visualizer
visual = Visualizer(image_ori, metadata=metadata)
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
label = 1
# for ann in sorted_anns:
# mask = ann['segmentation']
# color_mask = np.random.random((1, 3)).tolist()[0]
# # color_mask = [int(c*255) for c in color_mask]
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
# label += 1
# im = demo.get_image()
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
for i, ann in enumerate(sorted_anns):
mask = ann['segmentation']
color_mask = np.random.random((1, 3)).tolist()[0]
# color_mask = [int(c*255) for c in color_mask]
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
# assign the mask to the mask_map
mask_map[mask == 1] = label
label += 1
im = demo.get_image()
# fig=plt.figure(figsize=(10, 10))
# plt.imshow(image_ori)
# show_anns(outputs)
# fig.canvas.draw()
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
return im, sorted_anns
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))

View File

@@ -0,0 +1,164 @@
# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------
import torch
import numpy as np
from torchvision import transforms
from task_adapter.utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
import matplotlib.pyplot as plt
import cv2
import io
from .automatic_mask_generator import SeemAutomaticMaskGenerator
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
from segment_anything.utils.amg import (
MaskData,
area_from_rle,
batch_iterator,
batched_mask_to_box,
box_xyxy_to_xywh,
build_all_layer_point_grids,
calculate_stability_score,
coco_encode_rle,
generate_crop_boxes,
is_box_near_crop_edge,
mask_to_rle_pytorch,
remove_small_regions,
rle_to_mask,
uncrop_boxes_xyxy,
uncrop_masks,
uncrop_points,
)
def inference_seem_pano(model, image, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
transform1 = transforms.Compose(t)
image_ori = transform1(image)
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
orig_size = images.shape[-2:]
orig_h, orig_w = orig_size
crop_box = [0,0,orig_w,orig_h]
data = {"image": images, "height": orig_h, "width": orig_w}
batch_inputs = [data]
model.model.metadata = metadata
outputs = model.model.evaluate(batch_inputs)
pano_mask = outputs[0]['panoptic_seg'][0]
pano_info = outputs[0]['panoptic_seg'][1]
masks = []
for seg_info in pano_info:
masks += [pano_mask == seg_info['id']]
masks = torch.stack(masks, dim=0)
iou_preds = torch.ones(masks.shape[0], dtype=torch.float32)
points = torch.zeros((masks.shape[0], 2), dtype=torch.float32)
mask_data = MaskData(
masks=masks,
iou_preds=iou_preds,
points=points,
)
mask_data["stability_score"] = torch.ones(masks.shape[0], dtype=torch.float32)
del masks
mask_data["boxes"] = batched_mask_to_box(mask_data["masks"])
mask_data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(mask_data["boxes"]))])
# Compress to RLE
mask_data["masks"] = uncrop_masks(mask_data["masks"], crop_box, orig_h, orig_w)
mask_data["rles"] = mask_to_rle_pytorch(mask_data["masks"])
del mask_data["masks"]
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
# Write mask records
outputs = []
for idx in range(len(mask_data["segmentations"])):
ann = {
"segmentation": mask_data["segmentations"][idx],
"area": area_from_rle(mask_data["rles"][idx]),
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
"predicted_iou": mask_data["iou_preds"][idx].item(),
"point_coords": [mask_data["points"][idx].tolist()],
"stability_score": mask_data["stability_score"][idx].item(),
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
}
outputs.append(ann)
from task_adapter.utils.visualizer import Visualizer
visual = Visualizer(image_ori, metadata=metadata)
# create a full zero image as the image_orig
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
label = 1
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
for i, ann in enumerate(sorted_anns):
mask = ann['segmentation']
color_mask = np.random.random((1, 3)).tolist()[0]
# color_mask = [int(c*255) for c in color_mask]
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
# assign the mask to the mask_map
mask_map[mask == 1] = label
label += 1
im = demo.get_image()
# fig=plt.figure(figsize=(10, 10))
# plt.imshow(image_ori)
# show_anns(outputs)
# fig.canvas.draw()
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
return im, sorted_anns
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))

View File

@@ -0,0 +1,93 @@
# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------
import torch
import numpy as np
from torchvision import transforms
from task_adapter.utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
import matplotlib.pyplot as plt
import cv2
import io
from .automatic_mask_generator import SeemAutomaticMaskGenerator
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
def interactive_seem_m2m_auto(model, image, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
transform1 = transforms.Compose(t)
image_ori = transform1(image)
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
mask_generator = SeemAutomaticMaskGenerator(model)
outputs = mask_generator.generate(images)
from task_adapter.utils.visualizer import Visualizer
visual = Visualizer(image_ori, metadata=metadata)
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
label = 1
for ann in sorted_anns:
mask = ann['segmentation']
color_mask = np.random.random((1, 3)).tolist()[0]
# color_mask = [int(c*255) for c in color_mask]
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
label += 1
im = demo.get_image()
# fig=plt.figure(figsize=(10, 10))
# plt.imshow(image_ori)
# show_anns(outputs)
# fig.canvas.draw()
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
return im
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))

View File

@@ -0,0 +1,6 @@
from .interactive_idino_m2m import interactive_infer_image as interactive_infer_image_idino_m2m
from .interactive_idino_m2m import interactive_infer_image_semantic, interactive_infer_image_3l
from .inference_semsam_m2m_auto import inference_semsam_m2m_auto
from .interactive_idino_1o1_box import interactive_infer_image_box as interactive_infer_image_idino_m2m_box
from .automatic_mask_generator import prompt_switch
from .interactive_predictor import SemanticSAMPredictor

View File

@@ -0,0 +1,393 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
from typing import Any, Dict, List, Optional, Tuple
# from
# from .modeling import Sam
# from .predictor import SamPredictor
from semantic_sam.utils.sam_utils.amg import (
MaskData,
area_from_rle,
batch_iterator,
batched_mask_to_box,
box_xyxy_to_xywh,
build_all_layer_point_grids,
calculate_stability_score,
coco_encode_rle,
generate_crop_boxes,
is_box_near_crop_edge,
mask_to_rle_pytorch,
remove_small_regions,
rle_to_mask,
uncrop_boxes_xyxy,
uncrop_masks,
uncrop_points,
)
def prompt_switch(p):
p = int(p)
if p == 1:
return 3
if p == 2:
return 2
if p == 3:
return 0
if p == 4:
return 4
if p == 5:
return 1
if p == 6:
return 5
else:
raise NotImplementedError
class SemanticSamAutomaticMaskGenerator:
def __init__(
self,
model,
points_per_side: Optional[int] = 32,
points_per_batch: int = 200,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.92,
stability_score_offset: float = 1.0,
box_nms_thresh: float = 0.7,
crop_n_layers: int = 0,
crop_nms_thresh: float = 0.7,
crop_overlap_ratio: float = 512 / 1500,
crop_n_points_downscale_factor: int = 1,
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 10,
output_mode: str = "binary_mask",
level: list = [1, 2, 3, 4, 5, 6],
) -> None:
"""
Using a SAM model, generates masks for the entire image.
Generates a grid of point prompts over the image, then filters
low quality and duplicate masks. The default settings are chosen
for SAM with a ViT-H backbone.
Arguments:
model (Sam): The SAM model to use for mask prediction.
points_per_side (int or None): The number of points to be sampled
along one side of the image. The total number of points is
points_per_side**2. If None, 'point_grids' must provide explicit
point sampling.
points_per_batch (int): Sets the number of points run simultaneously
by the model. Higher numbers may be faster but use more GPU memory.
pred_iou_thresh (float): A filtering threshold in [0,1], using the
model's predicted mask quality.
stability_score_thresh (float): A filtering threshold in [0,1], using
the stability of the mask under changes to the cutoff used to binarize
the model's mask predictions.
stability_score_offset (float): The amount to shift the cutoff when
calculated the stability score.
box_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks.
crops_n_layers (int): If >0, mask prediction will be run again on
crops of the image. Sets the number of layers to run, where each
layer has 2**i_layer number of image crops.
crops_nms_thresh (float): The box IoU cutoff used by non-maximal
suppression to filter duplicate masks between different crops.
crop_overlap_ratio (float): Sets the degree to which crops overlap.
In the first crop layer, crops will overlap by this fraction of
the image length. Later layers with more crops scale down this overlap.
crop_n_points_downscale_factor (int): The number of points-per-side
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
point_grids (list(np.ndarray) or None): A list over explicit grids
of points used for sampling, normalized to [0,1]. The nth grid in the
list is used in the nth crop layer. Exclusive with points_per_side.
min_mask_region_area (int): If >0, postprocessing will be applied
to remove disconnected regions and holes in masks with area smaller
than min_mask_region_area. Requires opencv.
output_mode (str): The form masks are returned in. Can be 'binary_mask',
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
For large resolutions, 'binary_mask' may consume large amounts of
memory.
"""
self.level = [prompt_switch(l) for l in level]
assert (points_per_side is None) != (
point_grids is None
), "Exactly one of points_per_side or point_grid must be provided."
if points_per_side is not None:
self.point_grids = build_all_layer_point_grids(
points_per_side,
crop_n_layers,
crop_n_points_downscale_factor,
)
elif point_grids is not None:
self.point_grids = point_grids
else:
raise ValueError("Can't have both points_per_side and point_grid be None.")
assert output_mode in [
"binary_mask",
"uncompressed_rle",
"coco_rle",
], f"Unknown output_mode {output_mode}."
if output_mode == "coco_rle":
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
if min_mask_region_area > 0:
import cv2 # type: ignore # noqa: F401
self.predictor = model
self.points_per_batch = points_per_batch
self.pred_iou_thresh = pred_iou_thresh
self.stability_score_thresh = stability_score_thresh
self.stability_score_offset = stability_score_offset
self.box_nms_thresh = box_nms_thresh
self.crop_n_layers = crop_n_layers
self.crop_nms_thresh = crop_nms_thresh
self.crop_overlap_ratio = crop_overlap_ratio
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
self.min_mask_region_area = min_mask_region_area
self.output_mode = output_mode
@torch.no_grad()
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""
Generates masks for the given image.
Arguments:
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
Returns:
list(dict(str, any)): A list over records for masks. Each record is
a dict containing the following keys:
segmentation (dict(str, any) or np.ndarray): The mask. If
output_mode='binary_mask', is an array of shape HW. Otherwise,
is a dictionary containing the RLE.
bbox (list(float)): The box around the mask, in XYWH format.
area (int): The area in pixels of the mask.
predicted_iou (float): The model's own prediction of the mask's
quality. This is filtered by the pred_iou_thresh parameter.
point_coords (list(list(float))): The point coordinates input
to the model to generate this mask.
stability_score (float): A measure of the mask's quality. This
is filtered on using the stability_score_thresh parameter.
crop_box (list(float)): The crop of the image used to generate
the mask, given in XYWH format.
"""
# Generate masks
mask_data = self._generate_masks(image)
# Filter small disconnected regions and holes in masks
if self.min_mask_region_area > 0:
mask_data = self.postprocess_small_regions(
mask_data,
self.min_mask_region_area,
max(self.box_nms_thresh, self.crop_nms_thresh),
)
# Encode masks
if self.output_mode == "coco_rle":
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
elif self.output_mode == "binary_mask":
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
else:
mask_data["segmentations"] = mask_data["rles"]
# Write mask records
curr_anns = []
for idx in range(len(mask_data["segmentations"])):
ann = {
"segmentation": mask_data["segmentations"][idx],
"area": area_from_rle(mask_data["rles"][idx]),
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
"predicted_iou": mask_data["iou_preds"][idx].item(),
"point_coords": [mask_data["points"][idx].tolist()],
"stability_score": mask_data["stability_score"][idx].item(),
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
}
curr_anns.append(ann)
return curr_anns
def _generate_masks(self, image: np.ndarray) -> MaskData:
orig_size = image.shape[-2:]
crop_boxes, layer_idxs = generate_crop_boxes(
orig_size, self.crop_n_layers, self.crop_overlap_ratio
)
# Iterate over image crops
assert len(crop_boxes)==1
data = MaskData()
# import ipdb; ipdb.set_trace()
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
data.cat(crop_data)
# import ipdb; ipdb.set_trace()
# Remove duplicate masks between crops
if len(crop_boxes) > 1:
# Prefer masks from smaller crops
scores = 1 / box_area(data["crop_boxes"])
scores = scores.to(data["boxes"].device)
keep_by_nms = batched_nms(
data["boxes"].float(),
scores,
torch.zeros(len(data["boxes"])), # categories
iou_threshold=self.crop_nms_thresh,
)
data.filter(keep_by_nms)
data.to_numpy()
return data
def _process_crop(
self,
image: np.ndarray,
crop_box: List[int],
crop_layer_idx: int,
orig_size: Tuple[int, ...],
) -> MaskData:
# Crop the image and calculate embeddings
x0, y0, x1, y1 = crop_box
cropped_im = image#[y0:y1, x0:x1, :]
cropped_im_size = cropped_im.shape[-2:]
# self.predictor.set_image(cropped_im)
# Get points for this crop
points_scale = np.array(cropped_im_size)[None, ::-1]
points_for_image = self.point_grids[crop_layer_idx] #* points_scale
# Generate masks for this crop in batches
data = MaskData()
self.enc_features=None
# import ipdb; ipdb.set_trace()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self._process_batch(cropped_im,points, cropped_im_size, crop_box, orig_size)
data.cat(batch_data)
del batch_data
keep_by_nms = batched_nms(
data["boxes"].float(),
data["iou_preds"],
torch.zeros(len(data["boxes"])), # categories
iou_threshold=self.box_nms_thresh,
)
# import ipdb; ipdb.set_trace()
data.filter(keep_by_nms)
# import ipdb; ipdb.set_trace()
# Return to the original image frame
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
return data
def _process_batch(
self,
images,
points: np.ndarray,
im_size: Tuple[int, ...],
crop_box: List[int],
orig_size: Tuple[int, ...],
) -> MaskData:
orig_h, orig_w = orig_size
data = {"image": images, "height": orig_h, "width": orig_w}
points=torch.tensor(points,dtype=torch.float).to(images.device)
points = torch.cat([points, points.new_tensor([[0.005, 0.005]]).repeat(len(points), 1)], dim=-1)
data['targets'] = [dict()]
data['targets'][0]['points']=points
data['targets'][0]['pb']=points.new_tensor([0.]*len(points))
batch_inputs = [data]
if self.enc_features is None:
masks, iou_preds,mask_features,multi_scale_features= self.predictor.model.evaluate_demo(batch_inputs,None,None,return_features=True, level=self.level)
self.enc_features=(mask_features,multi_scale_features)
else:
masks, iou_preds= self.predictor.model.evaluate_demo(batch_inputs,None,None,self.enc_features[0],self.enc_features[1], level=self.level)
data = MaskData(
masks=masks,
iou_preds=iou_preds.flatten(),
points=torch.as_tensor(points[:,None].repeat(1,len(self.level), 1).view(-1,4)),
)
del masks
# Filter by predicted IoU
keep_mask = data["iou_preds"] > self.pred_iou_thresh
data.filter(keep_mask)
# Calculate stability score
data["stability_score"] = calculate_stability_score(
data["masks"], 0.0, self.stability_score_offset
)
# if self.stability_score_thresh > 0.0:
keep_mask = data["stability_score"] >= self.stability_score_thresh
data.filter(keep_mask)
# Threshold masks and calculate boxes
data["masks"] = data["masks"] > 0.0
data["boxes"] = batched_mask_to_box(data["masks"])
# Filter boxes that touch crop boundaries
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
if not torch.all(keep_mask):
data.filter(keep_mask)
# Compress to RLE
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
data["rles"] = mask_to_rle_pytorch(data["masks"])
del data["masks"]
return data
@staticmethod
def postprocess_small_regions(
mask_data: MaskData, min_area: int, nms_thresh: float
) -> MaskData:
"""
Removes small disconnected regions and holes in masks, then reruns
box NMS to remove any new duplicates.
Edits mask_data in place.
Requires open-cv as a dependency.
"""
if len(mask_data["rles"]) == 0:
return mask_data
# Filter small disconnected regions and holes
new_masks = []
scores = []
for rle in mask_data["rles"]:
mask = rle_to_mask(rle)
mask, changed = remove_small_regions(mask, min_area, mode="holes")
unchanged = not changed
mask, changed = remove_small_regions(mask, min_area, mode="islands")
unchanged = unchanged and not changed
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
# Give score=0 to changed masks and score=1 to unchanged masks
# so NMS will prefer ones that didn't need postprocessing
scores.append(float(unchanged))
# Recalculate boxes and remove any new duplicates
masks = torch.cat(new_masks, dim=0)
boxes = batched_mask_to_box(masks)
keep_by_nms = batched_nms(
boxes.float(),
torch.as_tensor(scores),
torch.zeros(len(boxes)), # categories
iou_threshold=nms_thresh,
)
# Only recalculate RLEs for masks that have changed
for i_mask in keep_by_nms:
if scores[i_mask] == 0.0:
mask_torch = masks[i_mask].unsqueeze(0)
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
mask_data.filter(keep_by_nms)
return mask_data

View File

@@ -0,0 +1,108 @@
# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------
import torch
import numpy as np
from torchvision import transforms
from task_adapter.utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
import matplotlib.pyplot as plt
import cv2
import io
from .automatic_mask_generator import SemanticSamAutomaticMaskGenerator
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
def inference_semsam_m2m_auto(model, image, level, all_classes, all_parts, thresh, text_size, hole_scale, island_scale, semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None, label_mode='1', alpha=0.1, anno_mode=['Mask']):
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
transform1 = transforms.Compose(t)
image_ori = transform1(image)
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
mask_generator = SemanticSamAutomaticMaskGenerator(model,points_per_side=32,
pred_iou_thresh=0.88,
stability_score_thresh=0.92,
min_mask_region_area=10,
level=level,
)
outputs = mask_generator.generate(images)
from task_adapter.utils.visualizer import Visualizer
visual = Visualizer(image_ori, metadata=metadata)
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
label = 1
# for ann in sorted_anns:
# mask = ann['segmentation']
# color_mask = np.random.random((1, 3)).tolist()[0]
# # color_mask = [int(c*255) for c in color_mask]
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
# label += 1
# im = demo.get_image()
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
for i, ann in enumerate(sorted_anns):
mask = ann['segmentation']
color_mask = np.random.random((1, 3)).tolist()[0]
# color_mask = [int(c*255) for c in color_mask]
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
# assign the mask to the mask_map
mask_map[mask == 1] = label
label += 1
im = demo.get_image()
# fig=plt.figure(figsize=(10, 10))
# plt.imshow(image_ori)
# show_anns(outputs)
# fig.canvas.draw()
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
return im, sorted_anns
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:,:,i] = color_mask[i]
ax.imshow(np.dstack((img, m*0.35)))

View File

@@ -0,0 +1,144 @@
# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------
import torch
import numpy as np
from torchvision import transforms
from task_adapter.utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
from detectron2.structures import BitMasks
from semantic_sam.utils import box_ops
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
def interactive_infer_image_box(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
transform1 = transforms.Compose(t)
image_ori = transform1(image['image'])
mask_ori = transform1(image['mask'])
width = image_ori.size[0]
height = image_ori.size[1]
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
data = {"image": images, "height": height, "width": width}
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
flaten_mask = mask_ori.unsqueeze(0)
# import ipdb; ipdb.set_trace()
points=mask_ori.nonzero().float().to(images.device)
if len(points)==0:
point_=point=points.new_tensor([[0.5,0.5,0.5,0.5]])
else:
mean_point=points.mean(0)[None]
box_xyxy = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor
h = mask_ori.shape[0]
w = mask_ori.shape[1]
box_xywh = (box_ops.box_xyxy_to_cxcywh(box_xyxy) / torch.as_tensor([w, h, w, h])).cuda()
# point_=points.mean(0)[None]
# point=point_.clone()
# point[0, 0] = point_[0, 0] / mask_ori.shape[0]
# point[0, 1] = point_[0, 1] / mask_ori.shape[1]
# point = point[:, [1, 0]]
point=box_xywh
data['targets'] = [dict()]
data['targets'][0]['points']=point
data['targets'][0]['pb']=point.new_tensor([1.])
batch_inputs = [data]
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts, task='demo_box')
pred_masks_poses = masks
reses=[]
ious=ious[0,0]
ids=torch.argsort(ious,descending=True)
text_res=''
try:
thresh=float(thresh)
except Exception:
thresh=0.0
mask_ls=[]
ious_res=[]
areas=[]
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
iou=round(float(iou),2)
texts=f'{iou}'
mask=(pred_masks_pos>0.0).cpu().numpy()
area=mask.sum()
conti=False
if iou<thresh:
conti=True
for m in mask_ls:
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
conti=True
break
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
conti=False
if conti:
continue
ious_res.append(iou)
mask_ls.append(mask)
areas.append(area)
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
mask=(mask).astype(np.float)
out_txt = texts
visual = Visualizer(image_ori, metadata=metadata)
color=[0.,0.,1.0]
demo = visual.draw_binary_mask(mask, color=color, text=texts)
demo = visual.draw_box(box_xyxy[0])
res = demo.get_image()
# point_x0=max(0,int(point_[0, 1])-3)
# point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
# point_y0 = max(0, int(point_[0, 0]) - 3)
# point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
# res[point_y0:point_y1,point_x0:point_x1,0]=255
# res[point_y0:point_y1,point_x0:point_x1,1]=0
# res[point_y0:point_y1,point_x0:point_x1,2]=0
reses.append(Image.fromarray(res))
text_res=text_res+';'+out_txt
ids=list(torch.argsort(torch.tensor(areas),descending=False))
ids = [int(i) for i in ids]
torch.cuda.empty_cache()
return reses,[reses[i] for i in ids]
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True

View File

@@ -0,0 +1,322 @@
# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------
import torch
import numpy as np
from torchvision import transforms
from task_adapter.utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
def interactive_infer_image(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None, label_mode='1', alpha=0.1, anno_mode=['Mask']):
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
transform1 = transforms.Compose(t)
image_ori = transform1(image['image'])
mask_ori = transform1(image['mask'])
width = image_ori.size[0]
height = image_ori.size[1]
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
data = {"image": images, "height": height, "width": width}
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
points=mask_ori.nonzero().float().to(images.device)
if len(points)==0:
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
else:
point_=points.mean(0)[None]
point=point_.clone()
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
point = point[:, [1, 0]]
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
data['targets'] = [dict()]
data['targets'][0]['points']=point
data['targets'][0]['pb']=point.new_tensor([0.])
batch_inputs = [data]
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts)
pred_masks_poses = masks
reses=[]
ious=ious[0,0]
ids=torch.argsort(ious,descending=True)
text_res=''
try:
thresh=float(thresh)
except Exception:
thresh=0.0
mask_ls=[]
ious_res=[]
areas=[]
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
iou=round(float(iou),2)
texts=f'{iou}'
mask=(pred_masks_pos>0.0).cpu().numpy()
area=mask.sum()
conti=False
if iou<thresh:
conti=True
for m in mask_ls:
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
conti=True
break
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
conti=False
if conti:
continue
ious_res.append(iou)
mask_ls.append(mask)
areas.append(area)
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
mask=(mask).astype(np.float)
out_txt = texts
visual = Visualizer(image_ori, metadata=metadata)
color=[0.,0.,1.0]
# demo = visual.draw_binary_mask(mask, color=color, text=texts)
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
res = demo.get_image()
point_x0=max(0,int(point_[0, 1])-3)
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
point_y0 = max(0, int(point_[0, 0]) - 3)
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
# res[point_y0:point_y1,point_x0:point_x1,0]=255
# res[point_y0:point_y1,point_x0:point_x1,1]=0
# res[point_y0:point_y1,point_x0:point_x1,2]=0
reses.append(Image.fromarray(res))
text_res=text_res+';'+out_txt
ids=list(torch.argsort(torch.tensor(areas),descending=False))
ids = [int(i) for i in ids]
torch.cuda.empty_cache()
return reses,[reses[i] for i in ids]
def interactive_infer_image_3l(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
transform1 = transforms.Compose(t)
image_ori = transform1(image['image'])
mask_ori = transform1(image['mask'])
width = image_ori.size[0]
height = image_ori.size[1]
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
data = {"image": images, "height": height, "width": width}
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
points=mask_ori.nonzero().float().to(images.device)
if len(points)==0:
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
else:
point_=points.mean(0)[None]
point=point_.clone()
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
point = point[:, [1, 0]]
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
data['targets'] = [dict()]
data['targets'][0]['points']=point
data['targets'][0]['pb']=point.new_tensor([0.])
batch_inputs = [data]
masks, ious, pred_class, pred_class_score = model.model.evaluate_demo(batch_inputs,all_classes,all_parts, level=[0,1,2])
pred_masks_poses = masks
reses=[]
ious=ious[0,0]
ids=torch.argsort(ious,descending=True)
text_res=''
try:
thresh=float(thresh)
except Exception:
thresh=0.0
mask_ls=[]
ious_res=[]
areas=[]
new_pred_class = []
new_pred_class_score = []
for i in ids:
new_pred_class_score.append(pred_class_score[i])
new_pred_class.append(pred_class[i])
# import ipdb; ipdb.set_trace()
for i,(pred_masks_pos,iou, cls_name, cls_score) in enumerate(zip(pred_masks_poses[ids],ious[ids], new_pred_class, new_pred_class_score)):
iou=round(float(iou),2)
texts=f'{iou}_{cls_name}_{cls_score}'
mask=(pred_masks_pos>0.0).cpu().numpy()
area=mask.sum()
conti=False
if iou<thresh:
conti=True
for m in mask_ls:
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
conti=True
break
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
conti=False
if conti:
continue
ious_res.append(iou)
mask_ls.append(mask)
areas.append(area)
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
mask=(mask).astype(np.float)
out_txt = texts
visual = Visualizer(image_ori, metadata=metadata)
color=[0.,0.,1.0]
demo = visual.draw_binary_mask(mask, color=color, text=texts)
res = demo.get_image()
point_x0=max(0,int(point_[0, 1])-3)
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
point_y0 = max(0, int(point_[0, 0]) - 3)
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
res[point_y0:point_y1,point_x0:point_x1,0]=255
res[point_y0:point_y1,point_x0:point_x1,1]=0
res[point_y0:point_y1,point_x0:point_x1,2]=0
reses.append(Image.fromarray(res))
text_res=text_res+';'+out_txt
ids=list(torch.argsort(torch.tensor(areas),descending=False))
ids = [int(i) for i in ids]
torch.cuda.empty_cache()
return reses,[reses[i] for i in ids]
def interactive_infer_image_semantic(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
transform1 = transforms.Compose(t)
image_ori = transform1(image['image'])
mask_ori = transform1(image['mask'])
width = image_ori.size[0]
height = image_ori.size[1]
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
data = {"image": images, "height": height, "width": width}
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
points=mask_ori.nonzero().float().to(images.device)
if len(points)==0:
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
else:
point_=points.mean(0)[None]
point=point_.clone()
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
point = point[:, [1, 0]]
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
data['targets'] = [dict()]
data['targets'][0]['points']=point
data['targets'][0]['pb']=point.new_tensor([0.])
data['targets'][0]['pb']=point.new_tensor([1.])
batch_inputs = [data]
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts)
pred_masks_poses = masks
reses=[]
ious=ious[0,0]
ids=torch.argsort(ious,descending=True)
text_res=''
try:
thresh=float(thresh)
except Exception:
thresh=0.0
mask_ls=[]
ious_res=[]
areas=[]
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
iou=round(float(iou),2)
texts=f'{iou}'
mask=(pred_masks_pos>0.0).cpu().numpy()
area=mask.sum()
conti=False
if iou<thresh:
conti=True
for m in mask_ls:
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
conti=True
break
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
conti=False
if conti:
continue
ious_res.append(iou)
mask_ls.append(mask)
areas.append(area)
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
mask=(mask).astype(np.float)
out_txt = texts
visual = Visualizer(image_ori, metadata=metadata)
color=[0.,0.,1.0]
demo = visual.draw_binary_mask(mask, color=color, text=texts)
res = demo.get_image()
point_x0=max(0,int(point_[0, 1])-3)
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
point_y0 = max(0, int(point_[0, 0]) - 3)
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
res[point_y0:point_y1,point_x0:point_x1,0]=255
res[point_y0:point_y1,point_x0:point_x1,1]=0
res[point_y0:point_y1,point_x0:point_x1,2]=0
reses.append(Image.fromarray(res))
text_res=text_res+';'+out_txt
ids=list(torch.argsort(torch.tensor(areas),descending=False))
ids = [int(i) for i in ids]
torch.cuda.empty_cache()
return reses,[reses[i] for i in ids]
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True

View File

@@ -0,0 +1,139 @@
import torch
import numpy as np
from torchvision import transforms
from task_adapter.utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
class SemanticSAMPredictor:
def __init__(self, model, thresh=0.5, text_size=640, hole_scale=100, island_scale=100):
"""
thresh: iou thresh to filter low confidence objects
text_size: resize the input image short edge for the model to process
hole_scale: fill in small holes as in SAM
island_scale: remove small regions as in SAM
"""
self.model = model
self.thresh = thresh
self.text_size = hole_scale
self.hole_scale = hole_scale
self.island_scale = island_scale
self.point = None
def predict(self, image_ori, image, point=None):
"""
produce up to 6 prediction results for each click
"""
width = image_ori.shape[0]
height = image_ori.shape[1]
data = {"image": image, "height": height, "width": width}
# import ipdb; ipdb.set_trace()
if point is None:
point = torch.tensor([[0.5, 0.5, 0.006, 0.006]]).cuda()
else:
point = torch.tensor(point).cuda()
point_ = point
point = point_.clone()
point[0, 0] = point_[0, 0]
point[0, 1] = point_[0, 1]
# point = point[:, [1, 0]]
point = torch.cat([point, point.new_tensor([[0.005, 0.005]])], dim=-1)
self.point = point[:, :2].clone()*(torch.tensor([width, height]).to(point))
data['targets'] = [dict()]
data['targets'][0]['points'] = point
data['targets'][0]['pb'] = point.new_tensor([0.])
batch_inputs = [data]
masks, ious = self.model.model.evaluate_demo(batch_inputs)
return masks, ious
def process_multi_mask(self, masks, ious, image_ori):
pred_masks_poses = masks
reses = []
ious = ious[0, 0]
ids = torch.argsort(ious, descending=True)
text_res = ''
mask_ls = []
ious_res = []
areas = []
for i, (pred_masks_pos, iou) in enumerate(zip(pred_masks_poses[ids], ious[ids])):
iou = round(float(iou), 2)
texts = f'{iou}'
mask = (pred_masks_pos > 0.0).cpu().numpy()
area = mask.sum()
conti = False
if iou < self.thresh:
conti = True
for m in mask_ls:
if np.logical_and(mask, m).sum() / np.logical_or(mask, m).sum() > 0.95:
conti = True
break
if i == len(pred_masks_poses[ids]) - 1 and mask_ls == []:
conti = False
if conti:
continue
ious_res.append(iou)
mask_ls.append(mask)
areas.append(area)
mask, _ = self.remove_small_regions(mask, int(self.hole_scale), mode="holes")
mask, _ = self.remove_small_regions(mask, int(self.island_scale), mode="islands")
mask = (mask).astype(np.float)
out_txt = texts
visual = Visualizer(image_ori, metadata=metadata)
color = [0., 0., 1.0]
demo = visual.draw_binary_mask(mask, color=color, text=texts)
res = demo.get_image()
point_x0 = max(0, int(self.point[0, 0]) - 3)
point_x1 = min(image_ori.shape[1], int(self.point[0, 0]) + 3)
point_y0 = max(0, int(self.point[0, 1]) - 3)
point_y1 = min(image_ori.shape[0], int(self.point[0, 1]) + 3)
res[point_y0:point_y1, point_x0:point_x1, 0] = 255
res[point_y0:point_y1, point_x0:point_x1, 1] = 0
res[point_y0:point_y1, point_x0:point_x1, 2] = 0
reses.append(Image.fromarray(res))
text_res = text_res + ';' + out_txt
ids = list(torch.argsort(torch.tensor(areas), descending=False))
ids = [int(i) for i in ids]
torch.cuda.empty_cache()
return reses, [reses[i] for i in ids]
def predict_masks(self, image_ori, image, point=None):
masks, ious = self.predict(image_ori, image, point)
return self.process_multi_mask(masks, ious, image_ori)
@staticmethod
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True

File diff suppressed because it is too large Load Diff