update SoM_agent
This commit is contained in:
@@ -3,10 +3,10 @@
|
||||
import os
|
||||
import re
|
||||
import base64
|
||||
import PIL.Image
|
||||
from desktop_env.envs.desktop_env import Action, MouseClick
|
||||
import json
|
||||
import requests
|
||||
from mm_agents.gpt_4v_prompt import SYS_PROMPT
|
||||
|
||||
import torch
|
||||
import argparse
|
||||
@@ -33,9 +33,37 @@ from task_adapter.sam.tasks.inference_sam_m2m_interactive import inference_sam_m
|
||||
from scipy.ndimage import label
|
||||
import numpy as np
|
||||
|
||||
SYS_PROMPT = '''
|
||||
You will act as an agent which follow my instruction and perform desktop computer tasks as instructed. You must have good knowledge of computer and good internet connection.
|
||||
For each step, you will get an observation of an image, which is the screenshot of the computer screen. And you will predict the action of the computer based on the image.
|
||||
|
||||
Firstly you need to predict the class of your action, select from one below:
|
||||
- **CLICK**: click on the screen with the specified integer label
|
||||
- **TYPE**: type a string on the keyboard
|
||||
|
||||
- For CLICK, you need to predict the correct integer label shown on the screenshot
|
||||
for example, format as:
|
||||
```
|
||||
{
|
||||
"action_type": "MOUSE_MOVE",
|
||||
"label": 7
|
||||
}
|
||||
```
|
||||
- For TYPE, you need to specify the text you want to type
|
||||
for example, format as:
|
||||
```
|
||||
{
|
||||
"action_type": "TYPE",
|
||||
"text": "hello world"
|
||||
}
|
||||
```
|
||||
|
||||
For every step, you should only return the action_type and the parameters of your action as a dict, without any other things. You MUST wrap the dict with backticks (\`).
|
||||
You can predict multiple actions at one step, but you should only return one action for each step.
|
||||
You MUST choose and ONLY CHOOSE from the action space above, otherwise your action will be considered as invalid and you will get a penalty.
|
||||
'''
|
||||
build args
|
||||
'''
|
||||
|
||||
# build args
|
||||
semsam_cfg = "configs/semantic_sam_only_sa-1b_swinL.yaml"
|
||||
seem_cfg = "configs/seem_focall_unicl_lang_v1.yaml"
|
||||
|
||||
@@ -47,9 +75,7 @@ opt_semsam = load_opt_from_config_file(semsam_cfg)
|
||||
opt_seem = load_opt_from_config_file(seem_cfg)
|
||||
opt_seem = init_distributed_seem(opt_seem)
|
||||
|
||||
'''
|
||||
build model
|
||||
'''
|
||||
# build model
|
||||
model_semsam = BaseModel(opt_semsam, build_model(opt_semsam)).from_pretrained(semsam_ckpt).eval().cuda()
|
||||
model_sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt).eval().cuda()
|
||||
model_seem = BaseModel_Seem(opt_seem, build_model_seem(opt_seem)).from_pretrained(seem_ckpt).eval().cuda()
|
||||
@@ -65,65 +91,46 @@ def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs
|
||||
elif slider > 2.5:
|
||||
model_name = 'sam'
|
||||
else:
|
||||
if mode == 'Automatic':
|
||||
model_name = 'semantic-sam'
|
||||
if slider < 1.5 + 0.14:
|
||||
level = [1]
|
||||
elif slider < 1.5 + 0.28:
|
||||
level = [2]
|
||||
elif slider < 1.5 + 0.42:
|
||||
level = [3]
|
||||
elif slider < 1.5 + 0.56:
|
||||
level = [4]
|
||||
elif slider < 1.5 + 0.70:
|
||||
level = [5]
|
||||
elif slider < 1.5 + 0.84:
|
||||
level = [6]
|
||||
else:
|
||||
level = [6, 1, 2, 3, 4, 5]
|
||||
model_name = 'semantic-sam'
|
||||
if slider < 1.5 + 0.14:
|
||||
level = [1]
|
||||
elif slider < 1.5 + 0.28:
|
||||
level = [2]
|
||||
elif slider < 1.5 + 0.42:
|
||||
level = [3]
|
||||
elif slider < 1.5 + 0.56:
|
||||
level = [4]
|
||||
elif slider < 1.5 + 0.70:
|
||||
level = [5]
|
||||
elif slider < 1.5 + 0.84:
|
||||
level = [6]
|
||||
else:
|
||||
model_name = 'sam'
|
||||
level = [6, 1, 2, 3, 4, 5]
|
||||
|
||||
if label_mode == 'Alphabet':
|
||||
label_mode = 'a'
|
||||
else:
|
||||
label_mode = '1'
|
||||
|
||||
text_size, hole_scale, island_scale = 640, 100, 100
|
||||
label_mode = 'a' if label_mode == 'Alphabet' else '1'
|
||||
text_size, hole_scale, island_scale = 1280, 100, 100
|
||||
text, text_part, text_thresh = '', '', '0.0'
|
||||
|
||||
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
||||
semantic = False
|
||||
|
||||
if mode == "Interactive":
|
||||
labeled_array, num_features = label(np.asarray(image['mask'].convert('L')))
|
||||
spatial_masks = torch.stack([torch.from_numpy(labeled_array == i+1) for i in range(num_features)])
|
||||
|
||||
if model_name == 'semantic-sam':
|
||||
model = model_semsam
|
||||
output, mask = inference_semsam_m2m_auto(model, image['image'], level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs)
|
||||
output, mask = inference_semsam_m2m_auto(model, image, level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs)
|
||||
|
||||
elif model_name == 'sam':
|
||||
model = model_sam
|
||||
if mode == "Automatic":
|
||||
output, mask = inference_sam_m2m_auto(model, image['image'], text_size, label_mode, alpha, anno_mode)
|
||||
elif mode == "Interactive":
|
||||
output, mask = inference_sam_m2m_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
|
||||
output, mask = inference_sam_m2m_auto(model, image, text_size, label_mode, alpha, anno_mode)
|
||||
|
||||
elif model_name == 'seem':
|
||||
model = model_seem
|
||||
if mode == "Automatic":
|
||||
output, mask = inference_seem_pano(model, image['image'], text_size, label_mode, alpha, anno_mode)
|
||||
elif mode == "Interactive":
|
||||
output, mask = inference_seem_interactive(model, image['image'], spatial_masks, text_size, label_mode, alpha, anno_mode)
|
||||
output, mask = inference_seem_pano(model, image, text_size, label_mode, alpha, anno_mode)
|
||||
|
||||
return output
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
def encode_image(image):
|
||||
return base64.b64encode(image).decode('utf-8')
|
||||
|
||||
def parse_actions_from_string(input_string):
|
||||
# Search for a JSON string within the input string
|
||||
@@ -156,7 +163,6 @@ def parse_actions_from_string(input_string):
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid response format: " + input_string)
|
||||
|
||||
|
||||
class GPT4v_Agent:
|
||||
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300):
|
||||
self.instruction = instruction
|
||||
@@ -181,7 +187,7 @@ class GPT4v_Agent:
|
||||
]
|
||||
|
||||
def predict(self, obs):
|
||||
obs = inference(obs, slider=2.0, mode="Automatic", alpha=0.1, label_mode="Alphabet", anno_mode=["Mask", "Mark"])
|
||||
obs = inference(obs, slider=2.0, mode="Automatic", alpha=0.1, label_mode="Number", anno_mode=["Mark", "Box"])
|
||||
base64_image = encode_image(obs)
|
||||
self.trajectory.append({
|
||||
"role": "user",
|
||||
@@ -274,4 +280,5 @@ if __name__ == '__main__':
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
agent = GPT4v_Agent(api_key=api_key, instruction="Open Google Sheet")
|
||||
print(agent.predict(obs="stackoverflow.png"))
|
||||
obs = PIL.Image.open('stackoverflow.png')
|
||||
print(agent.predict(obs=obs))
|
||||
19
mm_agents/gemini_test.py
Normal file
19
mm_agents/gemini_test.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import PIL.Image
|
||||
import google.generativeai as genai
|
||||
|
||||
genai.configure(api_key="AIzaSyANsETKHVo-D8jZu1SnTSaQgLOJEDgnj9Q")
|
||||
|
||||
# for m in genai.list_models():
|
||||
# if 'generateContent' in m.supported_generation_methods:
|
||||
# print(m.name)
|
||||
|
||||
model = genai.GenerativeModel('gemini-pro-vision')
|
||||
|
||||
img = PIL.Image.open('image.jpg')
|
||||
|
||||
messages = [
|
||||
{'role':'user',
|
||||
'parts': ["Explain this image.", img]}
|
||||
]
|
||||
|
||||
response = model.generate_content(messages)
|
||||
0
mm_agents/task_adapter/sam/__init__.py
Normal file
0
mm_agents/task_adapter/sam/__init__.py
Normal file
2
mm_agents/task_adapter/sam/tasks/__Init__.py
Normal file
2
mm_agents/task_adapter/sam/tasks/__Init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .inference_sam_m2m_auto import *
|
||||
from .inference_sam_m2m_interactive import *
|
||||
103
mm_agents/task_adapter/sam/tasks/inference_sam_m2m_auto.py
Normal file
103
mm_agents/task_adapter/sam/tasks/inference_sam_m2m_auto.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import io
|
||||
from segment_anything import SamAutomaticMaskGenerator
|
||||
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
|
||||
def inference_sam_m2m_auto(model, image, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image)
|
||||
image_ori = np.asarray(image_ori)
|
||||
|
||||
mask_generator = SamAutomaticMaskGenerator(model)
|
||||
outputs = mask_generator.generate(image_ori)
|
||||
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
||||
label = 1
|
||||
# for ann in sorted_anns:
|
||||
# mask = ann['segmentation']
|
||||
# color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# # color_mask = [int(c*255) for c in color_mask]
|
||||
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# label += 1
|
||||
# im = demo.get_image()
|
||||
|
||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
mask = ann['segmentation']
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# color_mask = [int(c*255) for c in color_mask]
|
||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# assign the mask to the mask_map
|
||||
mask_map[mask == 1] = label
|
||||
label += 1
|
||||
im = demo.get_image()
|
||||
# fig=plt.figure(figsize=(10, 10))
|
||||
# plt.imshow(image_ori)
|
||||
# show_anns(outputs)
|
||||
# fig.canvas.draw()
|
||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
||||
return im, sorted_anns
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
def show_anns(anns):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in sorted_anns:
|
||||
m = ann['segmentation']
|
||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack((img, m*0.35)))
|
||||
@@ -0,0 +1,221 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
from kornia.contrib import distance_transform
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import io
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
from segment_anything import SamAutomaticMaskGenerator
|
||||
from segment_anything.utils.amg import (
|
||||
MaskData,
|
||||
area_from_rle,
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
box_xyxy_to_xywh,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
coco_encode_rle,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
mask_to_rle_pytorch,
|
||||
remove_small_regions,
|
||||
rle_to_mask,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
uncrop_points,
|
||||
)
|
||||
|
||||
def sam_interactive_mask(mask_generator, points, in_points, in_labels, mask_input):
|
||||
masks, iou_preds, _ = mask_generator.predictor.predict_torch(
|
||||
in_points,
|
||||
in_labels,
|
||||
mask_input=mask_input,
|
||||
multimask_output=True,
|
||||
return_logits=True,
|
||||
)
|
||||
nm,_,h,w = masks.shape
|
||||
|
||||
# Serialize predictions and store in MaskData
|
||||
data = MaskData(
|
||||
masks=masks.flatten(0, 1),
|
||||
iou_preds=iou_preds.flatten(0, 1),
|
||||
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
||||
)
|
||||
del masks
|
||||
|
||||
# Calculate stability score
|
||||
data["stability_score"] = calculate_stability_score(
|
||||
data["masks"], mask_generator.predictor.model.mask_threshold, mask_generator.stability_score_offset
|
||||
)
|
||||
|
||||
masks = data["masks"].reshape(nm, -1, h, w)
|
||||
scores = (data['iou_preds'] + data['stability_score']).reshape(nm, -1)
|
||||
|
||||
index = torch.stack([torch.arange(nm).cuda(), scores.argmax(dim=1)]).tolist()
|
||||
return masks[index]
|
||||
|
||||
def inference_sam_m2m_interactive(model, image, spatial_masks, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image)
|
||||
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
|
||||
orig_size = images.shape[-2:]
|
||||
orig_h, orig_w = orig_size
|
||||
crop_box = [0,0,orig_w,orig_h]
|
||||
|
||||
spatial_masks = spatial_masks[:, None].float().cuda()
|
||||
spatial_masks = F.interpolate(spatial_masks, size=(orig_h, orig_w), mode='bicubic', align_corners=False) > 0
|
||||
|
||||
# generate single center point
|
||||
# n,_,h,w = spatial_masks.shape
|
||||
# mask_dt = (distance_transform((~F.pad(spatial_masks, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(n,-1)
|
||||
# max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
||||
# next_mask = torch.zeros(spatial_masks.shape, device=torch.cuda.current_device()).bool()
|
||||
# next_mask = next_mask.view(n,-1)
|
||||
# next_mask[max_xy_idx] = True
|
||||
# next_mask = next_mask.reshape((n,1,h,w))
|
||||
# points = next_mask.nonzero()[:,2:].flip(dims=[1]).cpu().numpy()
|
||||
|
||||
# stack sampled points
|
||||
acc_points = []
|
||||
for i in range(len(spatial_masks)):
|
||||
points = spatial_masks[i:i+1].nonzero()[:,2:].flip(dims=[1]).cpu().numpy()
|
||||
rand_ids = np.random.choice(points.shape[0], size=40, replace=True)
|
||||
points = points[rand_ids]
|
||||
acc_points.append(points)
|
||||
_np = len(acc_points)
|
||||
points = np.concatenate(acc_points)
|
||||
|
||||
mask_generator = SamAutomaticMaskGenerator(model)
|
||||
mask_generator.predictor.set_image(image_ori)
|
||||
im_size = image_ori.shape[:-1]
|
||||
|
||||
transformed_points = mask_generator.predictor.transform.apply_coords(points, im_size)
|
||||
in_points = torch.as_tensor(transformed_points, device=mask_generator.predictor.device).reshape(_np,-1,2).transpose(0,1)
|
||||
in_labels = torch.ones((in_points.shape[0], _np), dtype=torch.int, device=mask_generator.predictor.device)
|
||||
|
||||
masks = sam_interactive_mask(mask_generator, points, in_points.transpose(0,1), in_labels.transpose(0,1), None)
|
||||
|
||||
masks = masks > 0.0
|
||||
iou_preds = torch.ones(masks.shape[0], dtype=torch.float32)
|
||||
points = torch.zeros((masks.shape[0], 2), dtype=torch.float32)
|
||||
|
||||
mask_data = MaskData(
|
||||
masks=masks,
|
||||
iou_preds=iou_preds,
|
||||
points=points,
|
||||
)
|
||||
|
||||
mask_data["stability_score"] = torch.ones(masks.shape[0], dtype=torch.float32)
|
||||
del masks
|
||||
|
||||
mask_data["boxes"] = batched_mask_to_box(mask_data["masks"])
|
||||
mask_data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(mask_data["boxes"]))])
|
||||
|
||||
# Compress to RLE
|
||||
mask_data["masks"] = uncrop_masks(mask_data["masks"], crop_box, orig_h, orig_w)
|
||||
mask_data["rles"] = mask_to_rle_pytorch(mask_data["masks"])
|
||||
del mask_data["masks"]
|
||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
||||
|
||||
# Write mask records
|
||||
outputs = []
|
||||
for idx in range(len(mask_data["segmentations"])):
|
||||
ann = {
|
||||
"segmentation": mask_data["segmentations"][idx],
|
||||
"area": area_from_rle(mask_data["rles"][idx]),
|
||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
||||
"point_coords": [mask_data["points"][idx].tolist()],
|
||||
"stability_score": mask_data["stability_score"][idx].item(),
|
||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
||||
}
|
||||
outputs.append(ann)
|
||||
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
||||
label = 1
|
||||
# for ann in sorted_anns:
|
||||
# mask = ann['segmentation']
|
||||
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# label += 1
|
||||
# im = demo.get_image()
|
||||
|
||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
mask = ann['segmentation']
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# color_mask = [int(c*255) for c in color_mask]
|
||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# assign the mask to the mask_map
|
||||
mask_map[mask == 1] = label
|
||||
label += 1
|
||||
im = demo.get_image()
|
||||
# fig=plt.figure(figsize=(10, 10))
|
||||
# plt.imshow(image_ori)
|
||||
# show_anns(outputs)
|
||||
# fig.canvas.draw()
|
||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
||||
return im, sorted_anns
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
def show_anns(anns):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in sorted_anns:
|
||||
m = ann['segmentation']
|
||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack((img, m*0.35)))
|
||||
0
mm_agents/task_adapter/seem/__init__.py
Executable file
0
mm_agents/task_adapter/seem/__init__.py
Executable file
3
mm_agents/task_adapter/seem/tasks/__init__.py
Normal file
3
mm_agents/task_adapter/seem/tasks/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .interactive_seem_m2m_auto import *
|
||||
from .inference_seem_pano import *
|
||||
from .inference_seem_interactive import *
|
||||
382
mm_agents/task_adapter/seem/tasks/automatic_mask_generator.py
Normal file
382
mm_agents/task_adapter/seem/tasks/automatic_mask_generator.py
Normal file
@@ -0,0 +1,382 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from segment_anything.modeling import Sam
|
||||
from segment_anything.utils.amg import (
|
||||
MaskData,
|
||||
area_from_rle,
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
box_xyxy_to_xywh,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
coco_encode_rle,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
mask_to_rle_pytorch,
|
||||
remove_small_regions,
|
||||
rle_to_mask,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
uncrop_points,
|
||||
)
|
||||
|
||||
|
||||
class SeemAutomaticMaskGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
model: Sam,
|
||||
points_per_side: Optional[int] = 32,
|
||||
points_per_batch: int = 64,
|
||||
pred_iou_thresh: float = 0.9,
|
||||
stability_score_thresh: float = 0.5,
|
||||
stability_score_offset: float = 1.0,
|
||||
box_nms_thresh: float = 0.7,
|
||||
crop_n_layers: int = 0,
|
||||
crop_nms_thresh: float = 0.7,
|
||||
crop_overlap_ratio: float = 512 / 1500,
|
||||
crop_n_points_downscale_factor: int = 1,
|
||||
point_grids: Optional[List[np.ndarray]] = None,
|
||||
min_mask_region_area: int = 0,
|
||||
output_mode: str = "binary_mask",
|
||||
) -> None:
|
||||
"""
|
||||
Using a SAM model, generates masks for the entire image.
|
||||
Generates a grid of point prompts over the image, then filters
|
||||
low quality and duplicate masks. The default settings are chosen
|
||||
for SAM with a ViT-H backbone.
|
||||
|
||||
Arguments:
|
||||
model (Sam): The SAM model to use for mask prediction.
|
||||
points_per_side (int or None): The number of points to be sampled
|
||||
along one side of the image. The total number of points is
|
||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
||||
point sampling.
|
||||
points_per_batch (int): Sets the number of points run simultaneously
|
||||
by the model. Higher numbers may be faster but use more GPU memory.
|
||||
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
||||
model's predicted mask quality.
|
||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
||||
the stability of the mask under changes to the cutoff used to binarize
|
||||
the model's mask predictions.
|
||||
stability_score_offset (float): The amount to shift the cutoff when
|
||||
calculated the stability score.
|
||||
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks.
|
||||
crop_n_layers (int): If >0, mask prediction will be run again on
|
||||
crops of the image. Sets the number of layers to run, where each
|
||||
layer has 2**i_layer number of image crops.
|
||||
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks between different crops.
|
||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
||||
In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_n_points_downscale_factor (int): The number of points-per-side
|
||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
point_grids (list(np.ndarray) or None): A list over explicit grids
|
||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
||||
min_mask_region_area (int): If >0, postprocessing will be applied
|
||||
to remove disconnected regions and holes in masks with area smaller
|
||||
than min_mask_region_area. Requires opencv.
|
||||
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
||||
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
||||
For large resolutions, 'binary_mask' may consume large amounts of
|
||||
memory.
|
||||
"""
|
||||
|
||||
assert (points_per_side is None) != (
|
||||
point_grids is None
|
||||
), "Exactly one of points_per_side or point_grid must be provided."
|
||||
if points_per_side is not None:
|
||||
self.point_grids = build_all_layer_point_grids(
|
||||
points_per_side,
|
||||
crop_n_layers,
|
||||
crop_n_points_downscale_factor,
|
||||
)
|
||||
elif point_grids is not None:
|
||||
self.point_grids = point_grids
|
||||
else:
|
||||
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
||||
|
||||
assert output_mode in [
|
||||
"binary_mask",
|
||||
"uncompressed_rle",
|
||||
"coco_rle",
|
||||
], f"Unknown output_mode {output_mode}."
|
||||
if output_mode == "coco_rle":
|
||||
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
||||
|
||||
if min_mask_region_area > 0:
|
||||
import cv2 # type: ignore # noqa: F401
|
||||
|
||||
self.predictor = model
|
||||
self.points_per_batch = points_per_batch
|
||||
self.pred_iou_thresh = pred_iou_thresh
|
||||
self.stability_score_thresh = stability_score_thresh
|
||||
self.stability_score_offset = stability_score_offset
|
||||
self.box_nms_thresh = box_nms_thresh
|
||||
self.crop_n_layers = crop_n_layers
|
||||
self.crop_nms_thresh = crop_nms_thresh
|
||||
self.crop_overlap_ratio = crop_overlap_ratio
|
||||
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
||||
self.min_mask_region_area = min_mask_region_area
|
||||
self.output_mode = output_mode
|
||||
|
||||
# dilate conv
|
||||
self.dilation = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3, bias=False)
|
||||
self.dilation.weight.data.fill_(1.0)
|
||||
self.dilation.cuda()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generates masks for the given image.
|
||||
|
||||
Arguments:
|
||||
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
||||
|
||||
Returns:
|
||||
list(dict(str, any)): A list over records for masks. Each record is
|
||||
a dict containing the following keys:
|
||||
segmentation (dict(str, any) or np.ndarray): The mask. If
|
||||
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
||||
is a dictionary containing the RLE.
|
||||
bbox (list(float)): The box around the mask, in XYWH format.
|
||||
area (int): The area in pixels of the mask.
|
||||
predicted_iou (float): The model's own prediction of the mask's
|
||||
quality. This is filtered by the pred_iou_thresh parameter.
|
||||
point_coords (list(list(float))): The point coordinates input
|
||||
to the model to generate this mask.
|
||||
stability_score (float): A measure of the mask's quality. This
|
||||
is filtered on using the stability_score_thresh parameter.
|
||||
crop_box (list(float)): The crop of the image used to generate
|
||||
the mask, given in XYWH format.
|
||||
"""
|
||||
|
||||
# Generate masks
|
||||
mask_data = self._generate_masks(image)
|
||||
|
||||
# Filter small disconnected regions and holes in masks
|
||||
if self.min_mask_region_area > 0:
|
||||
mask_data = self.postprocess_small_regions(
|
||||
mask_data,
|
||||
self.min_mask_region_area,
|
||||
max(self.box_nms_thresh, self.crop_nms_thresh),
|
||||
)
|
||||
# Encode masks
|
||||
if self.output_mode == "coco_rle":
|
||||
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
||||
elif self.output_mode == "binary_mask":
|
||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
||||
else:
|
||||
mask_data["segmentations"] = mask_data["rles"]
|
||||
|
||||
# Write mask records
|
||||
curr_anns = []
|
||||
for idx in range(len(mask_data["segmentations"])):
|
||||
ann = {
|
||||
"segmentation": mask_data["segmentations"][idx],
|
||||
"area": area_from_rle(mask_data["rles"][idx]),
|
||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
||||
"point_coords": [mask_data["points"][idx].tolist()],
|
||||
"stability_score": mask_data["stability_score"][idx].item(),
|
||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
||||
}
|
||||
curr_anns.append(ann)
|
||||
|
||||
return curr_anns
|
||||
|
||||
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
||||
orig_size = image.shape[-2:]
|
||||
crop_boxes, layer_idxs = generate_crop_boxes(
|
||||
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
||||
)
|
||||
|
||||
# Iterate over image crops
|
||||
data = MaskData()
|
||||
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
||||
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
||||
data.cat(crop_data)
|
||||
|
||||
# Remove duplicate masks between crops
|
||||
if len(crop_boxes) > 1:
|
||||
# Prefer masks from smaller crops
|
||||
scores = 1 / box_area(data["crop_boxes"])
|
||||
scores = scores.to(data["boxes"].device)
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
scores,
|
||||
torch.zeros_like(data["boxes"][:, 0]), # categories
|
||||
iou_threshold=self.crop_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
data.to_numpy()
|
||||
return data
|
||||
|
||||
def _process_crop(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
crop_box: List[int],
|
||||
crop_layer_idx: int,
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
# Crop the image and calculate embeddings
|
||||
x0, y0, x1, y1 = crop_box
|
||||
cropped_im = image#[y0:y1, x0:x1, :]
|
||||
cropped_im_size = cropped_im.shape[-2:]
|
||||
# self.predictor.set_image(cropped_im)
|
||||
|
||||
# Get points for this crop
|
||||
points_scale = np.array(cropped_im_size)[None, ::-1]
|
||||
points_for_image = self.point_grids[crop_layer_idx] #* points_scale
|
||||
|
||||
# Generate masks for this crop in batches
|
||||
data = MaskData()
|
||||
self.enc_features=None
|
||||
|
||||
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
||||
batch_data = self._process_batch(cropped_im, points, cropped_im_size, crop_box, orig_size)
|
||||
data.cat(batch_data)
|
||||
del batch_data
|
||||
|
||||
# Remove duplicates within this crop.
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
data["iou_preds"],
|
||||
torch.zeros(len(data["boxes"])), # categories
|
||||
iou_threshold=self.box_nms_thresh,
|
||||
)
|
||||
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
# Return to the original image frame
|
||||
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
||||
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
||||
|
||||
return data
|
||||
|
||||
def _process_batch(
|
||||
self,
|
||||
images,
|
||||
points: np.ndarray,
|
||||
im_size: Tuple[int, ...],
|
||||
crop_box: List[int],
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
data = {"image": images, "height": orig_h, "width": orig_w}
|
||||
points = torch.tensor(points,dtype=torch.float).to(images.device)
|
||||
|
||||
# prepare interactive mask for seem
|
||||
abs_points = (points * torch.tensor(orig_size)[None,:].to(points.device)).long()
|
||||
abs_masks = torch.zeros((len(points), orig_h, orig_w), dtype=torch.bool).to(device=points.device)
|
||||
abs_masks[torch.arange(0, abs_points.size(0))[:,None], abs_points[:,0:1], abs_points[:,1:2]] = True
|
||||
abs_masks = self.dilation(abs_masks[:,None].float())[:,0] > 0
|
||||
data['spatial_query'] = {'rand_shape': abs_masks[:,None]}
|
||||
|
||||
batch_inputs = [data]
|
||||
if self.enc_features is None:
|
||||
masks, iou_preds, mask_features, transformer_encoder_features, multi_scale_features = self.predictor.model.evaluate_demo(batch_inputs, None, None, return_features=True)
|
||||
self.enc_features = (mask_features, transformer_encoder_features, multi_scale_features)
|
||||
else:
|
||||
masks, iou_preds = self.predictor.model.evaluate_demo(batch_inputs, self.enc_features[0], self.enc_features[1], self.enc_features[2])
|
||||
|
||||
data = MaskData(
|
||||
masks=masks,
|
||||
iou_preds=iou_preds,
|
||||
points=points,
|
||||
)
|
||||
del masks
|
||||
# Filter by predicted IoU
|
||||
if self.pred_iou_thresh > 0.0:
|
||||
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Calculate stability score
|
||||
data["stability_score"] = calculate_stability_score(
|
||||
data["masks"], 0.0, self.stability_score_offset
|
||||
)
|
||||
if self.stability_score_thresh > 0.0:
|
||||
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Threshold masks and calculate boxes
|
||||
data["masks"] = data["masks"] > 0.0
|
||||
data["boxes"] = batched_mask_to_box(data["masks"])
|
||||
|
||||
# Filter boxes that touch crop boundaries
|
||||
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
||||
if not torch.all(keep_mask):
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Compress to RLE
|
||||
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
||||
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
||||
del data["masks"]
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def postprocess_small_regions(
|
||||
mask_data: MaskData, min_area: int, nms_thresh: float
|
||||
) -> MaskData:
|
||||
"""
|
||||
Removes small disconnected regions and holes in masks, then reruns
|
||||
box NMS to remove any new duplicates.
|
||||
|
||||
Edits mask_data in place.
|
||||
|
||||
Requires open-cv as a dependency.
|
||||
"""
|
||||
if len(mask_data["rles"]) == 0:
|
||||
return mask_data
|
||||
|
||||
# Filter small disconnected regions and holes
|
||||
new_masks = []
|
||||
scores = []
|
||||
for rle in mask_data["rles"]:
|
||||
mask = rle_to_mask(rle)
|
||||
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
||||
unchanged = not changed
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
||||
unchanged = unchanged and not changed
|
||||
|
||||
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
||||
# Give score=0 to changed masks and score=1 to unchanged masks
|
||||
# so NMS will prefer ones that didn't need postprocessing
|
||||
scores.append(float(unchanged))
|
||||
|
||||
# Recalculate boxes and remove any new duplicates
|
||||
masks = torch.cat(new_masks, dim=0)
|
||||
boxes = batched_mask_to_box(masks)
|
||||
keep_by_nms = batched_nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
torch.zeros_like(boxes[:, 0]), # categories
|
||||
iou_threshold=nms_thresh,
|
||||
)
|
||||
|
||||
# Only recalculate RLEs for masks that have changed
|
||||
for i_mask in keep_by_nms:
|
||||
if scores[i_mask] == 0.0:
|
||||
mask_torch = masks[i_mask].unsqueeze(0)
|
||||
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
||||
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
||||
mask_data.filter(keep_by_nms)
|
||||
|
||||
return mask_data
|
||||
169
mm_agents/task_adapter/seem/tasks/inference_seem_interactive.py
Normal file
169
mm_agents/task_adapter/seem/tasks/inference_seem_interactive.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import io
|
||||
from .automatic_mask_generator import SeemAutomaticMaskGenerator
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
from segment_anything.utils.amg import (
|
||||
MaskData,
|
||||
area_from_rle,
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
box_xyxy_to_xywh,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
coco_encode_rle,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
mask_to_rle_pytorch,
|
||||
remove_small_regions,
|
||||
rle_to_mask,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
uncrop_points,
|
||||
)
|
||||
|
||||
|
||||
def inference_seem_interactive(model, image, spatial_masks, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image)
|
||||
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
|
||||
orig_size = images.shape[-2:]
|
||||
orig_h, orig_w = orig_size
|
||||
crop_box = [0,0,orig_w,orig_h]
|
||||
|
||||
data = {"image": images, "height": orig_h, "width": orig_w}
|
||||
|
||||
spatial_masks = spatial_masks[:, None].float().cuda()
|
||||
spatial_masks = F.interpolate(spatial_masks, size=(orig_h, orig_w), mode='bicubic', align_corners=False) > 0
|
||||
data['spatial_query'] = {'rand_shape': spatial_masks}
|
||||
|
||||
model.model.metadata = metadata
|
||||
masks, _ = model.model.evaluate_demo([data])
|
||||
masks = masks > 0.0
|
||||
iou_preds = torch.ones(masks.shape[0], dtype=torch.float32)
|
||||
points = torch.zeros((masks.shape[0], 2), dtype=torch.float32)
|
||||
|
||||
mask_data = MaskData(
|
||||
masks=masks,
|
||||
iou_preds=iou_preds,
|
||||
points=points,
|
||||
)
|
||||
|
||||
mask_data["stability_score"] = torch.ones(masks.shape[0], dtype=torch.float32)
|
||||
del masks
|
||||
|
||||
mask_data["boxes"] = batched_mask_to_box(mask_data["masks"])
|
||||
mask_data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(mask_data["boxes"]))])
|
||||
|
||||
# Compress to RLE
|
||||
mask_data["masks"] = uncrop_masks(mask_data["masks"], crop_box, orig_h, orig_w)
|
||||
mask_data["rles"] = mask_to_rle_pytorch(mask_data["masks"])
|
||||
del mask_data["masks"]
|
||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
||||
|
||||
# Write mask records
|
||||
outputs = []
|
||||
for idx in range(len(mask_data["segmentations"])):
|
||||
ann = {
|
||||
"segmentation": mask_data["segmentations"][idx],
|
||||
"area": area_from_rle(mask_data["rles"][idx]),
|
||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
||||
"point_coords": [mask_data["points"][idx].tolist()],
|
||||
"stability_score": mask_data["stability_score"][idx].item(),
|
||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
||||
}
|
||||
outputs.append(ann)
|
||||
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
||||
label = 1
|
||||
# for ann in sorted_anns:
|
||||
# mask = ann['segmentation']
|
||||
# color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# # color_mask = [int(c*255) for c in color_mask]
|
||||
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# label += 1
|
||||
# im = demo.get_image()
|
||||
|
||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
mask = ann['segmentation']
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# color_mask = [int(c*255) for c in color_mask]
|
||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# assign the mask to the mask_map
|
||||
mask_map[mask == 1] = label
|
||||
label += 1
|
||||
im = demo.get_image()
|
||||
# fig=plt.figure(figsize=(10, 10))
|
||||
# plt.imshow(image_ori)
|
||||
# show_anns(outputs)
|
||||
# fig.canvas.draw()
|
||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
||||
return im, sorted_anns
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
def show_anns(anns):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in sorted_anns:
|
||||
m = ann['segmentation']
|
||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack((img, m*0.35)))
|
||||
164
mm_agents/task_adapter/seem/tasks/inference_seem_pano.py
Normal file
164
mm_agents/task_adapter/seem/tasks/inference_seem_pano.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import io
|
||||
from .automatic_mask_generator import SeemAutomaticMaskGenerator
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
from segment_anything.utils.amg import (
|
||||
MaskData,
|
||||
area_from_rle,
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
box_xyxy_to_xywh,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
coco_encode_rle,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
mask_to_rle_pytorch,
|
||||
remove_small_regions,
|
||||
rle_to_mask,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
uncrop_points,
|
||||
)
|
||||
|
||||
|
||||
def inference_seem_pano(model, image, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image)
|
||||
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
|
||||
orig_size = images.shape[-2:]
|
||||
orig_h, orig_w = orig_size
|
||||
crop_box = [0,0,orig_w,orig_h]
|
||||
|
||||
data = {"image": images, "height": orig_h, "width": orig_w}
|
||||
batch_inputs = [data]
|
||||
|
||||
model.model.metadata = metadata
|
||||
outputs = model.model.evaluate(batch_inputs)
|
||||
|
||||
pano_mask = outputs[0]['panoptic_seg'][0]
|
||||
pano_info = outputs[0]['panoptic_seg'][1]
|
||||
|
||||
masks = []
|
||||
for seg_info in pano_info:
|
||||
masks += [pano_mask == seg_info['id']]
|
||||
masks = torch.stack(masks, dim=0)
|
||||
iou_preds = torch.ones(masks.shape[0], dtype=torch.float32)
|
||||
points = torch.zeros((masks.shape[0], 2), dtype=torch.float32)
|
||||
|
||||
mask_data = MaskData(
|
||||
masks=masks,
|
||||
iou_preds=iou_preds,
|
||||
points=points,
|
||||
)
|
||||
mask_data["stability_score"] = torch.ones(masks.shape[0], dtype=torch.float32)
|
||||
del masks
|
||||
|
||||
mask_data["boxes"] = batched_mask_to_box(mask_data["masks"])
|
||||
mask_data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(mask_data["boxes"]))])
|
||||
|
||||
# Compress to RLE
|
||||
mask_data["masks"] = uncrop_masks(mask_data["masks"], crop_box, orig_h, orig_w)
|
||||
mask_data["rles"] = mask_to_rle_pytorch(mask_data["masks"])
|
||||
del mask_data["masks"]
|
||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
||||
|
||||
# Write mask records
|
||||
outputs = []
|
||||
for idx in range(len(mask_data["segmentations"])):
|
||||
ann = {
|
||||
"segmentation": mask_data["segmentations"][idx],
|
||||
"area": area_from_rle(mask_data["rles"][idx]),
|
||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
||||
"point_coords": [mask_data["points"][idx].tolist()],
|
||||
"stability_score": mask_data["stability_score"][idx].item(),
|
||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
||||
}
|
||||
outputs.append(ann)
|
||||
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
# create a full zero image as the image_orig
|
||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
||||
label = 1
|
||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
mask = ann['segmentation']
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# color_mask = [int(c*255) for c in color_mask]
|
||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# assign the mask to the mask_map
|
||||
mask_map[mask == 1] = label
|
||||
label += 1
|
||||
im = demo.get_image()
|
||||
# fig=plt.figure(figsize=(10, 10))
|
||||
# plt.imshow(image_ori)
|
||||
# show_anns(outputs)
|
||||
# fig.canvas.draw()
|
||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
||||
return im, sorted_anns
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
def show_anns(anns):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in sorted_anns:
|
||||
m = ann['segmentation']
|
||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack((img, m*0.35)))
|
||||
@@ -0,0 +1,93 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import io
|
||||
from .automatic_mask_generator import SeemAutomaticMaskGenerator
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
def interactive_seem_m2m_auto(model, image, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image)
|
||||
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
|
||||
mask_generator = SeemAutomaticMaskGenerator(model)
|
||||
outputs = mask_generator.generate(images)
|
||||
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
||||
label = 1
|
||||
for ann in sorted_anns:
|
||||
mask = ann['segmentation']
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# color_mask = [int(c*255) for c in color_mask]
|
||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
label += 1
|
||||
im = demo.get_image()
|
||||
|
||||
# fig=plt.figure(figsize=(10, 10))
|
||||
# plt.imshow(image_ori)
|
||||
# show_anns(outputs)
|
||||
# fig.canvas.draw()
|
||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
||||
return im
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
def show_anns(anns):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in sorted_anns:
|
||||
m = ann['segmentation']
|
||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack((img, m*0.35)))
|
||||
6
mm_agents/task_adapter/semantic_sam/tasks/__init__.py
Normal file
6
mm_agents/task_adapter/semantic_sam/tasks/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .interactive_idino_m2m import interactive_infer_image as interactive_infer_image_idino_m2m
|
||||
from .interactive_idino_m2m import interactive_infer_image_semantic, interactive_infer_image_3l
|
||||
from .inference_semsam_m2m_auto import inference_semsam_m2m_auto
|
||||
from .interactive_idino_1o1_box import interactive_infer_image_box as interactive_infer_image_idino_m2m_box
|
||||
from .automatic_mask_generator import prompt_switch
|
||||
from .interactive_predictor import SemanticSAMPredictor
|
||||
@@ -0,0 +1,393 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
# from
|
||||
# from .modeling import Sam
|
||||
# from .predictor import SamPredictor
|
||||
from semantic_sam.utils.sam_utils.amg import (
|
||||
MaskData,
|
||||
area_from_rle,
|
||||
batch_iterator,
|
||||
batched_mask_to_box,
|
||||
box_xyxy_to_xywh,
|
||||
build_all_layer_point_grids,
|
||||
calculate_stability_score,
|
||||
coco_encode_rle,
|
||||
generate_crop_boxes,
|
||||
is_box_near_crop_edge,
|
||||
mask_to_rle_pytorch,
|
||||
remove_small_regions,
|
||||
rle_to_mask,
|
||||
uncrop_boxes_xyxy,
|
||||
uncrop_masks,
|
||||
uncrop_points,
|
||||
)
|
||||
|
||||
|
||||
def prompt_switch(p):
|
||||
p = int(p)
|
||||
if p == 1:
|
||||
return 3
|
||||
if p == 2:
|
||||
return 2
|
||||
if p == 3:
|
||||
return 0
|
||||
if p == 4:
|
||||
return 4
|
||||
if p == 5:
|
||||
return 1
|
||||
if p == 6:
|
||||
return 5
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SemanticSamAutomaticMaskGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
points_per_side: Optional[int] = 32,
|
||||
points_per_batch: int = 200,
|
||||
pred_iou_thresh: float = 0.88,
|
||||
stability_score_thresh: float = 0.92,
|
||||
stability_score_offset: float = 1.0,
|
||||
box_nms_thresh: float = 0.7,
|
||||
crop_n_layers: int = 0,
|
||||
crop_nms_thresh: float = 0.7,
|
||||
crop_overlap_ratio: float = 512 / 1500,
|
||||
crop_n_points_downscale_factor: int = 1,
|
||||
point_grids: Optional[List[np.ndarray]] = None,
|
||||
min_mask_region_area: int = 10,
|
||||
output_mode: str = "binary_mask",
|
||||
level: list = [1, 2, 3, 4, 5, 6],
|
||||
) -> None:
|
||||
"""
|
||||
Using a SAM model, generates masks for the entire image.
|
||||
Generates a grid of point prompts over the image, then filters
|
||||
low quality and duplicate masks. The default settings are chosen
|
||||
for SAM with a ViT-H backbone.
|
||||
|
||||
Arguments:
|
||||
model (Sam): The SAM model to use for mask prediction.
|
||||
points_per_side (int or None): The number of points to be sampled
|
||||
along one side of the image. The total number of points is
|
||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
||||
point sampling.
|
||||
points_per_batch (int): Sets the number of points run simultaneously
|
||||
by the model. Higher numbers may be faster but use more GPU memory.
|
||||
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
||||
model's predicted mask quality.
|
||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
||||
the stability of the mask under changes to the cutoff used to binarize
|
||||
the model's mask predictions.
|
||||
stability_score_offset (float): The amount to shift the cutoff when
|
||||
calculated the stability score.
|
||||
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks.
|
||||
crops_n_layers (int): If >0, mask prediction will be run again on
|
||||
crops of the image. Sets the number of layers to run, where each
|
||||
layer has 2**i_layer number of image crops.
|
||||
crops_nms_thresh (float): The box IoU cutoff used by non-maximal
|
||||
suppression to filter duplicate masks between different crops.
|
||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
||||
In the first crop layer, crops will overlap by this fraction of
|
||||
the image length. Later layers with more crops scale down this overlap.
|
||||
crop_n_points_downscale_factor (int): The number of points-per-side
|
||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
||||
point_grids (list(np.ndarray) or None): A list over explicit grids
|
||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
||||
min_mask_region_area (int): If >0, postprocessing will be applied
|
||||
to remove disconnected regions and holes in masks with area smaller
|
||||
than min_mask_region_area. Requires opencv.
|
||||
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
||||
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
||||
For large resolutions, 'binary_mask' may consume large amounts of
|
||||
memory.
|
||||
"""
|
||||
self.level = [prompt_switch(l) for l in level]
|
||||
assert (points_per_side is None) != (
|
||||
point_grids is None
|
||||
), "Exactly one of points_per_side or point_grid must be provided."
|
||||
if points_per_side is not None:
|
||||
self.point_grids = build_all_layer_point_grids(
|
||||
points_per_side,
|
||||
crop_n_layers,
|
||||
crop_n_points_downscale_factor,
|
||||
)
|
||||
elif point_grids is not None:
|
||||
self.point_grids = point_grids
|
||||
else:
|
||||
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
||||
|
||||
assert output_mode in [
|
||||
"binary_mask",
|
||||
"uncompressed_rle",
|
||||
"coco_rle",
|
||||
], f"Unknown output_mode {output_mode}."
|
||||
if output_mode == "coco_rle":
|
||||
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
||||
|
||||
if min_mask_region_area > 0:
|
||||
import cv2 # type: ignore # noqa: F401
|
||||
|
||||
self.predictor = model
|
||||
self.points_per_batch = points_per_batch
|
||||
self.pred_iou_thresh = pred_iou_thresh
|
||||
self.stability_score_thresh = stability_score_thresh
|
||||
self.stability_score_offset = stability_score_offset
|
||||
self.box_nms_thresh = box_nms_thresh
|
||||
self.crop_n_layers = crop_n_layers
|
||||
self.crop_nms_thresh = crop_nms_thresh
|
||||
self.crop_overlap_ratio = crop_overlap_ratio
|
||||
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
||||
self.min_mask_region_area = min_mask_region_area
|
||||
self.output_mode = output_mode
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generates masks for the given image.
|
||||
|
||||
Arguments:
|
||||
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
||||
|
||||
Returns:
|
||||
list(dict(str, any)): A list over records for masks. Each record is
|
||||
a dict containing the following keys:
|
||||
segmentation (dict(str, any) or np.ndarray): The mask. If
|
||||
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
||||
is a dictionary containing the RLE.
|
||||
bbox (list(float)): The box around the mask, in XYWH format.
|
||||
area (int): The area in pixels of the mask.
|
||||
predicted_iou (float): The model's own prediction of the mask's
|
||||
quality. This is filtered by the pred_iou_thresh parameter.
|
||||
point_coords (list(list(float))): The point coordinates input
|
||||
to the model to generate this mask.
|
||||
stability_score (float): A measure of the mask's quality. This
|
||||
is filtered on using the stability_score_thresh parameter.
|
||||
crop_box (list(float)): The crop of the image used to generate
|
||||
the mask, given in XYWH format.
|
||||
"""
|
||||
|
||||
# Generate masks
|
||||
mask_data = self._generate_masks(image)
|
||||
|
||||
# Filter small disconnected regions and holes in masks
|
||||
if self.min_mask_region_area > 0:
|
||||
mask_data = self.postprocess_small_regions(
|
||||
mask_data,
|
||||
self.min_mask_region_area,
|
||||
max(self.box_nms_thresh, self.crop_nms_thresh),
|
||||
)
|
||||
# Encode masks
|
||||
if self.output_mode == "coco_rle":
|
||||
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
||||
elif self.output_mode == "binary_mask":
|
||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
||||
else:
|
||||
mask_data["segmentations"] = mask_data["rles"]
|
||||
|
||||
# Write mask records
|
||||
curr_anns = []
|
||||
for idx in range(len(mask_data["segmentations"])):
|
||||
ann = {
|
||||
"segmentation": mask_data["segmentations"][idx],
|
||||
"area": area_from_rle(mask_data["rles"][idx]),
|
||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
||||
"point_coords": [mask_data["points"][idx].tolist()],
|
||||
"stability_score": mask_data["stability_score"][idx].item(),
|
||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
||||
}
|
||||
curr_anns.append(ann)
|
||||
|
||||
return curr_anns
|
||||
|
||||
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
||||
orig_size = image.shape[-2:]
|
||||
crop_boxes, layer_idxs = generate_crop_boxes(
|
||||
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
||||
)
|
||||
|
||||
# Iterate over image crops
|
||||
assert len(crop_boxes)==1
|
||||
data = MaskData()
|
||||
# import ipdb; ipdb.set_trace()
|
||||
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
||||
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
||||
|
||||
data.cat(crop_data)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# Remove duplicate masks between crops
|
||||
if len(crop_boxes) > 1:
|
||||
# Prefer masks from smaller crops
|
||||
scores = 1 / box_area(data["crop_boxes"])
|
||||
scores = scores.to(data["boxes"].device)
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
scores,
|
||||
torch.zeros(len(data["boxes"])), # categories
|
||||
iou_threshold=self.crop_nms_thresh,
|
||||
)
|
||||
data.filter(keep_by_nms)
|
||||
|
||||
data.to_numpy()
|
||||
return data
|
||||
|
||||
def _process_crop(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
crop_box: List[int],
|
||||
crop_layer_idx: int,
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
# Crop the image and calculate embeddings
|
||||
x0, y0, x1, y1 = crop_box
|
||||
cropped_im = image#[y0:y1, x0:x1, :]
|
||||
cropped_im_size = cropped_im.shape[-2:]
|
||||
# self.predictor.set_image(cropped_im)
|
||||
|
||||
# Get points for this crop
|
||||
points_scale = np.array(cropped_im_size)[None, ::-1]
|
||||
points_for_image = self.point_grids[crop_layer_idx] #* points_scale
|
||||
|
||||
# Generate masks for this crop in batches
|
||||
data = MaskData()
|
||||
self.enc_features=None
|
||||
# import ipdb; ipdb.set_trace()
|
||||
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
||||
batch_data = self._process_batch(cropped_im,points, cropped_im_size, crop_box, orig_size)
|
||||
data.cat(batch_data)
|
||||
del batch_data
|
||||
|
||||
keep_by_nms = batched_nms(
|
||||
data["boxes"].float(),
|
||||
data["iou_preds"],
|
||||
torch.zeros(len(data["boxes"])), # categories
|
||||
iou_threshold=self.box_nms_thresh,
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
data.filter(keep_by_nms)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# Return to the original image frame
|
||||
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
||||
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
||||
|
||||
return data
|
||||
|
||||
def _process_batch(
|
||||
self,
|
||||
images,
|
||||
points: np.ndarray,
|
||||
im_size: Tuple[int, ...],
|
||||
crop_box: List[int],
|
||||
orig_size: Tuple[int, ...],
|
||||
) -> MaskData:
|
||||
orig_h, orig_w = orig_size
|
||||
|
||||
data = {"image": images, "height": orig_h, "width": orig_w}
|
||||
points=torch.tensor(points,dtype=torch.float).to(images.device)
|
||||
points = torch.cat([points, points.new_tensor([[0.005, 0.005]]).repeat(len(points), 1)], dim=-1)
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points']=points
|
||||
data['targets'][0]['pb']=points.new_tensor([0.]*len(points))
|
||||
batch_inputs = [data]
|
||||
if self.enc_features is None:
|
||||
masks, iou_preds,mask_features,multi_scale_features= self.predictor.model.evaluate_demo(batch_inputs,None,None,return_features=True, level=self.level)
|
||||
self.enc_features=(mask_features,multi_scale_features)
|
||||
else:
|
||||
masks, iou_preds= self.predictor.model.evaluate_demo(batch_inputs,None,None,self.enc_features[0],self.enc_features[1], level=self.level)
|
||||
|
||||
data = MaskData(
|
||||
masks=masks,
|
||||
iou_preds=iou_preds.flatten(),
|
||||
points=torch.as_tensor(points[:,None].repeat(1,len(self.level), 1).view(-1,4)),
|
||||
)
|
||||
del masks
|
||||
# Filter by predicted IoU
|
||||
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Calculate stability score
|
||||
data["stability_score"] = calculate_stability_score(
|
||||
data["masks"], 0.0, self.stability_score_offset
|
||||
)
|
||||
# if self.stability_score_thresh > 0.0:
|
||||
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Threshold masks and calculate boxes
|
||||
data["masks"] = data["masks"] > 0.0
|
||||
data["boxes"] = batched_mask_to_box(data["masks"])
|
||||
|
||||
# Filter boxes that touch crop boundaries
|
||||
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
||||
if not torch.all(keep_mask):
|
||||
data.filter(keep_mask)
|
||||
|
||||
# Compress to RLE
|
||||
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
||||
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
||||
del data["masks"]
|
||||
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def postprocess_small_regions(
|
||||
mask_data: MaskData, min_area: int, nms_thresh: float
|
||||
) -> MaskData:
|
||||
"""
|
||||
Removes small disconnected regions and holes in masks, then reruns
|
||||
box NMS to remove any new duplicates.
|
||||
|
||||
Edits mask_data in place.
|
||||
|
||||
Requires open-cv as a dependency.
|
||||
"""
|
||||
if len(mask_data["rles"]) == 0:
|
||||
return mask_data
|
||||
|
||||
# Filter small disconnected regions and holes
|
||||
new_masks = []
|
||||
scores = []
|
||||
for rle in mask_data["rles"]:
|
||||
mask = rle_to_mask(rle)
|
||||
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
||||
unchanged = not changed
|
||||
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
||||
unchanged = unchanged and not changed
|
||||
|
||||
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
||||
# Give score=0 to changed masks and score=1 to unchanged masks
|
||||
# so NMS will prefer ones that didn't need postprocessing
|
||||
scores.append(float(unchanged))
|
||||
|
||||
# Recalculate boxes and remove any new duplicates
|
||||
masks = torch.cat(new_masks, dim=0)
|
||||
boxes = batched_mask_to_box(masks)
|
||||
keep_by_nms = batched_nms(
|
||||
boxes.float(),
|
||||
torch.as_tensor(scores),
|
||||
torch.zeros(len(boxes)), # categories
|
||||
iou_threshold=nms_thresh,
|
||||
)
|
||||
|
||||
# Only recalculate RLEs for masks that have changed
|
||||
for i_mask in keep_by_nms:
|
||||
if scores[i_mask] == 0.0:
|
||||
mask_torch = masks[i_mask].unsqueeze(0)
|
||||
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
||||
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
||||
mask_data.filter(keep_by_nms)
|
||||
|
||||
return mask_data
|
||||
@@ -0,0 +1,108 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
import io
|
||||
from .automatic_mask_generator import SemanticSamAutomaticMaskGenerator
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
def inference_semsam_m2m_auto(model, image, level, all_classes, all_parts, thresh, text_size, hole_scale, island_scale, semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image)
|
||||
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
|
||||
mask_generator = SemanticSamAutomaticMaskGenerator(model,points_per_side=32,
|
||||
pred_iou_thresh=0.88,
|
||||
stability_score_thresh=0.92,
|
||||
min_mask_region_area=10,
|
||||
level=level,
|
||||
)
|
||||
outputs = mask_generator.generate(images)
|
||||
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
||||
label = 1
|
||||
# for ann in sorted_anns:
|
||||
# mask = ann['segmentation']
|
||||
# color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# # color_mask = [int(c*255) for c in color_mask]
|
||||
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# label += 1
|
||||
# im = demo.get_image()
|
||||
|
||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
||||
for i, ann in enumerate(sorted_anns):
|
||||
mask = ann['segmentation']
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
# color_mask = [int(c*255) for c in color_mask]
|
||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
# assign the mask to the mask_map
|
||||
mask_map[mask == 1] = label
|
||||
label += 1
|
||||
im = demo.get_image()
|
||||
# fig=plt.figure(figsize=(10, 10))
|
||||
# plt.imshow(image_ori)
|
||||
# show_anns(outputs)
|
||||
# fig.canvas.draw()
|
||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
||||
return im, sorted_anns
|
||||
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
def show_anns(anns):
|
||||
if len(anns) == 0:
|
||||
return
|
||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
||||
ax = plt.gca()
|
||||
ax.set_autoscale_on(False)
|
||||
polygons = []
|
||||
color = []
|
||||
for ann in sorted_anns:
|
||||
m = ann['segmentation']
|
||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
||||
for i in range(3):
|
||||
img[:,:,i] = color_mask[i]
|
||||
ax.imshow(np.dstack((img, m*0.35)))
|
||||
@@ -0,0 +1,144 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
from detectron2.structures import BitMasks
|
||||
from semantic_sam.utils import box_ops
|
||||
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
def interactive_infer_image_box(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image['image'])
|
||||
mask_ori = transform1(image['mask'])
|
||||
width = image_ori.size[0]
|
||||
height = image_ori.size[1]
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
||||
|
||||
|
||||
data = {"image": images, "height": height, "width": width}
|
||||
|
||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
||||
flaten_mask = mask_ori.unsqueeze(0)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
points=mask_ori.nonzero().float().to(images.device)
|
||||
if len(points)==0:
|
||||
point_=point=points.new_tensor([[0.5,0.5,0.5,0.5]])
|
||||
else:
|
||||
mean_point=points.mean(0)[None]
|
||||
box_xyxy = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor
|
||||
h = mask_ori.shape[0]
|
||||
w = mask_ori.shape[1]
|
||||
box_xywh = (box_ops.box_xyxy_to_cxcywh(box_xyxy) / torch.as_tensor([w, h, w, h])).cuda()
|
||||
|
||||
# point_=points.mean(0)[None]
|
||||
# point=point_.clone()
|
||||
# point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
||||
# point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
||||
# point = point[:, [1, 0]]
|
||||
point=box_xywh
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points']=point
|
||||
data['targets'][0]['pb']=point.new_tensor([1.])
|
||||
|
||||
|
||||
batch_inputs = [data]
|
||||
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts, task='demo_box')
|
||||
|
||||
pred_masks_poses = masks
|
||||
reses=[]
|
||||
ious=ious[0,0]
|
||||
ids=torch.argsort(ious,descending=True)
|
||||
|
||||
text_res=''
|
||||
try:
|
||||
thresh=float(thresh)
|
||||
except Exception:
|
||||
thresh=0.0
|
||||
mask_ls=[]
|
||||
ious_res=[]
|
||||
areas=[]
|
||||
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
|
||||
iou=round(float(iou),2)
|
||||
texts=f'{iou}'
|
||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
||||
area=mask.sum()
|
||||
conti=False
|
||||
if iou<thresh:
|
||||
conti=True
|
||||
for m in mask_ls:
|
||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
||||
conti=True
|
||||
break
|
||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
||||
conti=False
|
||||
if conti:
|
||||
continue
|
||||
ious_res.append(iou)
|
||||
mask_ls.append(mask)
|
||||
areas.append(area)
|
||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
||||
mask=(mask).astype(np.float)
|
||||
out_txt = texts
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
color=[0.,0.,1.0]
|
||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
||||
demo = visual.draw_box(box_xyxy[0])
|
||||
res = demo.get_image()
|
||||
# point_x0=max(0,int(point_[0, 1])-3)
|
||||
# point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
||||
# point_y0 = max(0, int(point_[0, 0]) - 3)
|
||||
# point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
||||
# res[point_y0:point_y1,point_x0:point_x1,0]=255
|
||||
# res[point_y0:point_y1,point_x0:point_x1,1]=0
|
||||
# res[point_y0:point_y1,point_x0:point_x1,2]=0
|
||||
reses.append(Image.fromarray(res))
|
||||
text_res=text_res+';'+out_txt
|
||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
||||
ids = [int(i) for i in ids]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return reses,[reses[i] for i in ids]
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
@@ -0,0 +1,322 @@
|
||||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
def interactive_infer_image(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image['image'])
|
||||
mask_ori = transform1(image['mask'])
|
||||
width = image_ori.size[0]
|
||||
height = image_ori.size[1]
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
||||
|
||||
|
||||
data = {"image": images, "height": height, "width": width}
|
||||
|
||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
||||
points=mask_ori.nonzero().float().to(images.device)
|
||||
if len(points)==0:
|
||||
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
|
||||
else:
|
||||
point_=points.mean(0)[None]
|
||||
point=point_.clone()
|
||||
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
||||
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
||||
point = point[:, [1, 0]]
|
||||
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points']=point
|
||||
data['targets'][0]['pb']=point.new_tensor([0.])
|
||||
|
||||
|
||||
batch_inputs = [data]
|
||||
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts)
|
||||
|
||||
pred_masks_poses = masks
|
||||
reses=[]
|
||||
ious=ious[0,0]
|
||||
ids=torch.argsort(ious,descending=True)
|
||||
|
||||
text_res=''
|
||||
try:
|
||||
thresh=float(thresh)
|
||||
except Exception:
|
||||
thresh=0.0
|
||||
mask_ls=[]
|
||||
ious_res=[]
|
||||
areas=[]
|
||||
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
|
||||
iou=round(float(iou),2)
|
||||
texts=f'{iou}'
|
||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
||||
area=mask.sum()
|
||||
conti=False
|
||||
if iou<thresh:
|
||||
conti=True
|
||||
for m in mask_ls:
|
||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
||||
conti=True
|
||||
break
|
||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
||||
conti=False
|
||||
if conti:
|
||||
continue
|
||||
ious_res.append(iou)
|
||||
mask_ls.append(mask)
|
||||
areas.append(area)
|
||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
||||
mask=(mask).astype(np.float)
|
||||
out_txt = texts
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
color=[0.,0.,1.0]
|
||||
# demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
||||
res = demo.get_image()
|
||||
point_x0=max(0,int(point_[0, 1])-3)
|
||||
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
||||
point_y0 = max(0, int(point_[0, 0]) - 3)
|
||||
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
||||
# res[point_y0:point_y1,point_x0:point_x1,0]=255
|
||||
# res[point_y0:point_y1,point_x0:point_x1,1]=0
|
||||
# res[point_y0:point_y1,point_x0:point_x1,2]=0
|
||||
reses.append(Image.fromarray(res))
|
||||
text_res=text_res+';'+out_txt
|
||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
||||
ids = [int(i) for i in ids]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return reses,[reses[i] for i in ids]
|
||||
|
||||
def interactive_infer_image_3l(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image['image'])
|
||||
mask_ori = transform1(image['mask'])
|
||||
width = image_ori.size[0]
|
||||
height = image_ori.size[1]
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
||||
|
||||
|
||||
data = {"image": images, "height": height, "width": width}
|
||||
|
||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
||||
points=mask_ori.nonzero().float().to(images.device)
|
||||
if len(points)==0:
|
||||
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
|
||||
else:
|
||||
point_=points.mean(0)[None]
|
||||
point=point_.clone()
|
||||
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
||||
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
||||
point = point[:, [1, 0]]
|
||||
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points']=point
|
||||
data['targets'][0]['pb']=point.new_tensor([0.])
|
||||
|
||||
|
||||
batch_inputs = [data]
|
||||
masks, ious, pred_class, pred_class_score = model.model.evaluate_demo(batch_inputs,all_classes,all_parts, level=[0,1,2])
|
||||
|
||||
pred_masks_poses = masks
|
||||
reses=[]
|
||||
ious=ious[0,0]
|
||||
ids=torch.argsort(ious,descending=True)
|
||||
|
||||
text_res=''
|
||||
try:
|
||||
thresh=float(thresh)
|
||||
except Exception:
|
||||
thresh=0.0
|
||||
mask_ls=[]
|
||||
ious_res=[]
|
||||
areas=[]
|
||||
new_pred_class = []
|
||||
new_pred_class_score = []
|
||||
for i in ids:
|
||||
new_pred_class_score.append(pred_class_score[i])
|
||||
new_pred_class.append(pred_class[i])
|
||||
# import ipdb; ipdb.set_trace()
|
||||
for i,(pred_masks_pos,iou, cls_name, cls_score) in enumerate(zip(pred_masks_poses[ids],ious[ids], new_pred_class, new_pred_class_score)):
|
||||
iou=round(float(iou),2)
|
||||
texts=f'{iou}_{cls_name}_{cls_score}'
|
||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
||||
area=mask.sum()
|
||||
conti=False
|
||||
if iou<thresh:
|
||||
conti=True
|
||||
for m in mask_ls:
|
||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
||||
conti=True
|
||||
break
|
||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
||||
conti=False
|
||||
if conti:
|
||||
continue
|
||||
ious_res.append(iou)
|
||||
mask_ls.append(mask)
|
||||
areas.append(area)
|
||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
||||
mask=(mask).astype(np.float)
|
||||
out_txt = texts
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
color=[0.,0.,1.0]
|
||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
||||
res = demo.get_image()
|
||||
point_x0=max(0,int(point_[0, 1])-3)
|
||||
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
||||
point_y0 = max(0, int(point_[0, 0]) - 3)
|
||||
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
||||
res[point_y0:point_y1,point_x0:point_x1,0]=255
|
||||
res[point_y0:point_y1,point_x0:point_x1,1]=0
|
||||
res[point_y0:point_y1,point_x0:point_x1,2]=0
|
||||
reses.append(Image.fromarray(res))
|
||||
text_res=text_res+';'+out_txt
|
||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
||||
ids = [int(i) for i in ids]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return reses,[reses[i] for i in ids]
|
||||
|
||||
def interactive_infer_image_semantic(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori = transform1(image['image'])
|
||||
mask_ori = transform1(image['mask'])
|
||||
width = image_ori.size[0]
|
||||
height = image_ori.size[1]
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
||||
|
||||
|
||||
data = {"image": images, "height": height, "width": width}
|
||||
|
||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
||||
points=mask_ori.nonzero().float().to(images.device)
|
||||
if len(points)==0:
|
||||
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
|
||||
else:
|
||||
point_=points.mean(0)[None]
|
||||
point=point_.clone()
|
||||
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
||||
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
||||
point = point[:, [1, 0]]
|
||||
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points']=point
|
||||
data['targets'][0]['pb']=point.new_tensor([0.])
|
||||
data['targets'][0]['pb']=point.new_tensor([1.])
|
||||
|
||||
|
||||
batch_inputs = [data]
|
||||
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts)
|
||||
|
||||
pred_masks_poses = masks
|
||||
reses=[]
|
||||
ious=ious[0,0]
|
||||
ids=torch.argsort(ious,descending=True)
|
||||
|
||||
text_res=''
|
||||
try:
|
||||
thresh=float(thresh)
|
||||
except Exception:
|
||||
thresh=0.0
|
||||
mask_ls=[]
|
||||
ious_res=[]
|
||||
areas=[]
|
||||
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
|
||||
iou=round(float(iou),2)
|
||||
texts=f'{iou}'
|
||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
||||
area=mask.sum()
|
||||
conti=False
|
||||
if iou<thresh:
|
||||
conti=True
|
||||
for m in mask_ls:
|
||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
||||
conti=True
|
||||
break
|
||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
||||
conti=False
|
||||
if conti:
|
||||
continue
|
||||
ious_res.append(iou)
|
||||
mask_ls.append(mask)
|
||||
areas.append(area)
|
||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
||||
mask=(mask).astype(np.float)
|
||||
out_txt = texts
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
color=[0.,0.,1.0]
|
||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
||||
res = demo.get_image()
|
||||
point_x0=max(0,int(point_[0, 1])-3)
|
||||
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
||||
point_y0 = max(0, int(point_[0, 0]) - 3)
|
||||
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
||||
res[point_y0:point_y1,point_x0:point_x1,0]=255
|
||||
res[point_y0:point_y1,point_x0:point_x1,1]=0
|
||||
res[point_y0:point_y1,point_x0:point_x1,2]=0
|
||||
reses.append(Image.fromarray(res))
|
||||
text_res=text_res+';'+out_txt
|
||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
||||
ids = [int(i) for i in ids]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return reses,[reses[i] for i in ids]
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
@@ -0,0 +1,139 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from task_adapter.utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
|
||||
class SemanticSAMPredictor:
|
||||
def __init__(self, model, thresh=0.5, text_size=640, hole_scale=100, island_scale=100):
|
||||
"""
|
||||
thresh: iou thresh to filter low confidence objects
|
||||
text_size: resize the input image short edge for the model to process
|
||||
hole_scale: fill in small holes as in SAM
|
||||
island_scale: remove small regions as in SAM
|
||||
"""
|
||||
self.model = model
|
||||
self.thresh = thresh
|
||||
self.text_size = hole_scale
|
||||
self.hole_scale = hole_scale
|
||||
self.island_scale = island_scale
|
||||
self.point = None
|
||||
|
||||
def predict(self, image_ori, image, point=None):
|
||||
"""
|
||||
produce up to 6 prediction results for each click
|
||||
"""
|
||||
width = image_ori.shape[0]
|
||||
height = image_ori.shape[1]
|
||||
|
||||
data = {"image": image, "height": height, "width": width}
|
||||
# import ipdb; ipdb.set_trace()
|
||||
if point is None:
|
||||
point = torch.tensor([[0.5, 0.5, 0.006, 0.006]]).cuda()
|
||||
else:
|
||||
point = torch.tensor(point).cuda()
|
||||
point_ = point
|
||||
point = point_.clone()
|
||||
point[0, 0] = point_[0, 0]
|
||||
point[0, 1] = point_[0, 1]
|
||||
# point = point[:, [1, 0]]
|
||||
point = torch.cat([point, point.new_tensor([[0.005, 0.005]])], dim=-1)
|
||||
|
||||
self.point = point[:, :2].clone()*(torch.tensor([width, height]).to(point))
|
||||
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['points'] = point
|
||||
data['targets'][0]['pb'] = point.new_tensor([0.])
|
||||
|
||||
batch_inputs = [data]
|
||||
masks, ious = self.model.model.evaluate_demo(batch_inputs)
|
||||
|
||||
return masks, ious
|
||||
|
||||
def process_multi_mask(self, masks, ious, image_ori):
|
||||
pred_masks_poses = masks
|
||||
reses = []
|
||||
ious = ious[0, 0]
|
||||
ids = torch.argsort(ious, descending=True)
|
||||
|
||||
text_res = ''
|
||||
mask_ls = []
|
||||
ious_res = []
|
||||
areas = []
|
||||
for i, (pred_masks_pos, iou) in enumerate(zip(pred_masks_poses[ids], ious[ids])):
|
||||
iou = round(float(iou), 2)
|
||||
texts = f'{iou}'
|
||||
mask = (pred_masks_pos > 0.0).cpu().numpy()
|
||||
area = mask.sum()
|
||||
conti = False
|
||||
if iou < self.thresh:
|
||||
conti = True
|
||||
for m in mask_ls:
|
||||
if np.logical_and(mask, m).sum() / np.logical_or(mask, m).sum() > 0.95:
|
||||
conti = True
|
||||
break
|
||||
if i == len(pred_masks_poses[ids]) - 1 and mask_ls == []:
|
||||
conti = False
|
||||
if conti:
|
||||
continue
|
||||
ious_res.append(iou)
|
||||
mask_ls.append(mask)
|
||||
areas.append(area)
|
||||
mask, _ = self.remove_small_regions(mask, int(self.hole_scale), mode="holes")
|
||||
mask, _ = self.remove_small_regions(mask, int(self.island_scale), mode="islands")
|
||||
mask = (mask).astype(np.float)
|
||||
out_txt = texts
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
color = [0., 0., 1.0]
|
||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
||||
res = demo.get_image()
|
||||
point_x0 = max(0, int(self.point[0, 0]) - 3)
|
||||
point_x1 = min(image_ori.shape[1], int(self.point[0, 0]) + 3)
|
||||
point_y0 = max(0, int(self.point[0, 1]) - 3)
|
||||
point_y1 = min(image_ori.shape[0], int(self.point[0, 1]) + 3)
|
||||
res[point_y0:point_y1, point_x0:point_x1, 0] = 255
|
||||
res[point_y0:point_y1, point_x0:point_x1, 1] = 0
|
||||
res[point_y0:point_y1, point_x0:point_x1, 2] = 0
|
||||
reses.append(Image.fromarray(res))
|
||||
text_res = text_res + ';' + out_txt
|
||||
ids = list(torch.argsort(torch.tensor(areas), descending=False))
|
||||
ids = [int(i) for i in ids]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return reses, [reses[i] for i in ids]
|
||||
|
||||
def predict_masks(self, image_ori, image, point=None):
|
||||
masks, ious = self.predict(image_ori, image, point)
|
||||
return self.process_multi_mask(masks, ious, image_ori)
|
||||
|
||||
@staticmethod
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
1405
mm_agents/task_adapter/utils/visualizer.py
Normal file
1405
mm_agents/task_adapter/utils/visualizer.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user