Code clean
This commit is contained in:
@@ -1,401 +0,0 @@
|
|||||||
# --------------------------------------------------------
|
|
||||||
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
|
|
||||||
# Copyright (c) 2022 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
# Written by Xueyan Zou (xueyan@cs.wisc.edu)
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
# Define Test/Trainer/Saving
|
|
||||||
PIPELINE: XDecoderPipeline
|
|
||||||
TRAINER: xdecoder
|
|
||||||
SAVE_DIR: '../../data/output/test'
|
|
||||||
base_path: "./"
|
|
||||||
|
|
||||||
# Resume Logistic
|
|
||||||
RESUME: false
|
|
||||||
WEIGHT: false
|
|
||||||
RESUME_FROM: ''
|
|
||||||
EVAL_AT_START: False
|
|
||||||
|
|
||||||
# Logging and Debug
|
|
||||||
WANDB: False
|
|
||||||
LOG_EVERY: 100
|
|
||||||
FIND_UNUSED_PARAMETERS: false
|
|
||||||
|
|
||||||
# Speed up training
|
|
||||||
FP16: false
|
|
||||||
PORT: '36873'
|
|
||||||
|
|
||||||
# misc
|
|
||||||
LOADER:
|
|
||||||
JOINT: False
|
|
||||||
KEY_DATASET: 'coco'
|
|
||||||
|
|
||||||
##################
|
|
||||||
# Task settings
|
|
||||||
##################
|
|
||||||
VERBOSE: true
|
|
||||||
MODEL:
|
|
||||||
NAME: seem_model_v1
|
|
||||||
HEAD: xdecoder_head
|
|
||||||
MASK_ON: false
|
|
||||||
KEYPOINT_ON: false
|
|
||||||
LOAD_PROPOSALS: false
|
|
||||||
DIM_PROJ: 512
|
|
||||||
TEXT:
|
|
||||||
ARCH: vlpencoder
|
|
||||||
NAME: transformer
|
|
||||||
TOKENIZER: clip
|
|
||||||
CONTEXT_LENGTH: 77 # 77
|
|
||||||
WIDTH: 512
|
|
||||||
HEADS: 8
|
|
||||||
LAYERS: 12 # 6
|
|
||||||
AUTOGRESSIVE: True
|
|
||||||
BACKBONE:
|
|
||||||
NAME: focal
|
|
||||||
PRETRAINED: ''
|
|
||||||
LOAD_PRETRAINED: false
|
|
||||||
FOCAL:
|
|
||||||
PRETRAIN_IMG_SIZE: 224
|
|
||||||
PATCH_SIZE: 4
|
|
||||||
EMBED_DIM: 192
|
|
||||||
DEPTHS: [2, 2, 18, 2]
|
|
||||||
FOCAL_LEVELS: [4, 4, 4, 4]
|
|
||||||
FOCAL_WINDOWS: [3, 3, 3, 3]
|
|
||||||
DROP_PATH_RATE: 0.3
|
|
||||||
MLP_RATIO: 4.0
|
|
||||||
DROP_RATE: 0.0
|
|
||||||
PATCH_NORM: True
|
|
||||||
USE_CONV_EMBED: True
|
|
||||||
SCALING_MODULATOR: True
|
|
||||||
USE_CHECKPOINT: False
|
|
||||||
USE_POSTLN: true
|
|
||||||
USE_POSTLN_IN_MODULATION: false
|
|
||||||
USE_LAYERSCALE: True
|
|
||||||
OUT_FEATURES: ["res2", "res3", "res4", "res5"]
|
|
||||||
OUT_INDICES: [0, 1, 2, 3]
|
|
||||||
ENCODER:
|
|
||||||
NAME: transformer_encoder_fpn
|
|
||||||
IGNORE_VALUE: 255
|
|
||||||
NUM_CLASSES: 133
|
|
||||||
LOSS_WEIGHT: 1.0
|
|
||||||
CONVS_DIM: 512
|
|
||||||
MASK_DIM: 512
|
|
||||||
NORM: "GN"
|
|
||||||
IN_FEATURES: ["res2", "res3", "res4", "res5"]
|
|
||||||
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"]
|
|
||||||
COMMON_STRIDE: 4
|
|
||||||
TRANSFORMER_ENC_LAYERS: 6
|
|
||||||
DECODER:
|
|
||||||
NAME: seem_v1
|
|
||||||
TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
|
|
||||||
MASK:
|
|
||||||
ENABLED: True
|
|
||||||
DETECTION: False
|
|
||||||
SPATIAL:
|
|
||||||
ENABLED: True
|
|
||||||
MAX_ITER: 1
|
|
||||||
GROUNDING:
|
|
||||||
ENABLED: True
|
|
||||||
MAX_LEN: 5
|
|
||||||
TEXT_WEIGHT: 2.0
|
|
||||||
CLASS_WEIGHT: 0.5
|
|
||||||
RETRIEVAL:
|
|
||||||
ENABLED: False
|
|
||||||
LVIS:
|
|
||||||
ENABLED: True
|
|
||||||
THRES: 0.7
|
|
||||||
OPENIMAGE:
|
|
||||||
ENABLED: False
|
|
||||||
NEGATIVE_SAMPLES: 5
|
|
||||||
GROUNDING:
|
|
||||||
ENABLED: False
|
|
||||||
MAX_LEN: 5
|
|
||||||
CAPTION:
|
|
||||||
ENABLED: False
|
|
||||||
PHRASE_PROB: 0.5
|
|
||||||
SIM_THRES: 0.95
|
|
||||||
DEEP_SUPERVISION: True
|
|
||||||
NO_OBJECT_WEIGHT: 0.1
|
|
||||||
GCLASS_WEIGHT: 0.4
|
|
||||||
GMASK_WEIGHT: 1.0
|
|
||||||
GDICE_WEIGHT: 1.0
|
|
||||||
SCLASS_WEIGHT: 0.4
|
|
||||||
SMASK_WEIGHT: 1.0
|
|
||||||
SDICE_WEIGHT: 1.0
|
|
||||||
OCLASS_WEIGHT: 0.4
|
|
||||||
OMASK_WEIGHT: 1.0
|
|
||||||
ODICE_WEIGHT: 1.0
|
|
||||||
CLASS_WEIGHT: 2.0
|
|
||||||
MASK_WEIGHT: 5.0
|
|
||||||
DICE_WEIGHT: 5.0
|
|
||||||
BBOX_WEIGHT: 5.0
|
|
||||||
GIOU_WEIGHT: 2.0
|
|
||||||
CAPTION_WEIGHT: 2.0
|
|
||||||
COST_SPATIAL:
|
|
||||||
CLASS_WEIGHT: 5.0
|
|
||||||
MASK_WEIGHT: 2.0
|
|
||||||
DICE_WEIGHT: 2.0
|
|
||||||
HIDDEN_DIM: 512
|
|
||||||
NUM_OBJECT_QUERIES: 101
|
|
||||||
NHEADS: 8
|
|
||||||
DROPOUT: 0.0
|
|
||||||
DIM_FEEDFORWARD: 2048
|
|
||||||
MAX_SPATIAL_LEN: [512, 512, 512, 512]
|
|
||||||
# ENC_LAYERS: 0
|
|
||||||
PRE_NORM: False
|
|
||||||
ENFORCE_INPUT_PROJ: False
|
|
||||||
SIZE_DIVISIBILITY: 32
|
|
||||||
TRAIN_NUM_POINTS: 12544
|
|
||||||
OVERSAMPLE_RATIO: 3.0
|
|
||||||
IMPORTANCE_SAMPLE_RATIO: 0.75
|
|
||||||
DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query
|
|
||||||
TOP_GROUNDING_LAYERS: 10
|
|
||||||
TOP_CAPTION_LAYERS: 10
|
|
||||||
TOP_SPATIAL_LAYERS: 10
|
|
||||||
TOP_OPENIMAGE_LAYERS: 10
|
|
||||||
TEST:
|
|
||||||
SEMANTIC_ON: True
|
|
||||||
INSTANCE_ON: True
|
|
||||||
PANOPTIC_ON: True
|
|
||||||
OVERLAP_THRESHOLD: 0.8
|
|
||||||
OBJECT_MASK_THRESHOLD: 0.8
|
|
||||||
SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false
|
|
||||||
|
|
||||||
# Spatial sampler
|
|
||||||
STROKE_SAMPLER:
|
|
||||||
MAX_CANDIDATE: 1
|
|
||||||
CANDIDATE_PROBS: [0.25, 0.25, 0.25, 0.25] # for training only
|
|
||||||
CANDIDATE_NAMES: ["Point", "Polygon", "Scribble", "Circle"]
|
|
||||||
DILATION: 3
|
|
||||||
CIRCLE:
|
|
||||||
NUM_STROKES: 5
|
|
||||||
STROKE_PRESET: ['object_like', 'object_like_middle', 'object_like_small']
|
|
||||||
STROKE_PROB: [0.33, 0.33, 0.33]
|
|
||||||
SCRIBBLE:
|
|
||||||
NUM_STROKES: 5
|
|
||||||
STROKE_PRESET: ['rand_curve', 'rand_curve_small']
|
|
||||||
STROKE_PROB: [0.5, 0.5]
|
|
||||||
POINT:
|
|
||||||
NUM_POINTS: 20
|
|
||||||
POLYGON:
|
|
||||||
MAX_POINTS: 9
|
|
||||||
EVAL:
|
|
||||||
MODE: 'best' # best/random/best_random
|
|
||||||
NEGATIVE: False
|
|
||||||
MAX_ITER: 20
|
|
||||||
IOU_ITER: 1
|
|
||||||
GROUNDING: False
|
|
||||||
|
|
||||||
# Multi-modal Architecture, order matters
|
|
||||||
ATTENTION_ARCH:
|
|
||||||
VARIABLE:
|
|
||||||
queries: ['object', 'grounding', 'spatial']
|
|
||||||
tokens: ['grounding', 'spatial']
|
|
||||||
memories: ['spatial']
|
|
||||||
SELF_ATTENTION:
|
|
||||||
queries:
|
|
||||||
object: ['queries_object']
|
|
||||||
grounding: ['queries_grounding', 'tokens_grounding']
|
|
||||||
spatial: ['queries_spatial', 'tokens_spatial', 'memories_spatial']
|
|
||||||
tokens:
|
|
||||||
grounding: ['queries_grounding', 'tokens_grounding']
|
|
||||||
spatial: ['tokens_spatial']
|
|
||||||
memories:
|
|
||||||
spatial: ['memories_spatial']
|
|
||||||
CROSS_ATTENTION:
|
|
||||||
queries:
|
|
||||||
object: True
|
|
||||||
grounding: True
|
|
||||||
spatial: True
|
|
||||||
memories:
|
|
||||||
spatial: True
|
|
||||||
tokens:
|
|
||||||
grounding: False
|
|
||||||
spatial: False
|
|
||||||
MASKING: ['tokens_spatial', 'tokens_grounding']
|
|
||||||
DUPLICATION:
|
|
||||||
queries:
|
|
||||||
grounding: 'queries_object'
|
|
||||||
spatial: 'queries_object'
|
|
||||||
SPATIAL_MEMORIES: 32
|
|
||||||
QUERY_NUMBER: 3
|
|
||||||
|
|
||||||
DATASETS:
|
|
||||||
TRAIN: ["coco_2017_train_panoptic_filtrefgumdval_with_sem_seg_caption_grounding_lvis",]
|
|
||||||
# TRAIN: ["coco_2017_train_panoptic_with_sem_seg_caption_grounding",]
|
|
||||||
TEST: ["coco_2017_val_panoptic_with_sem_seg", "pascalvoc_val_Point", "refcocog_val_umd"] # to evaluate instance and semantic performance as well
|
|
||||||
# TEST: ["pascalvoc_val_Point"] # [pascalvoc, openimage600, ade600, davis, cocomini], [Point, Scribble, Polygon, Circle, Box]
|
|
||||||
# TEST: ["cocomini_val_Point", "cocomini_val_Circle", "cocomini_val_Scribble", "cocomini_val_Polygon", "cocomini_val_Box"] # [pascalvoc, openimage600, ade600, davis, cocomini], [Point, Scribble, Polygon, Circle, Box]
|
|
||||||
# TEST: ["ade600_val_Point", "ade600_val_Circle", "ade600_val_Scribble", "ade600_val_Polygon", "ade600_val_Box"] # [pascalvoc, openimage600, ade600, davis, cocomini], [Point, Scribble, Polygon, Circle, Box]
|
|
||||||
# TEST: ["openimage600_val_Point", "openimage600_val_Circle", "openimage600_val_Scribble", "openimage600_val_Polygon", "openimage600_val_Box"] # [pascalvoc, openimage600, ade600, davis, cocomini], [Point, Scribble, Polygon, Circle, Box]
|
|
||||||
CLASS_CONCAT: false
|
|
||||||
SIZE_DIVISIBILITY: 32
|
|
||||||
PROPOSAL_FILES_TRAIN: []
|
|
||||||
|
|
||||||
INPUT:
|
|
||||||
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
|
||||||
PIXEL_STD: [58.395, 57.120, 57.375]
|
|
||||||
|
|
||||||
TRAIN:
|
|
||||||
ASPECT_RATIO_GROUPING: true
|
|
||||||
BATCH_SIZE_TOTAL: 4
|
|
||||||
BATCH_SIZE_PER_GPU: 4
|
|
||||||
SHUFFLE: true
|
|
||||||
|
|
||||||
TEST:
|
|
||||||
DETECTIONS_PER_IMAGE: 100
|
|
||||||
NAME: coco_eval
|
|
||||||
IOU_TYPE: ['bbox', 'segm']
|
|
||||||
USE_MULTISCALE: false
|
|
||||||
BATCH_SIZE_TOTAL: 8
|
|
||||||
MODEL_FILE: ''
|
|
||||||
AUG:
|
|
||||||
ENABLED: False
|
|
||||||
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 8
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: True
|
|
||||||
|
|
||||||
COCO:
|
|
||||||
INPUT:
|
|
||||||
MIN_SIZE_TRAIN: 800
|
|
||||||
MAX_SIZE_TRAIN: 1333
|
|
||||||
MIN_SIZE_TRAIN_SAMPLING: 'choice'
|
|
||||||
MIN_SIZE_TEST: 800
|
|
||||||
MAX_SIZE_TEST: 1333
|
|
||||||
IMAGE_SIZE: 1024
|
|
||||||
MIN_SCALE: 0.1
|
|
||||||
MAX_SCALE: 2.0
|
|
||||||
DATASET_MAPPER_NAME: "coco_interactive"
|
|
||||||
IGNORE_VALUE: 255
|
|
||||||
COLOR_AUG_SSD: False
|
|
||||||
SIZE_DIVISIBILITY: 32
|
|
||||||
RANDOM_FLIP: "horizontal"
|
|
||||||
MASK_FORMAT: "polygon"
|
|
||||||
FORMAT: "RGB"
|
|
||||||
CROP:
|
|
||||||
ENABLED: True
|
|
||||||
DATASET:
|
|
||||||
DATASET: 'coco'
|
|
||||||
|
|
||||||
# Validation dataset
|
|
||||||
ADE20K:
|
|
||||||
INPUT:
|
|
||||||
MIN_SIZE_TRAIN: 640
|
|
||||||
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
|
||||||
MIN_SIZE_TEST: 640
|
|
||||||
MAX_SIZE_TRAIN: 2560
|
|
||||||
MAX_SIZE_TEST: 2560
|
|
||||||
MASK_FORMAT: "polygon"
|
|
||||||
CROP:
|
|
||||||
ENABLED: True
|
|
||||||
TYPE: "absolute"
|
|
||||||
SIZE: (640, 640)
|
|
||||||
SINGLE_CATEGORY_MAX_AREA: 1.0
|
|
||||||
COLOR_AUG_SSD: True
|
|
||||||
SIZE_DIVISIBILITY: 640 # used in dataset mapper
|
|
||||||
DATASET_MAPPER_NAME: "mask_former_panoptic"
|
|
||||||
FORMAT: "RGB"
|
|
||||||
DATASET:
|
|
||||||
DATASET: 'ade'
|
|
||||||
|
|
||||||
SBD:
|
|
||||||
INPUT:
|
|
||||||
MIN_SIZE_TEST: 800
|
|
||||||
MAX_SIZE_TEST: 1333
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 0
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: False
|
|
||||||
TEST:
|
|
||||||
BATCH_SIZE_TOTAL: 1
|
|
||||||
|
|
||||||
VOC:
|
|
||||||
INPUT:
|
|
||||||
MIN_SIZE_TEST: 800
|
|
||||||
MAX_SIZE_TEST: 1333
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 0
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: False
|
|
||||||
TEST:
|
|
||||||
BATCH_SIZE_TOTAL: 8
|
|
||||||
|
|
||||||
DAVIS:
|
|
||||||
INPUT:
|
|
||||||
MIN_SIZE_TEST: 800
|
|
||||||
MAX_SIZE_TEST: 1333
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 0
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: False
|
|
||||||
TEST:
|
|
||||||
BATCH_SIZE_TOTAL: 8
|
|
||||||
|
|
||||||
VOS:
|
|
||||||
INPUT:
|
|
||||||
MIN_SIZE_TEST: 800
|
|
||||||
MAX_SIZE_TEST: 1333
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 0
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: False
|
|
||||||
TEST:
|
|
||||||
BATCH_SIZE_TOTAL: 1
|
|
||||||
|
|
||||||
REF:
|
|
||||||
INPUT:
|
|
||||||
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
|
||||||
PIXEL_STD: [58.395, 57.120, 57.375]
|
|
||||||
MIN_SIZE_TEST: 512
|
|
||||||
MAX_SIZE_TEST: 1024
|
|
||||||
FORMAT: "RGB"
|
|
||||||
SPATIAL: False
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 4
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: False
|
|
||||||
TEST:
|
|
||||||
BATCH_SIZE_TOTAL: 8
|
|
||||||
|
|
||||||
# Detectron2 training config for optimizer and lr scheduler
|
|
||||||
SOLVER:
|
|
||||||
BASE_LR: 0.0001
|
|
||||||
STEPS: [0.88889, 0.96296]
|
|
||||||
MAX_ITER: 1
|
|
||||||
GAMMA: 0.1
|
|
||||||
WARMUP_FACTOR: 1.0
|
|
||||||
WARMUP_ITERS: 10
|
|
||||||
WARMUP_METHOD: "linear"
|
|
||||||
WEIGHT_DECAY: 0.05
|
|
||||||
OPTIMIZER: "ADAMW"
|
|
||||||
LR_SCHEDULER_NAME: "WarmupMultiStepLR"
|
|
||||||
LR_MULTIPLIER:
|
|
||||||
backbone: 0.1
|
|
||||||
lang_encoder: 0.1
|
|
||||||
FIX_PARAM:
|
|
||||||
backbone: True
|
|
||||||
lang_encoder: True
|
|
||||||
pixel_decoder: True
|
|
||||||
WEIGHT_DECAY_NORM: 0.0
|
|
||||||
WEIGHT_DECAY_EMBED: 0.0
|
|
||||||
CLIP_GRADIENTS:
|
|
||||||
ENABLED: True
|
|
||||||
CLIP_TYPE: "full_model"
|
|
||||||
CLIP_VALUE: 5.0 # 0.01
|
|
||||||
NORM_TYPE: 2.0
|
|
||||||
MAX_NUM_EPOCHS: 50
|
|
||||||
@@ -1,524 +0,0 @@
|
|||||||
# ------------------------------------------------------------------------
|
|
||||||
# Semantic SAM
|
|
||||||
# Copyright (c) MicroSoft, Inc. and its affiliates.
|
|
||||||
# Modified from OpenSeed https://github.com/IDEA-Research/OpenSeed by Feng Li.
|
|
||||||
# ------------------------------------------------------------------------
|
|
||||||
|
|
||||||
##################
|
|
||||||
# Task settings
|
|
||||||
##################
|
|
||||||
WEIGHT: ''
|
|
||||||
PORT: 53711
|
|
||||||
VERBOSE: true
|
|
||||||
|
|
||||||
OUTPUT_DIR: '../../data/output/test'
|
|
||||||
# misc
|
|
||||||
LOADER:
|
|
||||||
JOINT: True
|
|
||||||
KEY_DATASET: 'coco'
|
|
||||||
# model
|
|
||||||
MODEL:
|
|
||||||
NAME: interactive_mask_dino
|
|
||||||
HEAD: general_head
|
|
||||||
MASK_ON: false
|
|
||||||
KEYPOINT_ON: false
|
|
||||||
LOAD_PROPOSALS: false
|
|
||||||
DIM_PROJ: 512
|
|
||||||
BACKBONE_DIM: 768
|
|
||||||
BACKGROUND: False
|
|
||||||
WEIGHTS: ''
|
|
||||||
TEXT:
|
|
||||||
ARCH: noencoder # no language encoder for training only sa-1b data
|
|
||||||
NAME: transformer
|
|
||||||
TOKENIZER: clip
|
|
||||||
CONTEXT_LENGTH: 18 # 77
|
|
||||||
WIDTH: 512
|
|
||||||
HEADS: 8
|
|
||||||
LAYERS: 12 # 6
|
|
||||||
AUTOGRESSIVE: True
|
|
||||||
BACKBONE:
|
|
||||||
NAME: swin
|
|
||||||
PRETRAINED: 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth'
|
|
||||||
LOAD_PRETRAINED: true
|
|
||||||
SWIN:
|
|
||||||
PRETRAIN_IMG_SIZE: 384
|
|
||||||
PATCH_SIZE: 4
|
|
||||||
EMBED_DIM: 192
|
|
||||||
DEPTHS: [ 2, 2, 18, 2 ]
|
|
||||||
NUM_HEADS: [ 6, 12, 24, 48 ]
|
|
||||||
WINDOW_SIZE: 12
|
|
||||||
MLP_RATIO: 4.0
|
|
||||||
QKV_BIAS: true
|
|
||||||
QK_SCALE: ~
|
|
||||||
DROP_RATE: 0.0
|
|
||||||
ATTN_DROP_RATE: 0.0
|
|
||||||
DROP_PATH_RATE: 0.3
|
|
||||||
APE: false
|
|
||||||
PATCH_NORM: true
|
|
||||||
USE_CHECKPOINT: false
|
|
||||||
OUT_FEATURES: [ 'res2', 'res3', 'res4', 'res5' ]
|
|
||||||
ENCODER:
|
|
||||||
NAME: encoder_deform
|
|
||||||
IGNORE_VALUE: 255
|
|
||||||
NUM_CLASSES: 1
|
|
||||||
LOSS_WEIGHT: 1.0
|
|
||||||
CONVS_DIM: 256
|
|
||||||
MASK_DIM: 256
|
|
||||||
NORM: "GN"
|
|
||||||
IN_FEATURES: [ "res2", "res3", "res4", "res5" ]
|
|
||||||
DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: [ "res3", "res4", "res5" ]
|
|
||||||
COMMON_STRIDE: 4
|
|
||||||
TRANSFORMER_ENC_LAYERS: 6
|
|
||||||
TOTAL_NUM_FEATURE_LEVELS: 4
|
|
||||||
NUM_FEATURE_LEVELS: 3
|
|
||||||
FEATURE_ORDER: "low2high"
|
|
||||||
DECODER:
|
|
||||||
NAME: interactive_mask_dino
|
|
||||||
TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder"
|
|
||||||
MASK: True
|
|
||||||
BOX: True
|
|
||||||
PART: True
|
|
||||||
GROUNDING:
|
|
||||||
ENABLED: False
|
|
||||||
MAX_LEN: 5
|
|
||||||
TEXT_WEIGHT: 2.0
|
|
||||||
CLASS_WEIGHT: 0.5
|
|
||||||
CAPTION:
|
|
||||||
ENABLED: False
|
|
||||||
PHRASE_PROB: 0.0
|
|
||||||
SIM_THRES: 0.95
|
|
||||||
CAPTIONING:
|
|
||||||
ENABLED: False
|
|
||||||
STEP: 50
|
|
||||||
RETRIEVAL:
|
|
||||||
ENABLED: False
|
|
||||||
DIM_IMG: 768
|
|
||||||
ENSEMBLE: True
|
|
||||||
OPENIMAGE:
|
|
||||||
ENABLED: False
|
|
||||||
NEGATIVE_SAMPLES: 5
|
|
||||||
GROUNDING:
|
|
||||||
ENABLED: False
|
|
||||||
MAX_LEN: 5
|
|
||||||
DEEP_SUPERVISION: True
|
|
||||||
NO_OBJECT_WEIGHT: 0.1
|
|
||||||
CLASS_WEIGHT: 4.0
|
|
||||||
MASK_WEIGHT: 5.0
|
|
||||||
DICE_WEIGHT: 5.0
|
|
||||||
BOX_WEIGHT: 5.0
|
|
||||||
GIOU_WEIGHT: 2.0
|
|
||||||
IOU_WEIGHT: 1.0
|
|
||||||
COST_CLASS_WEIGHT: 4.0
|
|
||||||
COST_DICE_WEIGHT: 5.0
|
|
||||||
COST_MASK_WEIGHT: 5.0
|
|
||||||
COST_BOX_WEIGHT: 5.0
|
|
||||||
COST_GIOU_WEIGHT: 2.0
|
|
||||||
HIDDEN_DIM: 256
|
|
||||||
NUM_OBJECT_QUERIES: 0
|
|
||||||
NHEADS: 8
|
|
||||||
DROPOUT: 0.0
|
|
||||||
DIM_FEEDFORWARD: 2048
|
|
||||||
ENC_LAYERS: 0
|
|
||||||
PRE_NORM: False
|
|
||||||
ENFORCE_INPUT_PROJ: False
|
|
||||||
SIZE_DIVISIBILITY: 32
|
|
||||||
DEC_LAYERS: 9 # 9 decoder layers, add one for the loss on learnable query
|
|
||||||
TRAIN_NUM_POINTS: 12544
|
|
||||||
OVERSAMPLE_RATIO: 3.0
|
|
||||||
IMPORTANCE_SAMPLE_RATIO: 0.75
|
|
||||||
TWO_STAGE: False
|
|
||||||
INITIALIZE_BOX_TYPE: 'no'
|
|
||||||
DN: seg
|
|
||||||
DN_NOISE_SCALE: 0.4
|
|
||||||
DN_NUM: 100
|
|
||||||
INITIAL_PRED: False
|
|
||||||
LEARN_TGT: False
|
|
||||||
TOTAL_NUM_FEATURE_LEVELS: 4
|
|
||||||
SEMANTIC_CE_LOSS: False
|
|
||||||
PANO_BOX_LOSS: False
|
|
||||||
COCO: False
|
|
||||||
O365: False
|
|
||||||
SAM: True
|
|
||||||
PASCAL: False
|
|
||||||
RE_POINT: True
|
|
||||||
NUM_INTERACTIVE_TOKENS: 6
|
|
||||||
MAX_NUM_INSTANCE: 60
|
|
||||||
TEST:
|
|
||||||
SEMANTIC_ON: True
|
|
||||||
INSTANCE_ON: True
|
|
||||||
PANOPTIC_ON: True
|
|
||||||
BOX_INTERACTIVE: False
|
|
||||||
CLASSIFICATION_ON: False
|
|
||||||
OVERLAP_THRESHOLD: 0.8
|
|
||||||
OBJECT_MASK_THRESHOLD: 0.25
|
|
||||||
SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE: false
|
|
||||||
TEST_FOUCUS_ON_BOX: False
|
|
||||||
PANO_TRANSFORM_EVAL: True
|
|
||||||
PANO_TEMPERATURE: 0.06
|
|
||||||
|
|
||||||
TEST:
|
|
||||||
EVAL_PERIOD: 500000
|
|
||||||
PRECISE_BN:
|
|
||||||
NUM_ITER: 1
|
|
||||||
ENABLED: False
|
|
||||||
AUG:
|
|
||||||
ENABLED: False
|
|
||||||
|
|
||||||
SAM:
|
|
||||||
INPUT:
|
|
||||||
MIN_SIZE_TEST: 800
|
|
||||||
MAX_SIZE_TEST: 1333
|
|
||||||
IMAGE_SIZE: 1024
|
|
||||||
MIN_SCALE: 0.99
|
|
||||||
MAX_SCALE: 1.01
|
|
||||||
DATASET_MAPPER_NAME: "sam"
|
|
||||||
IGNORE_VALUE: 255
|
|
||||||
COLOR_AUG_SSD: False
|
|
||||||
SIZE_DIVISIBILITY: 32
|
|
||||||
RANDOM_FLIP: "horizontal"
|
|
||||||
MASK_FORMAT: "polygon"
|
|
||||||
FORMAT: "RGB"
|
|
||||||
CROP:
|
|
||||||
ENABLED: True
|
|
||||||
DATASET:
|
|
||||||
DATASET: 'sam'
|
|
||||||
TEST:
|
|
||||||
DETECTIONS_PER_IMAGE: 100
|
|
||||||
NAME: coco_eval
|
|
||||||
IOU_TYPE: ['bbox', 'segm']
|
|
||||||
USE_MULTISCALE: false
|
|
||||||
BATCH_SIZE_TOTAL: 8
|
|
||||||
MODEL_FILE: ''
|
|
||||||
AUG:
|
|
||||||
ENABLED: False
|
|
||||||
TRAIN:
|
|
||||||
BATCH_SIZE_TOTAL: 1
|
|
||||||
BATCH_SIZE_PER_GPU: 1
|
|
||||||
SHUFFLE: true
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 4
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: True
|
|
||||||
|
|
||||||
COCO:
|
|
||||||
INPUT:
|
|
||||||
MIN_SIZE_TEST: 800
|
|
||||||
MAX_SIZE_TEST: 1333
|
|
||||||
IMAGE_SIZE: 1024
|
|
||||||
MIN_SCALE: 0.1
|
|
||||||
MAX_SCALE: 2.0
|
|
||||||
DATASET_MAPPER_NAME: "coco_interactive_panoptic_lsj"
|
|
||||||
IGNORE_VALUE: 255
|
|
||||||
COLOR_AUG_SSD: False
|
|
||||||
SIZE_DIVISIBILITY: 32
|
|
||||||
RANDOM_FLIP: "horizontal"
|
|
||||||
MASK_FORMAT: "polygon"
|
|
||||||
FORMAT: "RGB"
|
|
||||||
CROP:
|
|
||||||
ENABLED: True
|
|
||||||
DATASET:
|
|
||||||
DATASET: 'coco'
|
|
||||||
TEST:
|
|
||||||
DETECTIONS_PER_IMAGE: 100
|
|
||||||
NAME: coco_eval
|
|
||||||
IOU_TYPE: ['bbox', 'segm']
|
|
||||||
USE_MULTISCALE: false
|
|
||||||
BATCH_SIZE_TOTAL: 1
|
|
||||||
MODEL_FILE: ''
|
|
||||||
AUG:
|
|
||||||
ENABLED: False
|
|
||||||
TRAIN:
|
|
||||||
BATCH_SIZE_TOTAL: 1
|
|
||||||
BATCH_SIZE_PER_GPU: 1
|
|
||||||
SHUFFLE: true
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 2
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: True
|
|
||||||
|
|
||||||
VLP:
|
|
||||||
INPUT:
|
|
||||||
IMAGE_SIZE: 224
|
|
||||||
DATASET_MAPPER_NAME: "vlpretrain"
|
|
||||||
IGNORE_VALUE: 255
|
|
||||||
COLOR_AUG_SSD: False
|
|
||||||
SIZE_DIVISIBILITY: 32
|
|
||||||
MASK_FORMAT: "polygon"
|
|
||||||
FORMAT: "RGB"
|
|
||||||
CROP:
|
|
||||||
ENABLED: True
|
|
||||||
TRAIN:
|
|
||||||
BATCH_SIZE_TOTAL: 2
|
|
||||||
BATCH_SIZE_PER_GPU: 2
|
|
||||||
TEST:
|
|
||||||
BATCH_SIZE_TOTAL: 256
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 16
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: True
|
|
||||||
|
|
||||||
INPUT:
|
|
||||||
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
|
||||||
PIXEL_STD: [58.395, 57.120, 57.375]
|
|
||||||
|
|
||||||
DATASETS:
|
|
||||||
TRAIN: ["sam_train"]
|
|
||||||
# interactive segmentation evaluation.
|
|
||||||
TEST: ["coco_2017_val_panoptic_with_sem_seg_interactive_jointboxpoint"]
|
|
||||||
# TEST: ["sam_minival"]
|
|
||||||
|
|
||||||
CLASS_CONCAT: false
|
|
||||||
SIZE_DIVISIBILITY: 32
|
|
||||||
PROPOSAL_FILES_TRAIN: []
|
|
||||||
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 16
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: True
|
|
||||||
|
|
||||||
# Detectron2 training config for optimizer and lr scheduler
|
|
||||||
SOLVER:
|
|
||||||
BASE_LR_END: 0.0
|
|
||||||
MOMENTUM: 0.9
|
|
||||||
NESTEROV: False
|
|
||||||
CHECKPOINT_PERIOD: 5000
|
|
||||||
IMS_PER_BATCH: 1
|
|
||||||
REFERENCE_WORLD_SIZE: 0
|
|
||||||
BIAS_LR_FACTOR: 1.0
|
|
||||||
WEIGHT_DECAY_BIAS: None
|
|
||||||
# original
|
|
||||||
BASE_LR: 0.0001
|
|
||||||
STEPS: [327778, 355092]
|
|
||||||
MAX_ITER: 368750
|
|
||||||
GAMMA: 0.1
|
|
||||||
WARMUP_FACTOR: 1.0
|
|
||||||
WARMUP_ITERS: 10
|
|
||||||
WARMUP_METHOD: "linear"
|
|
||||||
WEIGHT_DECAY: 0.05
|
|
||||||
OPTIMIZER: "ADAMW"
|
|
||||||
LR_SCHEDULER_NAME: "WarmupMultiStepLR"
|
|
||||||
LR_MULTIPLIER:
|
|
||||||
backbone: 0.1
|
|
||||||
lang_encoder: 0.1
|
|
||||||
WEIGHT_DECAY_NORM: 0.0
|
|
||||||
WEIGHT_DECAY_EMBED: 0.0
|
|
||||||
CLIP_GRADIENTS:
|
|
||||||
ENABLED: True
|
|
||||||
CLIP_TYPE: "full_model"
|
|
||||||
CLIP_VALUE: 0.01
|
|
||||||
NORM_TYPE: 2.0
|
|
||||||
AMP:
|
|
||||||
ENABLED: True
|
|
||||||
|
|
||||||
# Evaluation Dataset
|
|
||||||
ADE20K:
|
|
||||||
INPUT:
|
|
||||||
MIN_SIZE_TRAIN: [320, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152, 1216, 1280]
|
|
||||||
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
|
||||||
MIN_SIZE_TEST: 640
|
|
||||||
MAX_SIZE_TRAIN: 2560
|
|
||||||
MAX_SIZE_TEST: 2560
|
|
||||||
MASK_FORMAT: "polygon"
|
|
||||||
CROP:
|
|
||||||
ENABLED: True
|
|
||||||
TYPE: "absolute"
|
|
||||||
SIZE: [640, 640]
|
|
||||||
SINGLE_CATEGORY_MAX_AREA: 1.0
|
|
||||||
IGNORE_VALUE: 255
|
|
||||||
COLOR_AUG_SSD: True
|
|
||||||
SIZE_DIVISIBILITY: 640 # used in dataset mapper
|
|
||||||
DATASET_MAPPER_NAME: "mask_former_panoptic"
|
|
||||||
FORMAT: "RGB"
|
|
||||||
DATASET:
|
|
||||||
DATASET: 'ade'
|
|
||||||
TRAIN:
|
|
||||||
ASPECT_RATIO_GROUPING: true
|
|
||||||
BATCH_SIZE_TOTAL: 16
|
|
||||||
BATCH_SIZE_PER_GPU: 2
|
|
||||||
SHUFFLE: true
|
|
||||||
TEST:
|
|
||||||
DETECTIONS_PER_IMAGE: 100
|
|
||||||
NAME: coco_eval
|
|
||||||
IOU_TYPE: ['bbox', 'segm']
|
|
||||||
USE_MULTISCALE: false
|
|
||||||
BATCH_SIZE_TOTAL: 8
|
|
||||||
MODEL_FILE: ''
|
|
||||||
AUG:
|
|
||||||
ENABLED: False
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 8
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: True
|
|
||||||
#ADE20K:
|
|
||||||
# INPUT:
|
|
||||||
# MIN_SIZE_TRAIN: 640
|
|
||||||
# MIN_SIZE_TRAIN_SAMPLING: "choice"
|
|
||||||
# MIN_SIZE_TEST: 640
|
|
||||||
# MAX_SIZE_TRAIN: 2560
|
|
||||||
# MAX_SIZE_TEST: 2560
|
|
||||||
# MASK_FORMAT: "polygon"
|
|
||||||
# CROP:
|
|
||||||
# ENABLED: True
|
|
||||||
# TYPE: "absolute"
|
|
||||||
# SIZE: (640, 640)
|
|
||||||
# SINGLE_CATEGORY_MAX_AREA: 1.0
|
|
||||||
# COLOR_AUG_SSD: True
|
|
||||||
# SIZE_DIVISIBILITY: 640 # used in dataset mapper
|
|
||||||
# DATASET_MAPPER_NAME: "mask_former_panoptic"
|
|
||||||
# FORMAT: "RGB"
|
|
||||||
# DATASET:
|
|
||||||
# DATASET: 'ade'
|
|
||||||
# TEST:
|
|
||||||
# BATCH_SIZE_TOTAL: 8
|
|
||||||
|
|
||||||
|
|
||||||
REF:
|
|
||||||
INPUT:
|
|
||||||
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
|
||||||
PIXEL_STD: [58.395, 57.120, 57.375]
|
|
||||||
MIN_SIZE_TEST: 512
|
|
||||||
MAX_SIZE_TEST: 1024
|
|
||||||
FORMAT: "RGB"
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 0
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: False
|
|
||||||
TEST:
|
|
||||||
BATCH_SIZE_TOTAL: 8
|
|
||||||
|
|
||||||
SUN:
|
|
||||||
INPUT:
|
|
||||||
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
|
||||||
PIXEL_STD: [58.395, 57.120, 57.375]
|
|
||||||
MIN_SIZE_TEST: 512
|
|
||||||
MAX_SIZE_TEST: 1024
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 0
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: False
|
|
||||||
TEST:
|
|
||||||
BATCH_SIZE_TOTAL: 8
|
|
||||||
|
|
||||||
SCAN:
|
|
||||||
INPUT:
|
|
||||||
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
|
||||||
PIXEL_STD: [58.395, 57.120, 57.375]
|
|
||||||
MIN_SIZE_TEST: 512
|
|
||||||
MAX_SIZE_TEST: 1024
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 0
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: False
|
|
||||||
TEST:
|
|
||||||
BATCH_SIZE_TOTAL: 8
|
|
||||||
|
|
||||||
BDD:
|
|
||||||
INPUT:
|
|
||||||
PIXEL_MEAN: [123.675, 116.280, 103.530]
|
|
||||||
PIXEL_STD: [58.395, 57.120, 57.375]
|
|
||||||
MIN_SIZE_TEST: 800
|
|
||||||
MAX_SIZE_TEST: 1333
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 0
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: False
|
|
||||||
TEST:
|
|
||||||
BATCH_SIZE_TOTAL: 8
|
|
||||||
|
|
||||||
CITY:
|
|
||||||
INPUT:
|
|
||||||
MIN_SIZE_TRAIN: [ 512, 614, 716, 819, 921, 1024, 1126, 1228, 1331, 1433, 1536, 1638, 1740, 1843, 1945, 2048 ]
|
|
||||||
MIN_SIZE_TRAIN_SAMPLING: "choice"
|
|
||||||
MIN_SIZE_TEST: 1024
|
|
||||||
MAX_SIZE_TRAIN: 4096
|
|
||||||
MAX_SIZE_TEST: 2048
|
|
||||||
CROP:
|
|
||||||
ENABLED: True
|
|
||||||
TYPE: "absolute"
|
|
||||||
SIZE: [ 512, 1024 ]
|
|
||||||
SINGLE_CATEGORY_MAX_AREA: 1.0
|
|
||||||
IGNORE_VALUE: 255
|
|
||||||
COLOR_AUG_SSD: True
|
|
||||||
SIZE_DIVISIBILITY: -1
|
|
||||||
FORMAT: "RGB"
|
|
||||||
DATASET_MAPPER_NAME: "mask_former_panoptic"
|
|
||||||
MASK_FORMAT: "polygon"
|
|
||||||
TEST:
|
|
||||||
EVAL_PERIOD: 5000
|
|
||||||
BATCH_SIZE_TOTAL: 1
|
|
||||||
AUG:
|
|
||||||
ENABLED: False
|
|
||||||
MIN_SIZES: [ 512, 768, 1024, 1280, 1536, 1792 ]
|
|
||||||
MAX_SIZE: 4096
|
|
||||||
FLIP: True
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: True
|
|
||||||
NUM_WORKERS: 2
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: True
|
|
||||||
TRAIN:
|
|
||||||
ASPECT_RATIO_GROUPING: true
|
|
||||||
BATCH_SIZE_TOTAL: 2
|
|
||||||
BATCH_SIZE_PER_GPU: 2
|
|
||||||
SHUFFLE: true
|
|
||||||
|
|
||||||
PSACAL_PART:
|
|
||||||
INPUT:
|
|
||||||
MIN_SIZE_TEST: 800
|
|
||||||
MAX_SIZE_TEST: 1333
|
|
||||||
IMAGE_SIZE: 1024
|
|
||||||
MIN_SCALE: 0.1
|
|
||||||
MAX_SCALE: 2.0
|
|
||||||
DATASET_MAPPER_NAME: "pascal_part_lsj"
|
|
||||||
IGNORE_VALUE: 255
|
|
||||||
COLOR_AUG_SSD: False
|
|
||||||
SIZE_DIVISIBILITY: 32
|
|
||||||
RANDOM_FLIP: "horizontal"
|
|
||||||
MASK_FORMAT: "polygon"
|
|
||||||
FORMAT: "RGB"
|
|
||||||
CROP:
|
|
||||||
ENABLED: True
|
|
||||||
MODEL:
|
|
||||||
MASK_ON: True
|
|
||||||
KEYPOINT_ON: False
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
# DATASET:
|
|
||||||
# DATASET: 'coco'
|
|
||||||
TEST:
|
|
||||||
DETECTIONS_PER_IMAGE: 100
|
|
||||||
NAME: coco_eval
|
|
||||||
IOU_TYPE: ['bbox', 'segm']
|
|
||||||
USE_MULTISCALE: false
|
|
||||||
BATCH_SIZE_TOTAL: 8
|
|
||||||
MODEL_FILE: ''
|
|
||||||
AUG:
|
|
||||||
ENABLED: False
|
|
||||||
TRAIN:
|
|
||||||
BATCH_SIZE_TOTAL: 1
|
|
||||||
BATCH_SIZE_PER_GPU: 1
|
|
||||||
SHUFFLE: true
|
|
||||||
DATALOADER:
|
|
||||||
FILTER_EMPTY_ANNOTATIONS: False
|
|
||||||
NUM_WORKERS: 2
|
|
||||||
LOAD_PROPOSALS: False
|
|
||||||
SAMPLER_TRAIN: "TrainingSampler"
|
|
||||||
ASPECT_RATIO_GROUPING: True
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Deformable DETR
|
|
||||||
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
|
|
||||||
from .ms_deform_attn_func import MSDeformAttnFunction
|
|
||||||
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Deformable DETR
|
|
||||||
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import print_function
|
|
||||||
from __future__ import division
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.autograd import Function
|
|
||||||
from torch.autograd.function import once_differentiable
|
|
||||||
|
|
||||||
try:
|
|
||||||
import MultiScaleDeformableAttention as MSDA
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
info_string = (
|
|
||||||
"\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n"
|
|
||||||
"\t`cd mask2former/modeling/pixel_decoder/ops`\n"
|
|
||||||
"\t`sh make.sh`\n"
|
|
||||||
)
|
|
||||||
raise ModuleNotFoundError(info_string)
|
|
||||||
|
|
||||||
|
|
||||||
class MSDeformAttnFunction(Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
|
|
||||||
ctx.im2col_step = im2col_step
|
|
||||||
output = MSDA.ms_deform_attn_forward(
|
|
||||||
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
|
|
||||||
ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@once_differentiable
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
|
|
||||||
grad_value, grad_sampling_loc, grad_attn_weight = \
|
|
||||||
MSDA.ms_deform_attn_backward(
|
|
||||||
value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
|
|
||||||
|
|
||||||
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
|
|
||||||
|
|
||||||
|
|
||||||
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
|
|
||||||
# for debug and test only,
|
|
||||||
# need to use cuda version instead
|
|
||||||
N_, S_, M_, D_ = value.shape
|
|
||||||
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
|
|
||||||
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
|
||||||
sampling_grids = 2 * sampling_locations - 1
|
|
||||||
sampling_value_list = []
|
|
||||||
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
|
|
||||||
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
|
|
||||||
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
|
|
||||||
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
|
|
||||||
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
|
|
||||||
# N_*M_, D_, Lq_, P_
|
|
||||||
sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
|
|
||||||
mode='bilinear', padding_mode='zeros', align_corners=False)
|
|
||||||
sampling_value_list.append(sampling_value_l_)
|
|
||||||
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
|
|
||||||
attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
|
|
||||||
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
|
|
||||||
return output.transpose(1, 2).contiguous()
|
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
#!/usr/bin/env bash
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Deformable DETR
|
|
||||||
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
|
|
||||||
python setup.py build install
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Deformable DETR
|
|
||||||
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
|
|
||||||
from .ms_deform_attn import MSDeformAttn
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Deformable DETR
|
|
||||||
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import print_function
|
|
||||||
from __future__ import division
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn.init import xavier_uniform_, constant_
|
|
||||||
|
|
||||||
from ..functions import MSDeformAttnFunction
|
|
||||||
from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch
|
|
||||||
|
|
||||||
|
|
||||||
def _is_power_of_2(n):
|
|
||||||
if (not isinstance(n, int)) or (n < 0):
|
|
||||||
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
|
|
||||||
return (n & (n-1) == 0) and n != 0
|
|
||||||
|
|
||||||
|
|
||||||
class MSDeformAttn(nn.Module):
|
|
||||||
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
|
|
||||||
"""
|
|
||||||
Multi-Scale Deformable Attention Module
|
|
||||||
:param d_model hidden dimension
|
|
||||||
:param n_levels number of feature levels
|
|
||||||
:param n_heads number of attention heads
|
|
||||||
:param n_points number of sampling points per attention head per feature level
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
if d_model % n_heads != 0:
|
|
||||||
raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
|
|
||||||
_d_per_head = d_model // n_heads
|
|
||||||
# you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
|
|
||||||
if not _is_power_of_2(_d_per_head):
|
|
||||||
warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
|
|
||||||
"which is more efficient in our CUDA implementation.")
|
|
||||||
|
|
||||||
self.im2col_step = 128
|
|
||||||
|
|
||||||
self.d_model = d_model
|
|
||||||
self.n_levels = n_levels
|
|
||||||
self.n_heads = n_heads
|
|
||||||
self.n_points = n_points
|
|
||||||
|
|
||||||
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
|
|
||||||
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
|
|
||||||
self.value_proj = nn.Linear(d_model, d_model)
|
|
||||||
self.output_proj = nn.Linear(d_model, d_model)
|
|
||||||
|
|
||||||
self._reset_parameters()
|
|
||||||
|
|
||||||
def _reset_parameters(self):
|
|
||||||
constant_(self.sampling_offsets.weight.data, 0.)
|
|
||||||
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
|
||||||
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
|
||||||
grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
|
|
||||||
for i in range(self.n_points):
|
|
||||||
grid_init[:, :, i, :] *= i + 1
|
|
||||||
with torch.no_grad():
|
|
||||||
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
|
||||||
constant_(self.attention_weights.weight.data, 0.)
|
|
||||||
constant_(self.attention_weights.bias.data, 0.)
|
|
||||||
xavier_uniform_(self.value_proj.weight.data)
|
|
||||||
constant_(self.value_proj.bias.data, 0.)
|
|
||||||
xavier_uniform_(self.output_proj.weight.data)
|
|
||||||
constant_(self.output_proj.bias.data, 0.)
|
|
||||||
|
|
||||||
def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
|
|
||||||
"""
|
|
||||||
:param query (N, Length_{query}, C)
|
|
||||||
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
|
|
||||||
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
|
||||||
:param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
|
|
||||||
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
|
||||||
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
|
|
||||||
:param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
|
|
||||||
|
|
||||||
:return output (N, Length_{query}, C)
|
|
||||||
"""
|
|
||||||
N, Len_q, _ = query.shape
|
|
||||||
N, Len_in, _ = input_flatten.shape
|
|
||||||
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
|
|
||||||
|
|
||||||
value = self.value_proj(input_flatten)
|
|
||||||
if input_padding_mask is not None:
|
|
||||||
value = value.masked_fill(input_padding_mask[..., None], float(0))
|
|
||||||
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
|
|
||||||
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
|
|
||||||
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
|
|
||||||
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
|
|
||||||
# N, Len_q, n_heads, n_levels, n_points, 2
|
|
||||||
if reference_points.shape[-1] == 2:
|
|
||||||
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
|
|
||||||
sampling_locations = reference_points[:, :, None, :, None, :] \
|
|
||||||
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
|
||||||
elif reference_points.shape[-1] == 4:
|
|
||||||
sampling_locations = reference_points[:, :, None, :, None, :2] \
|
|
||||||
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
|
|
||||||
try:
|
|
||||||
output = MSDeformAttnFunction.apply(
|
|
||||||
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
|
|
||||||
except:
|
|
||||||
# CPU
|
|
||||||
output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
|
|
||||||
# # For FLOPs calculation only
|
|
||||||
# output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)
|
|
||||||
output = self.output_proj(output)
|
|
||||||
return output
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Deformable DETR
|
|
||||||
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
|
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from torch.utils.cpp_extension import CUDA_HOME
|
|
||||||
from torch.utils.cpp_extension import CppExtension
|
|
||||||
from torch.utils.cpp_extension import CUDAExtension
|
|
||||||
|
|
||||||
from setuptools import find_packages
|
|
||||||
from setuptools import setup
|
|
||||||
|
|
||||||
requirements = ["torch", "torchvision"]
|
|
||||||
|
|
||||||
def get_extensions():
|
|
||||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
extensions_dir = os.path.join(this_dir, "src")
|
|
||||||
|
|
||||||
main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
|
|
||||||
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
|
|
||||||
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
|
|
||||||
|
|
||||||
sources = main_file + source_cpu
|
|
||||||
extension = CppExtension
|
|
||||||
extra_compile_args = {"cxx": []}
|
|
||||||
define_macros = []
|
|
||||||
|
|
||||||
# Force cuda since torch ask for a device, not if cuda is in fact available.
|
|
||||||
if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None:
|
|
||||||
extension = CUDAExtension
|
|
||||||
sources += source_cuda
|
|
||||||
define_macros += [("WITH_CUDA", None)]
|
|
||||||
extra_compile_args["nvcc"] = [
|
|
||||||
"-DCUDA_HAS_FP16=1",
|
|
||||||
"-D__CUDA_NO_HALF_OPERATORS__",
|
|
||||||
"-D__CUDA_NO_HALF_CONVERSIONS__",
|
|
||||||
"-D__CUDA_NO_HALF2_OPERATORS__",
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
if CUDA_HOME is None:
|
|
||||||
raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.')
|
|
||||||
else:
|
|
||||||
raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().')
|
|
||||||
|
|
||||||
sources = [os.path.join(extensions_dir, s) for s in sources]
|
|
||||||
include_dirs = [extensions_dir]
|
|
||||||
ext_modules = [
|
|
||||||
extension(
|
|
||||||
"MultiScaleDeformableAttention",
|
|
||||||
sources,
|
|
||||||
include_dirs=include_dirs,
|
|
||||||
define_macros=define_macros,
|
|
||||||
extra_compile_args=extra_compile_args,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
return ext_modules
|
|
||||||
|
|
||||||
setup(
|
|
||||||
name="MultiScaleDeformableAttention",
|
|
||||||
version="1.0",
|
|
||||||
author="Weijie Su",
|
|
||||||
url="https://github.com/fundamentalvision/Deformable-DETR",
|
|
||||||
description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
|
|
||||||
packages=find_packages(exclude=("configs", "tests",)),
|
|
||||||
ext_modules=get_extensions(),
|
|
||||||
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
|
|
||||||
)
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
|
|
||||||
|
|
||||||
at::Tensor
|
|
||||||
ms_deform_attn_cpu_forward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
AT_ERROR("Not implement on cpu");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<at::Tensor>
|
|
||||||
ms_deform_attn_cpu_backward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const at::Tensor &grad_output,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
AT_ERROR("Not implement on cpu");
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
at::Tensor
|
|
||||||
ms_deform_attn_cpu_forward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const int im2col_step);
|
|
||||||
|
|
||||||
std::vector<at::Tensor>
|
|
||||||
ms_deform_attn_cpu_backward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const at::Tensor &grad_output,
|
|
||||||
const int im2col_step);
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,158 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
#include "cuda/ms_deform_im2col_cuda.cuh"
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <cuda.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
|
|
||||||
|
|
||||||
at::Tensor ms_deform_attn_cuda_forward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
|
||||||
|
|
||||||
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
|
||||||
|
|
||||||
const int batch = value.size(0);
|
|
||||||
const int spatial_size = value.size(1);
|
|
||||||
const int num_heads = value.size(2);
|
|
||||||
const int channels = value.size(3);
|
|
||||||
|
|
||||||
const int num_levels = spatial_shapes.size(0);
|
|
||||||
|
|
||||||
const int num_query = sampling_loc.size(1);
|
|
||||||
const int num_point = sampling_loc.size(4);
|
|
||||||
|
|
||||||
const int im2col_step_ = std::min(batch, im2col_step);
|
|
||||||
|
|
||||||
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
|
||||||
|
|
||||||
auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
|
|
||||||
|
|
||||||
const int batch_n = im2col_step_;
|
|
||||||
auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
|
||||||
auto per_value_size = spatial_size * num_heads * channels;
|
|
||||||
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
|
||||||
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
|
||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
|
||||||
{
|
|
||||||
auto columns = output_n.select(0, n);
|
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
|
|
||||||
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
|
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
|
||||||
spatial_shapes.data<int64_t>(),
|
|
||||||
level_start_index.data<int64_t>(),
|
|
||||||
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
|
||||||
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
|
||||||
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
|
||||||
columns.data<scalar_t>());
|
|
||||||
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
output = output.view({batch, num_query, num_heads*channels});
|
|
||||||
|
|
||||||
return output;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const at::Tensor &grad_output,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
|
|
||||||
AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
|
|
||||||
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
|
|
||||||
|
|
||||||
AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
|
|
||||||
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
|
|
||||||
|
|
||||||
const int batch = value.size(0);
|
|
||||||
const int spatial_size = value.size(1);
|
|
||||||
const int num_heads = value.size(2);
|
|
||||||
const int channels = value.size(3);
|
|
||||||
|
|
||||||
const int num_levels = spatial_shapes.size(0);
|
|
||||||
|
|
||||||
const int num_query = sampling_loc.size(1);
|
|
||||||
const int num_point = sampling_loc.size(4);
|
|
||||||
|
|
||||||
const int im2col_step_ = std::min(batch, im2col_step);
|
|
||||||
|
|
||||||
AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
|
|
||||||
|
|
||||||
auto grad_value = at::zeros_like(value);
|
|
||||||
auto grad_sampling_loc = at::zeros_like(sampling_loc);
|
|
||||||
auto grad_attn_weight = at::zeros_like(attn_weight);
|
|
||||||
|
|
||||||
const int batch_n = im2col_step_;
|
|
||||||
auto per_value_size = spatial_size * num_heads * channels;
|
|
||||||
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
|
|
||||||
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
|
|
||||||
auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
|
|
||||||
|
|
||||||
for (int n = 0; n < batch/im2col_step_; ++n)
|
|
||||||
{
|
|
||||||
auto grad_output_g = grad_output_n.select(0, n);
|
|
||||||
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
|
|
||||||
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
|
|
||||||
grad_output_g.data<scalar_t>(),
|
|
||||||
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
|
||||||
spatial_shapes.data<int64_t>(),
|
|
||||||
level_start_index.data<int64_t>(),
|
|
||||||
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
|
||||||
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
|
|
||||||
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
|
|
||||||
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
|
|
||||||
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
|
|
||||||
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
|
|
||||||
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
|
||||||
grad_value, grad_sampling_loc, grad_attn_weight
|
|
||||||
};
|
|
||||||
}
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
at::Tensor ms_deform_attn_cuda_forward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const int im2col_step);
|
|
||||||
|
|
||||||
std::vector<at::Tensor> ms_deform_attn_cuda_backward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const at::Tensor &grad_output,
|
|
||||||
const int im2col_step);
|
|
||||||
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,67 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
*/
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "cpu/ms_deform_attn_cpu.h"
|
|
||||||
|
|
||||||
#ifdef WITH_CUDA
|
|
||||||
#include "cuda/ms_deform_attn_cuda.h"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
|
||||||
at::Tensor
|
|
||||||
ms_deform_attn_forward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
if (value.type().is_cuda())
|
|
||||||
{
|
|
||||||
#ifdef WITH_CUDA
|
|
||||||
return ms_deform_attn_cuda_forward(
|
|
||||||
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
|
|
||||||
#else
|
|
||||||
AT_ERROR("Not compiled with GPU support");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
AT_ERROR("Not implemented on the CPU");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<at::Tensor>
|
|
||||||
ms_deform_attn_backward(
|
|
||||||
const at::Tensor &value,
|
|
||||||
const at::Tensor &spatial_shapes,
|
|
||||||
const at::Tensor &level_start_index,
|
|
||||||
const at::Tensor &sampling_loc,
|
|
||||||
const at::Tensor &attn_weight,
|
|
||||||
const at::Tensor &grad_output,
|
|
||||||
const int im2col_step)
|
|
||||||
{
|
|
||||||
if (value.type().is_cuda())
|
|
||||||
{
|
|
||||||
#ifdef WITH_CUDA
|
|
||||||
return ms_deform_attn_cuda_backward(
|
|
||||||
value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
|
|
||||||
#else
|
|
||||||
AT_ERROR("Not compiled with GPU support");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
AT_ERROR("Not implemented on the CPU");
|
|
||||||
}
|
|
||||||
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
/*!
|
|
||||||
**************************************************************************************************
|
|
||||||
* Deformable DETR
|
|
||||||
* Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
**************************************************************************************************
|
|
||||||
* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
**************************************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
* Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
*/
|
|
||||||
|
|
||||||
#include "ms_deform_attn.h"
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
|
|
||||||
m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
|
|
||||||
}
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Deformable DETR
|
|
||||||
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
||||||
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
||||||
# Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import print_function
|
|
||||||
from __future__ import division
|
|
||||||
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.autograd import gradcheck
|
|
||||||
|
|
||||||
from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
|
|
||||||
|
|
||||||
|
|
||||||
N, M, D = 1, 2, 2
|
|
||||||
Lq, L, P = 2, 2, 2
|
|
||||||
shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
|
|
||||||
level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
|
|
||||||
S = sum([(H*W).item() for H, W in shapes])
|
|
||||||
|
|
||||||
|
|
||||||
torch.manual_seed(3)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def check_forward_equal_with_pytorch_double():
|
|
||||||
value = torch.rand(N, S, M, D).cuda() * 0.01
|
|
||||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
|
||||||
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
|
||||||
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
|
||||||
im2col_step = 2
|
|
||||||
output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
|
|
||||||
output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
|
|
||||||
fwdok = torch.allclose(output_cuda, output_pytorch)
|
|
||||||
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
|
||||||
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
|
|
||||||
|
|
||||||
print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def check_forward_equal_with_pytorch_float():
|
|
||||||
value = torch.rand(N, S, M, D).cuda() * 0.01
|
|
||||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
|
||||||
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
|
||||||
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
|
||||||
im2col_step = 2
|
|
||||||
output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
|
|
||||||
output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
|
|
||||||
fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
|
|
||||||
max_abs_err = (output_cuda - output_pytorch).abs().max()
|
|
||||||
max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
|
|
||||||
|
|
||||||
print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
|
|
||||||
|
|
||||||
|
|
||||||
def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
|
|
||||||
|
|
||||||
value = torch.rand(N, S, M, channels).cuda() * 0.01
|
|
||||||
sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
|
|
||||||
attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
|
|
||||||
attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
|
|
||||||
im2col_step = 2
|
|
||||||
func = MSDeformAttnFunction.apply
|
|
||||||
|
|
||||||
value.requires_grad = grad_value
|
|
||||||
sampling_locations.requires_grad = grad_sampling_loc
|
|
||||||
attention_weights.requires_grad = grad_attn_weight
|
|
||||||
|
|
||||||
gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
|
|
||||||
|
|
||||||
print(f'* {gradok} check_gradient_numerical(D={channels})')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
check_forward_equal_with_pytorch_double()
|
|
||||||
check_forward_equal_with_pytorch_float()
|
|
||||||
|
|
||||||
for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
|
|
||||||
check_gradient_numerical(channels, True, True, True)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
from .inference_sam_m2m_auto import *
|
|
||||||
from .inference_sam_m2m_interactive import *
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
# --------------------------------------------------------
|
|
||||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
|
||||||
# Copyright (c) 2023 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torchvision import transforms
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
from typing import Tuple
|
|
||||||
from PIL import Image
|
|
||||||
from detectron2.data import MetadataCatalog
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import cv2
|
|
||||||
import io
|
|
||||||
from segment_anything import SamAutomaticMaskGenerator
|
|
||||||
|
|
||||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
|
||||||
|
|
||||||
|
|
||||||
def inference_sam_m2m_auto(model, image, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
|
||||||
t = []
|
|
||||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
|
||||||
transform1 = transforms.Compose(t)
|
|
||||||
image_ori = transform1(image)
|
|
||||||
image_ori = np.asarray(image_ori)
|
|
||||||
|
|
||||||
mask_generator = SamAutomaticMaskGenerator(model)
|
|
||||||
outputs = mask_generator.generate(image_ori)
|
|
||||||
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
visual = Visualizer(image_ori, metadata=metadata)
|
|
||||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
|
||||||
label = 1
|
|
||||||
# for ann in sorted_anns:
|
|
||||||
# mask = ann['segmentation']
|
|
||||||
# color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
# # color_mask = [int(c*255) for c in color_mask]
|
|
||||||
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
|
||||||
# label += 1
|
|
||||||
# im = demo.get_image()
|
|
||||||
|
|
||||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
|
||||||
for i, ann in enumerate(sorted_anns):
|
|
||||||
mask = ann['segmentation']
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
# color_mask = [int(c*255) for c in color_mask]
|
|
||||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
|
||||||
# assign the mask to the mask_map
|
|
||||||
mask_map[mask == 1] = label
|
|
||||||
label += 1
|
|
||||||
im = demo.get_image()
|
|
||||||
# fig=plt.figure(figsize=(10, 10))
|
|
||||||
# plt.imshow(image_ori)
|
|
||||||
# show_anns(outputs)
|
|
||||||
# fig.canvas.draw()
|
|
||||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
|
||||||
return im, sorted_anns
|
|
||||||
|
|
||||||
|
|
||||||
def remove_small_regions(
|
|
||||||
mask: np.ndarray, area_thresh: float, mode: str
|
|
||||||
) -> Tuple[np.ndarray, bool]:
|
|
||||||
"""
|
|
||||||
Removes small disconnected regions and holes in a mask. Returns the
|
|
||||||
mask and an indicator of if the mask has been modified.
|
|
||||||
"""
|
|
||||||
import cv2 # type: ignore
|
|
||||||
|
|
||||||
assert mode in ["holes", "islands"]
|
|
||||||
correct_holes = mode == "holes"
|
|
||||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
||||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
||||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
||||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
||||||
if len(small_regions) == 0:
|
|
||||||
return mask, False
|
|
||||||
fill_labels = [0] + small_regions
|
|
||||||
if not correct_holes:
|
|
||||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
|
||||||
# If every region is below threshold, keep largest
|
|
||||||
if len(fill_labels) == 0:
|
|
||||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
|
||||||
mask = np.isin(regions, fill_labels)
|
|
||||||
return mask, True
|
|
||||||
|
|
||||||
def show_anns(anns):
|
|
||||||
if len(anns) == 0:
|
|
||||||
return
|
|
||||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
|
||||||
ax = plt.gca()
|
|
||||||
ax.set_autoscale_on(False)
|
|
||||||
polygons = []
|
|
||||||
color = []
|
|
||||||
for ann in sorted_anns:
|
|
||||||
m = ann['segmentation']
|
|
||||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
for i in range(3):
|
|
||||||
img[:,:,i] = color_mask[i]
|
|
||||||
ax.imshow(np.dstack((img, m*0.35)))
|
|
||||||
@@ -1,221 +0,0 @@
|
|||||||
# --------------------------------------------------------
|
|
||||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
|
||||||
# Copyright (c) 2023 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
from torchvision import transforms
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
from typing import Tuple
|
|
||||||
from PIL import Image
|
|
||||||
from detectron2.data import MetadataCatalog
|
|
||||||
from kornia.contrib import distance_transform
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import cv2
|
|
||||||
import io
|
|
||||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
|
||||||
|
|
||||||
from segment_anything import SamAutomaticMaskGenerator
|
|
||||||
from segment_anything.utils.amg import (
|
|
||||||
MaskData,
|
|
||||||
area_from_rle,
|
|
||||||
batch_iterator,
|
|
||||||
batched_mask_to_box,
|
|
||||||
box_xyxy_to_xywh,
|
|
||||||
build_all_layer_point_grids,
|
|
||||||
calculate_stability_score,
|
|
||||||
coco_encode_rle,
|
|
||||||
generate_crop_boxes,
|
|
||||||
is_box_near_crop_edge,
|
|
||||||
mask_to_rle_pytorch,
|
|
||||||
remove_small_regions,
|
|
||||||
rle_to_mask,
|
|
||||||
uncrop_boxes_xyxy,
|
|
||||||
uncrop_masks,
|
|
||||||
uncrop_points,
|
|
||||||
)
|
|
||||||
|
|
||||||
def sam_interactive_mask(mask_generator, points, in_points, in_labels, mask_input):
|
|
||||||
masks, iou_preds, _ = mask_generator.predictor.predict_torch(
|
|
||||||
in_points,
|
|
||||||
in_labels,
|
|
||||||
mask_input=mask_input,
|
|
||||||
multimask_output=True,
|
|
||||||
return_logits=True,
|
|
||||||
)
|
|
||||||
nm,_,h,w = masks.shape
|
|
||||||
|
|
||||||
# Serialize predictions and store in MaskData
|
|
||||||
data = MaskData(
|
|
||||||
masks=masks.flatten(0, 1),
|
|
||||||
iou_preds=iou_preds.flatten(0, 1),
|
|
||||||
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
|
||||||
)
|
|
||||||
del masks
|
|
||||||
|
|
||||||
# Calculate stability score
|
|
||||||
data["stability_score"] = calculate_stability_score(
|
|
||||||
data["masks"], mask_generator.predictor.model.mask_threshold, mask_generator.stability_score_offset
|
|
||||||
)
|
|
||||||
|
|
||||||
masks = data["masks"].reshape(nm, -1, h, w)
|
|
||||||
scores = (data['iou_preds'] + data['stability_score']).reshape(nm, -1)
|
|
||||||
|
|
||||||
index = torch.stack([torch.arange(nm).cuda(), scores.argmax(dim=1)]).tolist()
|
|
||||||
return masks[index]
|
|
||||||
|
|
||||||
def inference_sam_m2m_interactive(model, image, spatial_masks, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
|
||||||
t = []
|
|
||||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
|
||||||
transform1 = transforms.Compose(t)
|
|
||||||
image_ori = transform1(image)
|
|
||||||
|
|
||||||
image_ori = np.asarray(image_ori)
|
|
||||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
|
||||||
|
|
||||||
orig_size = images.shape[-2:]
|
|
||||||
orig_h, orig_w = orig_size
|
|
||||||
crop_box = [0,0,orig_w,orig_h]
|
|
||||||
|
|
||||||
spatial_masks = spatial_masks[:, None].float().cuda()
|
|
||||||
spatial_masks = F.interpolate(spatial_masks, size=(orig_h, orig_w), mode='bicubic', align_corners=False) > 0
|
|
||||||
|
|
||||||
# generate single center point
|
|
||||||
# n,_,h,w = spatial_masks.shape
|
|
||||||
# mask_dt = (distance_transform((~F.pad(spatial_masks, pad=(1, 1, 1, 1), mode='constant', value=0)).float())[:,:,1:-1,1:-1]).reshape(n,-1)
|
|
||||||
# max_xy_idx = torch.stack([torch.arange(n), mask_dt.max(dim=-1)[1].cpu()]).tolist()
|
|
||||||
# next_mask = torch.zeros(spatial_masks.shape, device=torch.cuda.current_device()).bool()
|
|
||||||
# next_mask = next_mask.view(n,-1)
|
|
||||||
# next_mask[max_xy_idx] = True
|
|
||||||
# next_mask = next_mask.reshape((n,1,h,w))
|
|
||||||
# points = next_mask.nonzero()[:,2:].flip(dims=[1]).cpu().numpy()
|
|
||||||
|
|
||||||
# stack sampled points
|
|
||||||
acc_points = []
|
|
||||||
for i in range(len(spatial_masks)):
|
|
||||||
points = spatial_masks[i:i+1].nonzero()[:,2:].flip(dims=[1]).cpu().numpy()
|
|
||||||
rand_ids = np.random.choice(points.shape[0], size=40, replace=True)
|
|
||||||
points = points[rand_ids]
|
|
||||||
acc_points.append(points)
|
|
||||||
_np = len(acc_points)
|
|
||||||
points = np.concatenate(acc_points)
|
|
||||||
|
|
||||||
mask_generator = SamAutomaticMaskGenerator(model)
|
|
||||||
mask_generator.predictor.set_image(image_ori)
|
|
||||||
im_size = image_ori.shape[:-1]
|
|
||||||
|
|
||||||
transformed_points = mask_generator.predictor.transform.apply_coords(points, im_size)
|
|
||||||
in_points = torch.as_tensor(transformed_points, device=mask_generator.predictor.device).reshape(_np,-1,2).transpose(0,1)
|
|
||||||
in_labels = torch.ones((in_points.shape[0], _np), dtype=torch.int, device=mask_generator.predictor.device)
|
|
||||||
|
|
||||||
masks = sam_interactive_mask(mask_generator, points, in_points.transpose(0,1), in_labels.transpose(0,1), None)
|
|
||||||
|
|
||||||
masks = masks > 0.0
|
|
||||||
iou_preds = torch.ones(masks.shape[0], dtype=torch.float32)
|
|
||||||
points = torch.zeros((masks.shape[0], 2), dtype=torch.float32)
|
|
||||||
|
|
||||||
mask_data = MaskData(
|
|
||||||
masks=masks,
|
|
||||||
iou_preds=iou_preds,
|
|
||||||
points=points,
|
|
||||||
)
|
|
||||||
|
|
||||||
mask_data["stability_score"] = torch.ones(masks.shape[0], dtype=torch.float32)
|
|
||||||
del masks
|
|
||||||
|
|
||||||
mask_data["boxes"] = batched_mask_to_box(mask_data["masks"])
|
|
||||||
mask_data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(mask_data["boxes"]))])
|
|
||||||
|
|
||||||
# Compress to RLE
|
|
||||||
mask_data["masks"] = uncrop_masks(mask_data["masks"], crop_box, orig_h, orig_w)
|
|
||||||
mask_data["rles"] = mask_to_rle_pytorch(mask_data["masks"])
|
|
||||||
del mask_data["masks"]
|
|
||||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
|
||||||
|
|
||||||
# Write mask records
|
|
||||||
outputs = []
|
|
||||||
for idx in range(len(mask_data["segmentations"])):
|
|
||||||
ann = {
|
|
||||||
"segmentation": mask_data["segmentations"][idx],
|
|
||||||
"area": area_from_rle(mask_data["rles"][idx]),
|
|
||||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
|
||||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
|
||||||
"point_coords": [mask_data["points"][idx].tolist()],
|
|
||||||
"stability_score": mask_data["stability_score"][idx].item(),
|
|
||||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
|
||||||
}
|
|
||||||
outputs.append(ann)
|
|
||||||
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
visual = Visualizer(image_ori, metadata=metadata)
|
|
||||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
|
||||||
label = 1
|
|
||||||
# for ann in sorted_anns:
|
|
||||||
# mask = ann['segmentation']
|
|
||||||
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
|
||||||
# label += 1
|
|
||||||
# im = demo.get_image()
|
|
||||||
|
|
||||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
|
||||||
for i, ann in enumerate(sorted_anns):
|
|
||||||
mask = ann['segmentation']
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
# color_mask = [int(c*255) for c in color_mask]
|
|
||||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
|
||||||
# assign the mask to the mask_map
|
|
||||||
mask_map[mask == 1] = label
|
|
||||||
label += 1
|
|
||||||
im = demo.get_image()
|
|
||||||
# fig=plt.figure(figsize=(10, 10))
|
|
||||||
# plt.imshow(image_ori)
|
|
||||||
# show_anns(outputs)
|
|
||||||
# fig.canvas.draw()
|
|
||||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
|
||||||
return im, sorted_anns
|
|
||||||
|
|
||||||
|
|
||||||
def remove_small_regions(
|
|
||||||
mask: np.ndarray, area_thresh: float, mode: str
|
|
||||||
) -> Tuple[np.ndarray, bool]:
|
|
||||||
"""
|
|
||||||
Removes small disconnected regions and holes in a mask. Returns the
|
|
||||||
mask and an indicator of if the mask has been modified.
|
|
||||||
"""
|
|
||||||
import cv2 # type: ignore
|
|
||||||
|
|
||||||
assert mode in ["holes", "islands"]
|
|
||||||
correct_holes = mode == "holes"
|
|
||||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
||||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
||||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
||||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
||||||
if len(small_regions) == 0:
|
|
||||||
return mask, False
|
|
||||||
fill_labels = [0] + small_regions
|
|
||||||
if not correct_holes:
|
|
||||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
|
||||||
# If every region is below threshold, keep largest
|
|
||||||
if len(fill_labels) == 0:
|
|
||||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
|
||||||
mask = np.isin(regions, fill_labels)
|
|
||||||
return mask, True
|
|
||||||
|
|
||||||
def show_anns(anns):
|
|
||||||
if len(anns) == 0:
|
|
||||||
return
|
|
||||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
|
||||||
ax = plt.gca()
|
|
||||||
ax.set_autoscale_on(False)
|
|
||||||
polygons = []
|
|
||||||
color = []
|
|
||||||
for ann in sorted_anns:
|
|
||||||
m = ann['segmentation']
|
|
||||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
for i in range(3):
|
|
||||||
img[:,:,i] = color_mask[i]
|
|
||||||
ax.imshow(np.dstack((img, m*0.35)))
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
from .interactive_seem_m2m_auto import *
|
|
||||||
from .inference_seem_pano import *
|
|
||||||
from .inference_seem_interactive import *
|
|
||||||
@@ -1,382 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from segment_anything.modeling import Sam
|
|
||||||
from segment_anything.utils.amg import (
|
|
||||||
MaskData,
|
|
||||||
area_from_rle,
|
|
||||||
batch_iterator,
|
|
||||||
batched_mask_to_box,
|
|
||||||
box_xyxy_to_xywh,
|
|
||||||
build_all_layer_point_grids,
|
|
||||||
calculate_stability_score,
|
|
||||||
coco_encode_rle,
|
|
||||||
generate_crop_boxes,
|
|
||||||
is_box_near_crop_edge,
|
|
||||||
mask_to_rle_pytorch,
|
|
||||||
remove_small_regions,
|
|
||||||
rle_to_mask,
|
|
||||||
uncrop_boxes_xyxy,
|
|
||||||
uncrop_masks,
|
|
||||||
uncrop_points,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SeemAutomaticMaskGenerator:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: Sam,
|
|
||||||
points_per_side: Optional[int] = 32,
|
|
||||||
points_per_batch: int = 64,
|
|
||||||
pred_iou_thresh: float = 0.9,
|
|
||||||
stability_score_thresh: float = 0.5,
|
|
||||||
stability_score_offset: float = 1.0,
|
|
||||||
box_nms_thresh: float = 0.7,
|
|
||||||
crop_n_layers: int = 0,
|
|
||||||
crop_nms_thresh: float = 0.7,
|
|
||||||
crop_overlap_ratio: float = 512 / 1500,
|
|
||||||
crop_n_points_downscale_factor: int = 1,
|
|
||||||
point_grids: Optional[List[np.ndarray]] = None,
|
|
||||||
min_mask_region_area: int = 0,
|
|
||||||
output_mode: str = "binary_mask",
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Using a SAM model, generates masks for the entire image.
|
|
||||||
Generates a grid of point prompts over the image, then filters
|
|
||||||
low quality and duplicate masks. The default settings are chosen
|
|
||||||
for SAM with a ViT-H backbone.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
model (Sam): The SAM model to use for mask prediction.
|
|
||||||
points_per_side (int or None): The number of points to be sampled
|
|
||||||
along one side of the image. The total number of points is
|
|
||||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
|
||||||
point sampling.
|
|
||||||
points_per_batch (int): Sets the number of points run simultaneously
|
|
||||||
by the model. Higher numbers may be faster but use more GPU memory.
|
|
||||||
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
|
||||||
model's predicted mask quality.
|
|
||||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
|
||||||
the stability of the mask under changes to the cutoff used to binarize
|
|
||||||
the model's mask predictions.
|
|
||||||
stability_score_offset (float): The amount to shift the cutoff when
|
|
||||||
calculated the stability score.
|
|
||||||
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
|
||||||
suppression to filter duplicate masks.
|
|
||||||
crop_n_layers (int): If >0, mask prediction will be run again on
|
|
||||||
crops of the image. Sets the number of layers to run, where each
|
|
||||||
layer has 2**i_layer number of image crops.
|
|
||||||
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
|
||||||
suppression to filter duplicate masks between different crops.
|
|
||||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
|
||||||
In the first crop layer, crops will overlap by this fraction of
|
|
||||||
the image length. Later layers with more crops scale down this overlap.
|
|
||||||
crop_n_points_downscale_factor (int): The number of points-per-side
|
|
||||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
|
||||||
point_grids (list(np.ndarray) or None): A list over explicit grids
|
|
||||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
|
||||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
|
||||||
min_mask_region_area (int): If >0, postprocessing will be applied
|
|
||||||
to remove disconnected regions and holes in masks with area smaller
|
|
||||||
than min_mask_region_area. Requires opencv.
|
|
||||||
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
|
||||||
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
|
||||||
For large resolutions, 'binary_mask' may consume large amounts of
|
|
||||||
memory.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert (points_per_side is None) != (
|
|
||||||
point_grids is None
|
|
||||||
), "Exactly one of points_per_side or point_grid must be provided."
|
|
||||||
if points_per_side is not None:
|
|
||||||
self.point_grids = build_all_layer_point_grids(
|
|
||||||
points_per_side,
|
|
||||||
crop_n_layers,
|
|
||||||
crop_n_points_downscale_factor,
|
|
||||||
)
|
|
||||||
elif point_grids is not None:
|
|
||||||
self.point_grids = point_grids
|
|
||||||
else:
|
|
||||||
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
|
||||||
|
|
||||||
assert output_mode in [
|
|
||||||
"binary_mask",
|
|
||||||
"uncompressed_rle",
|
|
||||||
"coco_rle",
|
|
||||||
], f"Unknown output_mode {output_mode}."
|
|
||||||
if output_mode == "coco_rle":
|
|
||||||
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
|
||||||
|
|
||||||
if min_mask_region_area > 0:
|
|
||||||
import cv2 # type: ignore # noqa: F401
|
|
||||||
|
|
||||||
self.predictor = model
|
|
||||||
self.points_per_batch = points_per_batch
|
|
||||||
self.pred_iou_thresh = pred_iou_thresh
|
|
||||||
self.stability_score_thresh = stability_score_thresh
|
|
||||||
self.stability_score_offset = stability_score_offset
|
|
||||||
self.box_nms_thresh = box_nms_thresh
|
|
||||||
self.crop_n_layers = crop_n_layers
|
|
||||||
self.crop_nms_thresh = crop_nms_thresh
|
|
||||||
self.crop_overlap_ratio = crop_overlap_ratio
|
|
||||||
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
|
||||||
self.min_mask_region_area = min_mask_region_area
|
|
||||||
self.output_mode = output_mode
|
|
||||||
|
|
||||||
# dilate conv
|
|
||||||
self.dilation = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3, bias=False)
|
|
||||||
self.dilation.weight.data.fill_(1.0)
|
|
||||||
self.dilation.cuda()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Generates masks for the given image.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list(dict(str, any)): A list over records for masks. Each record is
|
|
||||||
a dict containing the following keys:
|
|
||||||
segmentation (dict(str, any) or np.ndarray): The mask. If
|
|
||||||
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
|
||||||
is a dictionary containing the RLE.
|
|
||||||
bbox (list(float)): The box around the mask, in XYWH format.
|
|
||||||
area (int): The area in pixels of the mask.
|
|
||||||
predicted_iou (float): The model's own prediction of the mask's
|
|
||||||
quality. This is filtered by the pred_iou_thresh parameter.
|
|
||||||
point_coords (list(list(float))): The point coordinates input
|
|
||||||
to the model to generate this mask.
|
|
||||||
stability_score (float): A measure of the mask's quality. This
|
|
||||||
is filtered on using the stability_score_thresh parameter.
|
|
||||||
crop_box (list(float)): The crop of the image used to generate
|
|
||||||
the mask, given in XYWH format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Generate masks
|
|
||||||
mask_data = self._generate_masks(image)
|
|
||||||
|
|
||||||
# Filter small disconnected regions and holes in masks
|
|
||||||
if self.min_mask_region_area > 0:
|
|
||||||
mask_data = self.postprocess_small_regions(
|
|
||||||
mask_data,
|
|
||||||
self.min_mask_region_area,
|
|
||||||
max(self.box_nms_thresh, self.crop_nms_thresh),
|
|
||||||
)
|
|
||||||
# Encode masks
|
|
||||||
if self.output_mode == "coco_rle":
|
|
||||||
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
|
||||||
elif self.output_mode == "binary_mask":
|
|
||||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
|
||||||
else:
|
|
||||||
mask_data["segmentations"] = mask_data["rles"]
|
|
||||||
|
|
||||||
# Write mask records
|
|
||||||
curr_anns = []
|
|
||||||
for idx in range(len(mask_data["segmentations"])):
|
|
||||||
ann = {
|
|
||||||
"segmentation": mask_data["segmentations"][idx],
|
|
||||||
"area": area_from_rle(mask_data["rles"][idx]),
|
|
||||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
|
||||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
|
||||||
"point_coords": [mask_data["points"][idx].tolist()],
|
|
||||||
"stability_score": mask_data["stability_score"][idx].item(),
|
|
||||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
|
||||||
}
|
|
||||||
curr_anns.append(ann)
|
|
||||||
|
|
||||||
return curr_anns
|
|
||||||
|
|
||||||
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
|
||||||
orig_size = image.shape[-2:]
|
|
||||||
crop_boxes, layer_idxs = generate_crop_boxes(
|
|
||||||
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
|
||||||
)
|
|
||||||
|
|
||||||
# Iterate over image crops
|
|
||||||
data = MaskData()
|
|
||||||
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
|
||||||
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
|
||||||
data.cat(crop_data)
|
|
||||||
|
|
||||||
# Remove duplicate masks between crops
|
|
||||||
if len(crop_boxes) > 1:
|
|
||||||
# Prefer masks from smaller crops
|
|
||||||
scores = 1 / box_area(data["crop_boxes"])
|
|
||||||
scores = scores.to(data["boxes"].device)
|
|
||||||
keep_by_nms = batched_nms(
|
|
||||||
data["boxes"].float(),
|
|
||||||
scores,
|
|
||||||
torch.zeros_like(data["boxes"][:, 0]), # categories
|
|
||||||
iou_threshold=self.crop_nms_thresh,
|
|
||||||
)
|
|
||||||
data.filter(keep_by_nms)
|
|
||||||
|
|
||||||
data.to_numpy()
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _process_crop(
|
|
||||||
self,
|
|
||||||
image: np.ndarray,
|
|
||||||
crop_box: List[int],
|
|
||||||
crop_layer_idx: int,
|
|
||||||
orig_size: Tuple[int, ...],
|
|
||||||
) -> MaskData:
|
|
||||||
# Crop the image and calculate embeddings
|
|
||||||
x0, y0, x1, y1 = crop_box
|
|
||||||
cropped_im = image#[y0:y1, x0:x1, :]
|
|
||||||
cropped_im_size = cropped_im.shape[-2:]
|
|
||||||
# self.predictor.set_image(cropped_im)
|
|
||||||
|
|
||||||
# Get points for this crop
|
|
||||||
points_scale = np.array(cropped_im_size)[None, ::-1]
|
|
||||||
points_for_image = self.point_grids[crop_layer_idx] #* points_scale
|
|
||||||
|
|
||||||
# Generate masks for this crop in batches
|
|
||||||
data = MaskData()
|
|
||||||
self.enc_features=None
|
|
||||||
|
|
||||||
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
|
||||||
batch_data = self._process_batch(cropped_im, points, cropped_im_size, crop_box, orig_size)
|
|
||||||
data.cat(batch_data)
|
|
||||||
del batch_data
|
|
||||||
|
|
||||||
# Remove duplicates within this crop.
|
|
||||||
keep_by_nms = batched_nms(
|
|
||||||
data["boxes"].float(),
|
|
||||||
data["iou_preds"],
|
|
||||||
torch.zeros(len(data["boxes"])), # categories
|
|
||||||
iou_threshold=self.box_nms_thresh,
|
|
||||||
)
|
|
||||||
|
|
||||||
data.filter(keep_by_nms)
|
|
||||||
|
|
||||||
# Return to the original image frame
|
|
||||||
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
|
||||||
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _process_batch(
|
|
||||||
self,
|
|
||||||
images,
|
|
||||||
points: np.ndarray,
|
|
||||||
im_size: Tuple[int, ...],
|
|
||||||
crop_box: List[int],
|
|
||||||
orig_size: Tuple[int, ...],
|
|
||||||
) -> MaskData:
|
|
||||||
orig_h, orig_w = orig_size
|
|
||||||
|
|
||||||
data = {"image": images, "height": orig_h, "width": orig_w}
|
|
||||||
points = torch.tensor(points,dtype=torch.float).to(images.device)
|
|
||||||
|
|
||||||
# prepare interactive mask for seem
|
|
||||||
abs_points = (points * torch.tensor(orig_size)[None,:].to(points.device)).long()
|
|
||||||
abs_masks = torch.zeros((len(points), orig_h, orig_w), dtype=torch.bool).to(device=points.device)
|
|
||||||
abs_masks[torch.arange(0, abs_points.size(0))[:,None], abs_points[:,0:1], abs_points[:,1:2]] = True
|
|
||||||
abs_masks = self.dilation(abs_masks[:,None].float())[:,0] > 0
|
|
||||||
data['spatial_query'] = {'rand_shape': abs_masks[:,None]}
|
|
||||||
|
|
||||||
batch_inputs = [data]
|
|
||||||
if self.enc_features is None:
|
|
||||||
masks, iou_preds, mask_features, transformer_encoder_features, multi_scale_features = self.predictor.model.evaluate_demo(batch_inputs, None, None, return_features=True)
|
|
||||||
self.enc_features = (mask_features, transformer_encoder_features, multi_scale_features)
|
|
||||||
else:
|
|
||||||
masks, iou_preds = self.predictor.model.evaluate_demo(batch_inputs, self.enc_features[0], self.enc_features[1], self.enc_features[2])
|
|
||||||
|
|
||||||
data = MaskData(
|
|
||||||
masks=masks,
|
|
||||||
iou_preds=iou_preds,
|
|
||||||
points=points,
|
|
||||||
)
|
|
||||||
del masks
|
|
||||||
# Filter by predicted IoU
|
|
||||||
if self.pred_iou_thresh > 0.0:
|
|
||||||
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
|
||||||
data.filter(keep_mask)
|
|
||||||
|
|
||||||
# Calculate stability score
|
|
||||||
data["stability_score"] = calculate_stability_score(
|
|
||||||
data["masks"], 0.0, self.stability_score_offset
|
|
||||||
)
|
|
||||||
if self.stability_score_thresh > 0.0:
|
|
||||||
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
|
||||||
data.filter(keep_mask)
|
|
||||||
|
|
||||||
# Threshold masks and calculate boxes
|
|
||||||
data["masks"] = data["masks"] > 0.0
|
|
||||||
data["boxes"] = batched_mask_to_box(data["masks"])
|
|
||||||
|
|
||||||
# Filter boxes that touch crop boundaries
|
|
||||||
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
|
||||||
if not torch.all(keep_mask):
|
|
||||||
data.filter(keep_mask)
|
|
||||||
|
|
||||||
# Compress to RLE
|
|
||||||
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
|
||||||
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
|
||||||
del data["masks"]
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def postprocess_small_regions(
|
|
||||||
mask_data: MaskData, min_area: int, nms_thresh: float
|
|
||||||
) -> MaskData:
|
|
||||||
"""
|
|
||||||
Removes small disconnected regions and holes in masks, then reruns
|
|
||||||
box NMS to remove any new duplicates.
|
|
||||||
|
|
||||||
Edits mask_data in place.
|
|
||||||
|
|
||||||
Requires open-cv as a dependency.
|
|
||||||
"""
|
|
||||||
if len(mask_data["rles"]) == 0:
|
|
||||||
return mask_data
|
|
||||||
|
|
||||||
# Filter small disconnected regions and holes
|
|
||||||
new_masks = []
|
|
||||||
scores = []
|
|
||||||
for rle in mask_data["rles"]:
|
|
||||||
mask = rle_to_mask(rle)
|
|
||||||
|
|
||||||
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
|
||||||
unchanged = not changed
|
|
||||||
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
|
||||||
unchanged = unchanged and not changed
|
|
||||||
|
|
||||||
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
|
||||||
# Give score=0 to changed masks and score=1 to unchanged masks
|
|
||||||
# so NMS will prefer ones that didn't need postprocessing
|
|
||||||
scores.append(float(unchanged))
|
|
||||||
|
|
||||||
# Recalculate boxes and remove any new duplicates
|
|
||||||
masks = torch.cat(new_masks, dim=0)
|
|
||||||
boxes = batched_mask_to_box(masks)
|
|
||||||
keep_by_nms = batched_nms(
|
|
||||||
boxes.float(),
|
|
||||||
torch.as_tensor(scores),
|
|
||||||
torch.zeros_like(boxes[:, 0]), # categories
|
|
||||||
iou_threshold=nms_thresh,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only recalculate RLEs for masks that have changed
|
|
||||||
for i_mask in keep_by_nms:
|
|
||||||
if scores[i_mask] == 0.0:
|
|
||||||
mask_torch = masks[i_mask].unsqueeze(0)
|
|
||||||
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
|
||||||
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
|
||||||
mask_data.filter(keep_by_nms)
|
|
||||||
|
|
||||||
return mask_data
|
|
||||||
@@ -1,169 +0,0 @@
|
|||||||
# --------------------------------------------------------
|
|
||||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
|
||||||
# Copyright (c) 2023 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import numpy as np
|
|
||||||
from torchvision import transforms
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
from typing import Tuple
|
|
||||||
from PIL import Image
|
|
||||||
from detectron2.data import MetadataCatalog
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import cv2
|
|
||||||
import io
|
|
||||||
from .automatic_mask_generator import SeemAutomaticMaskGenerator
|
|
||||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
|
||||||
|
|
||||||
from segment_anything.utils.amg import (
|
|
||||||
MaskData,
|
|
||||||
area_from_rle,
|
|
||||||
batch_iterator,
|
|
||||||
batched_mask_to_box,
|
|
||||||
box_xyxy_to_xywh,
|
|
||||||
build_all_layer_point_grids,
|
|
||||||
calculate_stability_score,
|
|
||||||
coco_encode_rle,
|
|
||||||
generate_crop_boxes,
|
|
||||||
is_box_near_crop_edge,
|
|
||||||
mask_to_rle_pytorch,
|
|
||||||
remove_small_regions,
|
|
||||||
rle_to_mask,
|
|
||||||
uncrop_boxes_xyxy,
|
|
||||||
uncrop_masks,
|
|
||||||
uncrop_points,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def inference_seem_interactive(model, image, spatial_masks, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
|
||||||
t = []
|
|
||||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
|
||||||
transform1 = transforms.Compose(t)
|
|
||||||
image_ori = transform1(image)
|
|
||||||
|
|
||||||
image_ori = np.asarray(image_ori)
|
|
||||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
|
||||||
|
|
||||||
orig_size = images.shape[-2:]
|
|
||||||
orig_h, orig_w = orig_size
|
|
||||||
crop_box = [0,0,orig_w,orig_h]
|
|
||||||
|
|
||||||
data = {"image": images, "height": orig_h, "width": orig_w}
|
|
||||||
|
|
||||||
spatial_masks = spatial_masks[:, None].float().cuda()
|
|
||||||
spatial_masks = F.interpolate(spatial_masks, size=(orig_h, orig_w), mode='bicubic', align_corners=False) > 0
|
|
||||||
data['spatial_query'] = {'rand_shape': spatial_masks}
|
|
||||||
|
|
||||||
model.model.metadata = metadata
|
|
||||||
masks, _ = model.model.evaluate_demo([data])
|
|
||||||
masks = masks > 0.0
|
|
||||||
iou_preds = torch.ones(masks.shape[0], dtype=torch.float32)
|
|
||||||
points = torch.zeros((masks.shape[0], 2), dtype=torch.float32)
|
|
||||||
|
|
||||||
mask_data = MaskData(
|
|
||||||
masks=masks,
|
|
||||||
iou_preds=iou_preds,
|
|
||||||
points=points,
|
|
||||||
)
|
|
||||||
|
|
||||||
mask_data["stability_score"] = torch.ones(masks.shape[0], dtype=torch.float32)
|
|
||||||
del masks
|
|
||||||
|
|
||||||
mask_data["boxes"] = batched_mask_to_box(mask_data["masks"])
|
|
||||||
mask_data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(mask_data["boxes"]))])
|
|
||||||
|
|
||||||
# Compress to RLE
|
|
||||||
mask_data["masks"] = uncrop_masks(mask_data["masks"], crop_box, orig_h, orig_w)
|
|
||||||
mask_data["rles"] = mask_to_rle_pytorch(mask_data["masks"])
|
|
||||||
del mask_data["masks"]
|
|
||||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
|
||||||
|
|
||||||
# Write mask records
|
|
||||||
outputs = []
|
|
||||||
for idx in range(len(mask_data["segmentations"])):
|
|
||||||
ann = {
|
|
||||||
"segmentation": mask_data["segmentations"][idx],
|
|
||||||
"area": area_from_rle(mask_data["rles"][idx]),
|
|
||||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
|
||||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
|
||||||
"point_coords": [mask_data["points"][idx].tolist()],
|
|
||||||
"stability_score": mask_data["stability_score"][idx].item(),
|
|
||||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
|
||||||
}
|
|
||||||
outputs.append(ann)
|
|
||||||
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
visual = Visualizer(image_ori, metadata=metadata)
|
|
||||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
|
||||||
label = 1
|
|
||||||
# for ann in sorted_anns:
|
|
||||||
# mask = ann['segmentation']
|
|
||||||
# color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
# # color_mask = [int(c*255) for c in color_mask]
|
|
||||||
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
|
||||||
# label += 1
|
|
||||||
# im = demo.get_image()
|
|
||||||
|
|
||||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
|
||||||
for i, ann in enumerate(sorted_anns):
|
|
||||||
mask = ann['segmentation']
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
# color_mask = [int(c*255) for c in color_mask]
|
|
||||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
|
||||||
# assign the mask to the mask_map
|
|
||||||
mask_map[mask == 1] = label
|
|
||||||
label += 1
|
|
||||||
im = demo.get_image()
|
|
||||||
# fig=plt.figure(figsize=(10, 10))
|
|
||||||
# plt.imshow(image_ori)
|
|
||||||
# show_anns(outputs)
|
|
||||||
# fig.canvas.draw()
|
|
||||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
|
||||||
return im, sorted_anns
|
|
||||||
|
|
||||||
|
|
||||||
def remove_small_regions(
|
|
||||||
mask: np.ndarray, area_thresh: float, mode: str
|
|
||||||
) -> Tuple[np.ndarray, bool]:
|
|
||||||
"""
|
|
||||||
Removes small disconnected regions and holes in a mask. Returns the
|
|
||||||
mask and an indicator of if the mask has been modified.
|
|
||||||
"""
|
|
||||||
import cv2 # type: ignore
|
|
||||||
|
|
||||||
assert mode in ["holes", "islands"]
|
|
||||||
correct_holes = mode == "holes"
|
|
||||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
||||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
||||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
||||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
||||||
if len(small_regions) == 0:
|
|
||||||
return mask, False
|
|
||||||
fill_labels = [0] + small_regions
|
|
||||||
if not correct_holes:
|
|
||||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
|
||||||
# If every region is below threshold, keep largest
|
|
||||||
if len(fill_labels) == 0:
|
|
||||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
|
||||||
mask = np.isin(regions, fill_labels)
|
|
||||||
return mask, True
|
|
||||||
|
|
||||||
def show_anns(anns):
|
|
||||||
if len(anns) == 0:
|
|
||||||
return
|
|
||||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
|
||||||
ax = plt.gca()
|
|
||||||
ax.set_autoscale_on(False)
|
|
||||||
polygons = []
|
|
||||||
color = []
|
|
||||||
for ann in sorted_anns:
|
|
||||||
m = ann['segmentation']
|
|
||||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
for i in range(3):
|
|
||||||
img[:,:,i] = color_mask[i]
|
|
||||||
ax.imshow(np.dstack((img, m*0.35)))
|
|
||||||
@@ -1,164 +0,0 @@
|
|||||||
# --------------------------------------------------------
|
|
||||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
|
||||||
# Copyright (c) 2023 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torchvision import transforms
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
from typing import Tuple
|
|
||||||
from PIL import Image
|
|
||||||
from detectron2.data import MetadataCatalog
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import cv2
|
|
||||||
import io
|
|
||||||
from .automatic_mask_generator import SeemAutomaticMaskGenerator
|
|
||||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
|
||||||
|
|
||||||
from segment_anything.utils.amg import (
|
|
||||||
MaskData,
|
|
||||||
area_from_rle,
|
|
||||||
batch_iterator,
|
|
||||||
batched_mask_to_box,
|
|
||||||
box_xyxy_to_xywh,
|
|
||||||
build_all_layer_point_grids,
|
|
||||||
calculate_stability_score,
|
|
||||||
coco_encode_rle,
|
|
||||||
generate_crop_boxes,
|
|
||||||
is_box_near_crop_edge,
|
|
||||||
mask_to_rle_pytorch,
|
|
||||||
remove_small_regions,
|
|
||||||
rle_to_mask,
|
|
||||||
uncrop_boxes_xyxy,
|
|
||||||
uncrop_masks,
|
|
||||||
uncrop_points,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def inference_seem_pano(model, image, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
|
||||||
t = []
|
|
||||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
|
||||||
transform1 = transforms.Compose(t)
|
|
||||||
image_ori = transform1(image)
|
|
||||||
|
|
||||||
image_ori = np.asarray(image_ori)
|
|
||||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
|
||||||
|
|
||||||
orig_size = images.shape[-2:]
|
|
||||||
orig_h, orig_w = orig_size
|
|
||||||
crop_box = [0,0,orig_w,orig_h]
|
|
||||||
|
|
||||||
data = {"image": images, "height": orig_h, "width": orig_w}
|
|
||||||
batch_inputs = [data]
|
|
||||||
|
|
||||||
model.model.metadata = metadata
|
|
||||||
outputs = model.model.evaluate(batch_inputs)
|
|
||||||
|
|
||||||
pano_mask = outputs[0]['panoptic_seg'][0]
|
|
||||||
pano_info = outputs[0]['panoptic_seg'][1]
|
|
||||||
|
|
||||||
masks = []
|
|
||||||
for seg_info in pano_info:
|
|
||||||
masks += [pano_mask == seg_info['id']]
|
|
||||||
masks = torch.stack(masks, dim=0)
|
|
||||||
iou_preds = torch.ones(masks.shape[0], dtype=torch.float32)
|
|
||||||
points = torch.zeros((masks.shape[0], 2), dtype=torch.float32)
|
|
||||||
|
|
||||||
mask_data = MaskData(
|
|
||||||
masks=masks,
|
|
||||||
iou_preds=iou_preds,
|
|
||||||
points=points,
|
|
||||||
)
|
|
||||||
mask_data["stability_score"] = torch.ones(masks.shape[0], dtype=torch.float32)
|
|
||||||
del masks
|
|
||||||
|
|
||||||
mask_data["boxes"] = batched_mask_to_box(mask_data["masks"])
|
|
||||||
mask_data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(mask_data["boxes"]))])
|
|
||||||
|
|
||||||
# Compress to RLE
|
|
||||||
mask_data["masks"] = uncrop_masks(mask_data["masks"], crop_box, orig_h, orig_w)
|
|
||||||
mask_data["rles"] = mask_to_rle_pytorch(mask_data["masks"])
|
|
||||||
del mask_data["masks"]
|
|
||||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
|
||||||
|
|
||||||
# Write mask records
|
|
||||||
outputs = []
|
|
||||||
for idx in range(len(mask_data["segmentations"])):
|
|
||||||
ann = {
|
|
||||||
"segmentation": mask_data["segmentations"][idx],
|
|
||||||
"area": area_from_rle(mask_data["rles"][idx]),
|
|
||||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
|
||||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
|
||||||
"point_coords": [mask_data["points"][idx].tolist()],
|
|
||||||
"stability_score": mask_data["stability_score"][idx].item(),
|
|
||||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
|
||||||
}
|
|
||||||
outputs.append(ann)
|
|
||||||
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
visual = Visualizer(image_ori, metadata=metadata)
|
|
||||||
# create a full zero image as the image_orig
|
|
||||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
|
||||||
label = 1
|
|
||||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
|
||||||
for i, ann in enumerate(sorted_anns):
|
|
||||||
mask = ann['segmentation']
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
# color_mask = [int(c*255) for c in color_mask]
|
|
||||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
|
||||||
# assign the mask to the mask_map
|
|
||||||
mask_map[mask == 1] = label
|
|
||||||
label += 1
|
|
||||||
im = demo.get_image()
|
|
||||||
# fig=plt.figure(figsize=(10, 10))
|
|
||||||
# plt.imshow(image_ori)
|
|
||||||
# show_anns(outputs)
|
|
||||||
# fig.canvas.draw()
|
|
||||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
|
||||||
return im, sorted_anns
|
|
||||||
|
|
||||||
|
|
||||||
def remove_small_regions(
|
|
||||||
mask: np.ndarray, area_thresh: float, mode: str
|
|
||||||
) -> Tuple[np.ndarray, bool]:
|
|
||||||
"""
|
|
||||||
Removes small disconnected regions and holes in a mask. Returns the
|
|
||||||
mask and an indicator of if the mask has been modified.
|
|
||||||
"""
|
|
||||||
import cv2 # type: ignore
|
|
||||||
|
|
||||||
assert mode in ["holes", "islands"]
|
|
||||||
correct_holes = mode == "holes"
|
|
||||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
||||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
||||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
||||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
||||||
if len(small_regions) == 0:
|
|
||||||
return mask, False
|
|
||||||
fill_labels = [0] + small_regions
|
|
||||||
if not correct_holes:
|
|
||||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
|
||||||
# If every region is below threshold, keep largest
|
|
||||||
if len(fill_labels) == 0:
|
|
||||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
|
||||||
mask = np.isin(regions, fill_labels)
|
|
||||||
return mask, True
|
|
||||||
|
|
||||||
def show_anns(anns):
|
|
||||||
if len(anns) == 0:
|
|
||||||
return
|
|
||||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
|
||||||
ax = plt.gca()
|
|
||||||
ax.set_autoscale_on(False)
|
|
||||||
polygons = []
|
|
||||||
color = []
|
|
||||||
for ann in sorted_anns:
|
|
||||||
m = ann['segmentation']
|
|
||||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
for i in range(3):
|
|
||||||
img[:,:,i] = color_mask[i]
|
|
||||||
ax.imshow(np.dstack((img, m*0.35)))
|
|
||||||
@@ -1,93 +0,0 @@
|
|||||||
# --------------------------------------------------------
|
|
||||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
|
||||||
# Copyright (c) 2023 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torchvision import transforms
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
from typing import Tuple
|
|
||||||
from PIL import Image
|
|
||||||
from detectron2.data import MetadataCatalog
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import cv2
|
|
||||||
import io
|
|
||||||
from .automatic_mask_generator import SeemAutomaticMaskGenerator
|
|
||||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
|
||||||
|
|
||||||
def interactive_seem_m2m_auto(model, image, text_size, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
|
||||||
t = []
|
|
||||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
|
||||||
transform1 = transforms.Compose(t)
|
|
||||||
image_ori = transform1(image)
|
|
||||||
|
|
||||||
image_ori = np.asarray(image_ori)
|
|
||||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
|
||||||
|
|
||||||
mask_generator = SeemAutomaticMaskGenerator(model)
|
|
||||||
outputs = mask_generator.generate(images)
|
|
||||||
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
visual = Visualizer(image_ori, metadata=metadata)
|
|
||||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
|
||||||
label = 1
|
|
||||||
for ann in sorted_anns:
|
|
||||||
mask = ann['segmentation']
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
# color_mask = [int(c*255) for c in color_mask]
|
|
||||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
|
||||||
label += 1
|
|
||||||
im = demo.get_image()
|
|
||||||
|
|
||||||
# fig=plt.figure(figsize=(10, 10))
|
|
||||||
# plt.imshow(image_ori)
|
|
||||||
# show_anns(outputs)
|
|
||||||
# fig.canvas.draw()
|
|
||||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
|
||||||
return im
|
|
||||||
|
|
||||||
|
|
||||||
def remove_small_regions(
|
|
||||||
mask: np.ndarray, area_thresh: float, mode: str
|
|
||||||
) -> Tuple[np.ndarray, bool]:
|
|
||||||
"""
|
|
||||||
Removes small disconnected regions and holes in a mask. Returns the
|
|
||||||
mask and an indicator of if the mask has been modified.
|
|
||||||
"""
|
|
||||||
import cv2 # type: ignore
|
|
||||||
|
|
||||||
assert mode in ["holes", "islands"]
|
|
||||||
correct_holes = mode == "holes"
|
|
||||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
||||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
||||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
||||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
||||||
if len(small_regions) == 0:
|
|
||||||
return mask, False
|
|
||||||
fill_labels = [0] + small_regions
|
|
||||||
if not correct_holes:
|
|
||||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
|
||||||
# If every region is below threshold, keep largest
|
|
||||||
if len(fill_labels) == 0:
|
|
||||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
|
||||||
mask = np.isin(regions, fill_labels)
|
|
||||||
return mask, True
|
|
||||||
|
|
||||||
def show_anns(anns):
|
|
||||||
if len(anns) == 0:
|
|
||||||
return
|
|
||||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
|
||||||
ax = plt.gca()
|
|
||||||
ax.set_autoscale_on(False)
|
|
||||||
polygons = []
|
|
||||||
color = []
|
|
||||||
for ann in sorted_anns:
|
|
||||||
m = ann['segmentation']
|
|
||||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
for i in range(3):
|
|
||||||
img[:,:,i] = color_mask[i]
|
|
||||||
ax.imshow(np.dstack((img, m*0.35)))
|
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
from .interactive_idino_m2m import interactive_infer_image as interactive_infer_image_idino_m2m
|
|
||||||
from .interactive_idino_m2m import interactive_infer_image_semantic, interactive_infer_image_3l
|
|
||||||
from .inference_semsam_m2m_auto import inference_semsam_m2m_auto
|
|
||||||
from .interactive_idino_1o1_box import interactive_infer_image_box as interactive_infer_image_idino_m2m_box
|
|
||||||
from .automatic_mask_generator import prompt_switch
|
|
||||||
from .interactive_predictor import SemanticSAMPredictor
|
|
||||||
@@ -1,393 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
|
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
# from
|
|
||||||
# from .modeling import Sam
|
|
||||||
# from .predictor import SamPredictor
|
|
||||||
from semantic_sam.utils.sam_utils.amg import (
|
|
||||||
MaskData,
|
|
||||||
area_from_rle,
|
|
||||||
batch_iterator,
|
|
||||||
batched_mask_to_box,
|
|
||||||
box_xyxy_to_xywh,
|
|
||||||
build_all_layer_point_grids,
|
|
||||||
calculate_stability_score,
|
|
||||||
coco_encode_rle,
|
|
||||||
generate_crop_boxes,
|
|
||||||
is_box_near_crop_edge,
|
|
||||||
mask_to_rle_pytorch,
|
|
||||||
remove_small_regions,
|
|
||||||
rle_to_mask,
|
|
||||||
uncrop_boxes_xyxy,
|
|
||||||
uncrop_masks,
|
|
||||||
uncrop_points,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def prompt_switch(p):
|
|
||||||
p = int(p)
|
|
||||||
if p == 1:
|
|
||||||
return 3
|
|
||||||
if p == 2:
|
|
||||||
return 2
|
|
||||||
if p == 3:
|
|
||||||
return 0
|
|
||||||
if p == 4:
|
|
||||||
return 4
|
|
||||||
if p == 5:
|
|
||||||
return 1
|
|
||||||
if p == 6:
|
|
||||||
return 5
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class SemanticSamAutomaticMaskGenerator:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model,
|
|
||||||
points_per_side: Optional[int] = 32,
|
|
||||||
points_per_batch: int = 200,
|
|
||||||
pred_iou_thresh: float = 0.88,
|
|
||||||
stability_score_thresh: float = 0.92,
|
|
||||||
stability_score_offset: float = 1.0,
|
|
||||||
box_nms_thresh: float = 0.7,
|
|
||||||
crop_n_layers: int = 0,
|
|
||||||
crop_nms_thresh: float = 0.7,
|
|
||||||
crop_overlap_ratio: float = 512 / 1500,
|
|
||||||
crop_n_points_downscale_factor: int = 1,
|
|
||||||
point_grids: Optional[List[np.ndarray]] = None,
|
|
||||||
min_mask_region_area: int = 10,
|
|
||||||
output_mode: str = "binary_mask",
|
|
||||||
level: list = [1, 2, 3, 4, 5, 6],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Using a SAM model, generates masks for the entire image.
|
|
||||||
Generates a grid of point prompts over the image, then filters
|
|
||||||
low quality and duplicate masks. The default settings are chosen
|
|
||||||
for SAM with a ViT-H backbone.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
model (Sam): The SAM model to use for mask prediction.
|
|
||||||
points_per_side (int or None): The number of points to be sampled
|
|
||||||
along one side of the image. The total number of points is
|
|
||||||
points_per_side**2. If None, 'point_grids' must provide explicit
|
|
||||||
point sampling.
|
|
||||||
points_per_batch (int): Sets the number of points run simultaneously
|
|
||||||
by the model. Higher numbers may be faster but use more GPU memory.
|
|
||||||
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
|
||||||
model's predicted mask quality.
|
|
||||||
stability_score_thresh (float): A filtering threshold in [0,1], using
|
|
||||||
the stability of the mask under changes to the cutoff used to binarize
|
|
||||||
the model's mask predictions.
|
|
||||||
stability_score_offset (float): The amount to shift the cutoff when
|
|
||||||
calculated the stability score.
|
|
||||||
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
|
||||||
suppression to filter duplicate masks.
|
|
||||||
crops_n_layers (int): If >0, mask prediction will be run again on
|
|
||||||
crops of the image. Sets the number of layers to run, where each
|
|
||||||
layer has 2**i_layer number of image crops.
|
|
||||||
crops_nms_thresh (float): The box IoU cutoff used by non-maximal
|
|
||||||
suppression to filter duplicate masks between different crops.
|
|
||||||
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
|
||||||
In the first crop layer, crops will overlap by this fraction of
|
|
||||||
the image length. Later layers with more crops scale down this overlap.
|
|
||||||
crop_n_points_downscale_factor (int): The number of points-per-side
|
|
||||||
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
|
||||||
point_grids (list(np.ndarray) or None): A list over explicit grids
|
|
||||||
of points used for sampling, normalized to [0,1]. The nth grid in the
|
|
||||||
list is used in the nth crop layer. Exclusive with points_per_side.
|
|
||||||
min_mask_region_area (int): If >0, postprocessing will be applied
|
|
||||||
to remove disconnected regions and holes in masks with area smaller
|
|
||||||
than min_mask_region_area. Requires opencv.
|
|
||||||
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
|
||||||
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
|
||||||
For large resolutions, 'binary_mask' may consume large amounts of
|
|
||||||
memory.
|
|
||||||
"""
|
|
||||||
self.level = [prompt_switch(l) for l in level]
|
|
||||||
assert (points_per_side is None) != (
|
|
||||||
point_grids is None
|
|
||||||
), "Exactly one of points_per_side or point_grid must be provided."
|
|
||||||
if points_per_side is not None:
|
|
||||||
self.point_grids = build_all_layer_point_grids(
|
|
||||||
points_per_side,
|
|
||||||
crop_n_layers,
|
|
||||||
crop_n_points_downscale_factor,
|
|
||||||
)
|
|
||||||
elif point_grids is not None:
|
|
||||||
self.point_grids = point_grids
|
|
||||||
else:
|
|
||||||
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
|
||||||
|
|
||||||
assert output_mode in [
|
|
||||||
"binary_mask",
|
|
||||||
"uncompressed_rle",
|
|
||||||
"coco_rle",
|
|
||||||
], f"Unknown output_mode {output_mode}."
|
|
||||||
if output_mode == "coco_rle":
|
|
||||||
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
|
||||||
|
|
||||||
if min_mask_region_area > 0:
|
|
||||||
import cv2 # type: ignore # noqa: F401
|
|
||||||
|
|
||||||
self.predictor = model
|
|
||||||
self.points_per_batch = points_per_batch
|
|
||||||
self.pred_iou_thresh = pred_iou_thresh
|
|
||||||
self.stability_score_thresh = stability_score_thresh
|
|
||||||
self.stability_score_offset = stability_score_offset
|
|
||||||
self.box_nms_thresh = box_nms_thresh
|
|
||||||
self.crop_n_layers = crop_n_layers
|
|
||||||
self.crop_nms_thresh = crop_nms_thresh
|
|
||||||
self.crop_overlap_ratio = crop_overlap_ratio
|
|
||||||
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
|
||||||
self.min_mask_region_area = min_mask_region_area
|
|
||||||
self.output_mode = output_mode
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Generates masks for the given image.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list(dict(str, any)): A list over records for masks. Each record is
|
|
||||||
a dict containing the following keys:
|
|
||||||
segmentation (dict(str, any) or np.ndarray): The mask. If
|
|
||||||
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
|
||||||
is a dictionary containing the RLE.
|
|
||||||
bbox (list(float)): The box around the mask, in XYWH format.
|
|
||||||
area (int): The area in pixels of the mask.
|
|
||||||
predicted_iou (float): The model's own prediction of the mask's
|
|
||||||
quality. This is filtered by the pred_iou_thresh parameter.
|
|
||||||
point_coords (list(list(float))): The point coordinates input
|
|
||||||
to the model to generate this mask.
|
|
||||||
stability_score (float): A measure of the mask's quality. This
|
|
||||||
is filtered on using the stability_score_thresh parameter.
|
|
||||||
crop_box (list(float)): The crop of the image used to generate
|
|
||||||
the mask, given in XYWH format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Generate masks
|
|
||||||
mask_data = self._generate_masks(image)
|
|
||||||
|
|
||||||
# Filter small disconnected regions and holes in masks
|
|
||||||
if self.min_mask_region_area > 0:
|
|
||||||
mask_data = self.postprocess_small_regions(
|
|
||||||
mask_data,
|
|
||||||
self.min_mask_region_area,
|
|
||||||
max(self.box_nms_thresh, self.crop_nms_thresh),
|
|
||||||
)
|
|
||||||
# Encode masks
|
|
||||||
if self.output_mode == "coco_rle":
|
|
||||||
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
|
||||||
elif self.output_mode == "binary_mask":
|
|
||||||
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
|
||||||
else:
|
|
||||||
mask_data["segmentations"] = mask_data["rles"]
|
|
||||||
|
|
||||||
# Write mask records
|
|
||||||
curr_anns = []
|
|
||||||
for idx in range(len(mask_data["segmentations"])):
|
|
||||||
ann = {
|
|
||||||
"segmentation": mask_data["segmentations"][idx],
|
|
||||||
"area": area_from_rle(mask_data["rles"][idx]),
|
|
||||||
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
|
||||||
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
|
||||||
"point_coords": [mask_data["points"][idx].tolist()],
|
|
||||||
"stability_score": mask_data["stability_score"][idx].item(),
|
|
||||||
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
|
||||||
}
|
|
||||||
curr_anns.append(ann)
|
|
||||||
|
|
||||||
return curr_anns
|
|
||||||
|
|
||||||
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
|
||||||
orig_size = image.shape[-2:]
|
|
||||||
crop_boxes, layer_idxs = generate_crop_boxes(
|
|
||||||
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
|
||||||
)
|
|
||||||
|
|
||||||
# Iterate over image crops
|
|
||||||
assert len(crop_boxes)==1
|
|
||||||
data = MaskData()
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
|
||||||
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
|
||||||
|
|
||||||
data.cat(crop_data)
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
# Remove duplicate masks between crops
|
|
||||||
if len(crop_boxes) > 1:
|
|
||||||
# Prefer masks from smaller crops
|
|
||||||
scores = 1 / box_area(data["crop_boxes"])
|
|
||||||
scores = scores.to(data["boxes"].device)
|
|
||||||
keep_by_nms = batched_nms(
|
|
||||||
data["boxes"].float(),
|
|
||||||
scores,
|
|
||||||
torch.zeros(len(data["boxes"])), # categories
|
|
||||||
iou_threshold=self.crop_nms_thresh,
|
|
||||||
)
|
|
||||||
data.filter(keep_by_nms)
|
|
||||||
|
|
||||||
data.to_numpy()
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _process_crop(
|
|
||||||
self,
|
|
||||||
image: np.ndarray,
|
|
||||||
crop_box: List[int],
|
|
||||||
crop_layer_idx: int,
|
|
||||||
orig_size: Tuple[int, ...],
|
|
||||||
) -> MaskData:
|
|
||||||
# Crop the image and calculate embeddings
|
|
||||||
x0, y0, x1, y1 = crop_box
|
|
||||||
cropped_im = image#[y0:y1, x0:x1, :]
|
|
||||||
cropped_im_size = cropped_im.shape[-2:]
|
|
||||||
# self.predictor.set_image(cropped_im)
|
|
||||||
|
|
||||||
# Get points for this crop
|
|
||||||
points_scale = np.array(cropped_im_size)[None, ::-1]
|
|
||||||
points_for_image = self.point_grids[crop_layer_idx] #* points_scale
|
|
||||||
|
|
||||||
# Generate masks for this crop in batches
|
|
||||||
data = MaskData()
|
|
||||||
self.enc_features=None
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
|
||||||
batch_data = self._process_batch(cropped_im,points, cropped_im_size, crop_box, orig_size)
|
|
||||||
data.cat(batch_data)
|
|
||||||
del batch_data
|
|
||||||
|
|
||||||
keep_by_nms = batched_nms(
|
|
||||||
data["boxes"].float(),
|
|
||||||
data["iou_preds"],
|
|
||||||
torch.zeros(len(data["boxes"])), # categories
|
|
||||||
iou_threshold=self.box_nms_thresh,
|
|
||||||
)
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
data.filter(keep_by_nms)
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
# Return to the original image frame
|
|
||||||
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
|
||||||
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _process_batch(
|
|
||||||
self,
|
|
||||||
images,
|
|
||||||
points: np.ndarray,
|
|
||||||
im_size: Tuple[int, ...],
|
|
||||||
crop_box: List[int],
|
|
||||||
orig_size: Tuple[int, ...],
|
|
||||||
) -> MaskData:
|
|
||||||
orig_h, orig_w = orig_size
|
|
||||||
|
|
||||||
data = {"image": images, "height": orig_h, "width": orig_w}
|
|
||||||
points=torch.tensor(points,dtype=torch.float).to(images.device)
|
|
||||||
points = torch.cat([points, points.new_tensor([[0.005, 0.005]]).repeat(len(points), 1)], dim=-1)
|
|
||||||
data['targets'] = [dict()]
|
|
||||||
data['targets'][0]['points']=points
|
|
||||||
data['targets'][0]['pb']=points.new_tensor([0.]*len(points))
|
|
||||||
batch_inputs = [data]
|
|
||||||
if self.enc_features is None:
|
|
||||||
masks, iou_preds,mask_features,multi_scale_features= self.predictor.model.evaluate_demo(batch_inputs,None,None,return_features=True, level=self.level)
|
|
||||||
self.enc_features=(mask_features,multi_scale_features)
|
|
||||||
else:
|
|
||||||
masks, iou_preds= self.predictor.model.evaluate_demo(batch_inputs,None,None,self.enc_features[0],self.enc_features[1], level=self.level)
|
|
||||||
|
|
||||||
data = MaskData(
|
|
||||||
masks=masks,
|
|
||||||
iou_preds=iou_preds.flatten(),
|
|
||||||
points=torch.as_tensor(points[:,None].repeat(1,len(self.level), 1).view(-1,4)),
|
|
||||||
)
|
|
||||||
del masks
|
|
||||||
# Filter by predicted IoU
|
|
||||||
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
|
||||||
data.filter(keep_mask)
|
|
||||||
|
|
||||||
# Calculate stability score
|
|
||||||
data["stability_score"] = calculate_stability_score(
|
|
||||||
data["masks"], 0.0, self.stability_score_offset
|
|
||||||
)
|
|
||||||
# if self.stability_score_thresh > 0.0:
|
|
||||||
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
|
||||||
data.filter(keep_mask)
|
|
||||||
|
|
||||||
# Threshold masks and calculate boxes
|
|
||||||
data["masks"] = data["masks"] > 0.0
|
|
||||||
data["boxes"] = batched_mask_to_box(data["masks"])
|
|
||||||
|
|
||||||
# Filter boxes that touch crop boundaries
|
|
||||||
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
|
||||||
if not torch.all(keep_mask):
|
|
||||||
data.filter(keep_mask)
|
|
||||||
|
|
||||||
# Compress to RLE
|
|
||||||
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
|
||||||
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
|
||||||
del data["masks"]
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def postprocess_small_regions(
|
|
||||||
mask_data: MaskData, min_area: int, nms_thresh: float
|
|
||||||
) -> MaskData:
|
|
||||||
"""
|
|
||||||
Removes small disconnected regions and holes in masks, then reruns
|
|
||||||
box NMS to remove any new duplicates.
|
|
||||||
|
|
||||||
Edits mask_data in place.
|
|
||||||
|
|
||||||
Requires open-cv as a dependency.
|
|
||||||
"""
|
|
||||||
if len(mask_data["rles"]) == 0:
|
|
||||||
return mask_data
|
|
||||||
|
|
||||||
# Filter small disconnected regions and holes
|
|
||||||
new_masks = []
|
|
||||||
scores = []
|
|
||||||
for rle in mask_data["rles"]:
|
|
||||||
mask = rle_to_mask(rle)
|
|
||||||
|
|
||||||
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
|
||||||
unchanged = not changed
|
|
||||||
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
|
||||||
unchanged = unchanged and not changed
|
|
||||||
|
|
||||||
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
|
||||||
# Give score=0 to changed masks and score=1 to unchanged masks
|
|
||||||
# so NMS will prefer ones that didn't need postprocessing
|
|
||||||
scores.append(float(unchanged))
|
|
||||||
|
|
||||||
# Recalculate boxes and remove any new duplicates
|
|
||||||
masks = torch.cat(new_masks, dim=0)
|
|
||||||
boxes = batched_mask_to_box(masks)
|
|
||||||
keep_by_nms = batched_nms(
|
|
||||||
boxes.float(),
|
|
||||||
torch.as_tensor(scores),
|
|
||||||
torch.zeros(len(boxes)), # categories
|
|
||||||
iou_threshold=nms_thresh,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only recalculate RLEs for masks that have changed
|
|
||||||
for i_mask in keep_by_nms:
|
|
||||||
if scores[i_mask] == 0.0:
|
|
||||||
mask_torch = masks[i_mask].unsqueeze(0)
|
|
||||||
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
|
||||||
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
|
||||||
mask_data.filter(keep_by_nms)
|
|
||||||
|
|
||||||
return mask_data
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
# --------------------------------------------------------
|
|
||||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
|
||||||
# Copyright (c) 2023 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torchvision import transforms
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
from typing import Tuple
|
|
||||||
from PIL import Image
|
|
||||||
from detectron2.data import MetadataCatalog
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import cv2
|
|
||||||
import io
|
|
||||||
from .automatic_mask_generator import SemanticSamAutomaticMaskGenerator
|
|
||||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
|
||||||
|
|
||||||
def inference_semsam_m2m_auto(model, image, level, all_classes, all_parts, thresh, text_size, hole_scale, island_scale, semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
|
||||||
t = []
|
|
||||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
|
||||||
transform1 = transforms.Compose(t)
|
|
||||||
image_ori = transform1(image)
|
|
||||||
|
|
||||||
image_ori = np.asarray(image_ori)
|
|
||||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
|
||||||
|
|
||||||
mask_generator = SemanticSamAutomaticMaskGenerator(model,points_per_side=32,
|
|
||||||
pred_iou_thresh=0.88,
|
|
||||||
stability_score_thresh=0.92,
|
|
||||||
min_mask_region_area=10,
|
|
||||||
level=level,
|
|
||||||
)
|
|
||||||
outputs = mask_generator.generate(images)
|
|
||||||
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
visual = Visualizer(image_ori, metadata=metadata)
|
|
||||||
sorted_anns = sorted(outputs, key=(lambda x: x['area']), reverse=True)
|
|
||||||
label = 1
|
|
||||||
# for ann in sorted_anns:
|
|
||||||
# mask = ann['segmentation']
|
|
||||||
# color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
# # color_mask = [int(c*255) for c in color_mask]
|
|
||||||
# demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
|
||||||
# label += 1
|
|
||||||
# im = demo.get_image()
|
|
||||||
|
|
||||||
mask_map = np.zeros(image_ori.shape, dtype=np.uint8)
|
|
||||||
for i, ann in enumerate(sorted_anns):
|
|
||||||
mask = ann['segmentation']
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
# color_mask = [int(c*255) for c in color_mask]
|
|
||||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
|
||||||
# assign the mask to the mask_map
|
|
||||||
mask_map[mask == 1] = label
|
|
||||||
label += 1
|
|
||||||
im = demo.get_image()
|
|
||||||
# fig=plt.figure(figsize=(10, 10))
|
|
||||||
# plt.imshow(image_ori)
|
|
||||||
# show_anns(outputs)
|
|
||||||
# fig.canvas.draw()
|
|
||||||
# im=Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
|
|
||||||
return im, sorted_anns
|
|
||||||
|
|
||||||
|
|
||||||
def remove_small_regions(
|
|
||||||
mask: np.ndarray, area_thresh: float, mode: str
|
|
||||||
) -> Tuple[np.ndarray, bool]:
|
|
||||||
"""
|
|
||||||
Removes small disconnected regions and holes in a mask. Returns the
|
|
||||||
mask and an indicator of if the mask has been modified.
|
|
||||||
"""
|
|
||||||
import cv2 # type: ignore
|
|
||||||
|
|
||||||
assert mode in ["holes", "islands"]
|
|
||||||
correct_holes = mode == "holes"
|
|
||||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
||||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
||||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
||||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
||||||
if len(small_regions) == 0:
|
|
||||||
return mask, False
|
|
||||||
fill_labels = [0] + small_regions
|
|
||||||
if not correct_holes:
|
|
||||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
|
||||||
# If every region is below threshold, keep largest
|
|
||||||
if len(fill_labels) == 0:
|
|
||||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
|
||||||
mask = np.isin(regions, fill_labels)
|
|
||||||
return mask, True
|
|
||||||
|
|
||||||
def show_anns(anns):
|
|
||||||
if len(anns) == 0:
|
|
||||||
return
|
|
||||||
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
|
||||||
ax = plt.gca()
|
|
||||||
ax.set_autoscale_on(False)
|
|
||||||
polygons = []
|
|
||||||
color = []
|
|
||||||
for ann in sorted_anns:
|
|
||||||
m = ann['segmentation']
|
|
||||||
img = np.ones((m.shape[0], m.shape[1], 3))
|
|
||||||
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
||||||
for i in range(3):
|
|
||||||
img[:,:,i] = color_mask[i]
|
|
||||||
ax.imshow(np.dstack((img, m*0.35)))
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
# --------------------------------------------------------
|
|
||||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
|
||||||
# Copyright (c) 2023 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torchvision import transforms
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
from typing import Tuple
|
|
||||||
from PIL import Image
|
|
||||||
from detectron2.data import MetadataCatalog
|
|
||||||
from detectron2.structures import BitMasks
|
|
||||||
from semantic_sam.utils import box_ops
|
|
||||||
|
|
||||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
|
||||||
|
|
||||||
def interactive_infer_image_box(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
|
|
||||||
t = []
|
|
||||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
|
||||||
transform1 = transforms.Compose(t)
|
|
||||||
image_ori = transform1(image['image'])
|
|
||||||
mask_ori = transform1(image['mask'])
|
|
||||||
width = image_ori.size[0]
|
|
||||||
height = image_ori.size[1]
|
|
||||||
image_ori = np.asarray(image_ori)
|
|
||||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
|
||||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
|
||||||
|
|
||||||
|
|
||||||
data = {"image": images, "height": height, "width": width}
|
|
||||||
|
|
||||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
|
||||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
|
||||||
flaten_mask = mask_ori.unsqueeze(0)
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
points=mask_ori.nonzero().float().to(images.device)
|
|
||||||
if len(points)==0:
|
|
||||||
point_=point=points.new_tensor([[0.5,0.5,0.5,0.5]])
|
|
||||||
else:
|
|
||||||
mean_point=points.mean(0)[None]
|
|
||||||
box_xyxy = BitMasks(flaten_mask > 0).get_bounding_boxes().tensor
|
|
||||||
h = mask_ori.shape[0]
|
|
||||||
w = mask_ori.shape[1]
|
|
||||||
box_xywh = (box_ops.box_xyxy_to_cxcywh(box_xyxy) / torch.as_tensor([w, h, w, h])).cuda()
|
|
||||||
|
|
||||||
# point_=points.mean(0)[None]
|
|
||||||
# point=point_.clone()
|
|
||||||
# point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
|
||||||
# point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
|
||||||
# point = point[:, [1, 0]]
|
|
||||||
point=box_xywh
|
|
||||||
data['targets'] = [dict()]
|
|
||||||
data['targets'][0]['points']=point
|
|
||||||
data['targets'][0]['pb']=point.new_tensor([1.])
|
|
||||||
|
|
||||||
|
|
||||||
batch_inputs = [data]
|
|
||||||
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts, task='demo_box')
|
|
||||||
|
|
||||||
pred_masks_poses = masks
|
|
||||||
reses=[]
|
|
||||||
ious=ious[0,0]
|
|
||||||
ids=torch.argsort(ious,descending=True)
|
|
||||||
|
|
||||||
text_res=''
|
|
||||||
try:
|
|
||||||
thresh=float(thresh)
|
|
||||||
except Exception:
|
|
||||||
thresh=0.0
|
|
||||||
mask_ls=[]
|
|
||||||
ious_res=[]
|
|
||||||
areas=[]
|
|
||||||
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
|
|
||||||
iou=round(float(iou),2)
|
|
||||||
texts=f'{iou}'
|
|
||||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
|
||||||
area=mask.sum()
|
|
||||||
conti=False
|
|
||||||
if iou<thresh:
|
|
||||||
conti=True
|
|
||||||
for m in mask_ls:
|
|
||||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
|
||||||
conti=True
|
|
||||||
break
|
|
||||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
|
||||||
conti=False
|
|
||||||
if conti:
|
|
||||||
continue
|
|
||||||
ious_res.append(iou)
|
|
||||||
mask_ls.append(mask)
|
|
||||||
areas.append(area)
|
|
||||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
|
||||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
|
||||||
mask=(mask).astype(np.float)
|
|
||||||
out_txt = texts
|
|
||||||
visual = Visualizer(image_ori, metadata=metadata)
|
|
||||||
color=[0.,0.,1.0]
|
|
||||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
|
||||||
demo = visual.draw_box(box_xyxy[0])
|
|
||||||
res = demo.get_image()
|
|
||||||
# point_x0=max(0,int(point_[0, 1])-3)
|
|
||||||
# point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
|
||||||
# point_y0 = max(0, int(point_[0, 0]) - 3)
|
|
||||||
# point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
|
||||||
# res[point_y0:point_y1,point_x0:point_x1,0]=255
|
|
||||||
# res[point_y0:point_y1,point_x0:point_x1,1]=0
|
|
||||||
# res[point_y0:point_y1,point_x0:point_x1,2]=0
|
|
||||||
reses.append(Image.fromarray(res))
|
|
||||||
text_res=text_res+';'+out_txt
|
|
||||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
|
||||||
ids = [int(i) for i in ids]
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return reses,[reses[i] for i in ids]
|
|
||||||
|
|
||||||
def remove_small_regions(
|
|
||||||
mask: np.ndarray, area_thresh: float, mode: str
|
|
||||||
) -> Tuple[np.ndarray, bool]:
|
|
||||||
"""
|
|
||||||
Removes small disconnected regions and holes in a mask. Returns the
|
|
||||||
mask and an indicator of if the mask has been modified.
|
|
||||||
"""
|
|
||||||
import cv2 # type: ignore
|
|
||||||
|
|
||||||
assert mode in ["holes", "islands"]
|
|
||||||
correct_holes = mode == "holes"
|
|
||||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
||||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
||||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
||||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
||||||
if len(small_regions) == 0:
|
|
||||||
return mask, False
|
|
||||||
fill_labels = [0] + small_regions
|
|
||||||
if not correct_holes:
|
|
||||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
|
||||||
# If every region is below threshold, keep largest
|
|
||||||
if len(fill_labels) == 0:
|
|
||||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
|
||||||
mask = np.isin(regions, fill_labels)
|
|
||||||
return mask, True
|
|
||||||
@@ -1,322 +0,0 @@
|
|||||||
# --------------------------------------------------------
|
|
||||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
|
||||||
# Copyright (c) 2023 Microsoft
|
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
|
||||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
|
||||||
# --------------------------------------------------------
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torchvision import transforms
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
from typing import Tuple
|
|
||||||
from PIL import Image
|
|
||||||
from detectron2.data import MetadataCatalog
|
|
||||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
|
||||||
|
|
||||||
def interactive_infer_image(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None, label_mode='1', alpha=0.1, anno_mode=['Mask']):
|
|
||||||
t = []
|
|
||||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
|
||||||
transform1 = transforms.Compose(t)
|
|
||||||
image_ori = transform1(image['image'])
|
|
||||||
mask_ori = transform1(image['mask'])
|
|
||||||
width = image_ori.size[0]
|
|
||||||
height = image_ori.size[1]
|
|
||||||
image_ori = np.asarray(image_ori)
|
|
||||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
|
||||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
|
||||||
|
|
||||||
|
|
||||||
data = {"image": images, "height": height, "width": width}
|
|
||||||
|
|
||||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
|
||||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
|
||||||
points=mask_ori.nonzero().float().to(images.device)
|
|
||||||
if len(points)==0:
|
|
||||||
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
|
|
||||||
else:
|
|
||||||
point_=points.mean(0)[None]
|
|
||||||
point=point_.clone()
|
|
||||||
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
|
||||||
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
|
||||||
point = point[:, [1, 0]]
|
|
||||||
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
|
|
||||||
data['targets'] = [dict()]
|
|
||||||
data['targets'][0]['points']=point
|
|
||||||
data['targets'][0]['pb']=point.new_tensor([0.])
|
|
||||||
|
|
||||||
|
|
||||||
batch_inputs = [data]
|
|
||||||
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts)
|
|
||||||
|
|
||||||
pred_masks_poses = masks
|
|
||||||
reses=[]
|
|
||||||
ious=ious[0,0]
|
|
||||||
ids=torch.argsort(ious,descending=True)
|
|
||||||
|
|
||||||
text_res=''
|
|
||||||
try:
|
|
||||||
thresh=float(thresh)
|
|
||||||
except Exception:
|
|
||||||
thresh=0.0
|
|
||||||
mask_ls=[]
|
|
||||||
ious_res=[]
|
|
||||||
areas=[]
|
|
||||||
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
|
|
||||||
iou=round(float(iou),2)
|
|
||||||
texts=f'{iou}'
|
|
||||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
|
||||||
area=mask.sum()
|
|
||||||
conti=False
|
|
||||||
if iou<thresh:
|
|
||||||
conti=True
|
|
||||||
for m in mask_ls:
|
|
||||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
|
||||||
conti=True
|
|
||||||
break
|
|
||||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
|
||||||
conti=False
|
|
||||||
if conti:
|
|
||||||
continue
|
|
||||||
ious_res.append(iou)
|
|
||||||
mask_ls.append(mask)
|
|
||||||
areas.append(area)
|
|
||||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
|
||||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
|
||||||
mask=(mask).astype(np.float)
|
|
||||||
out_txt = texts
|
|
||||||
visual = Visualizer(image_ori, metadata=metadata)
|
|
||||||
color=[0.,0.,1.0]
|
|
||||||
# demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
|
||||||
demo = visual.draw_binary_mask_with_number(mask, text=str(label), label_mode=label_mode, alpha=alpha, anno_mode=anno_mode)
|
|
||||||
res = demo.get_image()
|
|
||||||
point_x0=max(0,int(point_[0, 1])-3)
|
|
||||||
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
|
||||||
point_y0 = max(0, int(point_[0, 0]) - 3)
|
|
||||||
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
|
||||||
# res[point_y0:point_y1,point_x0:point_x1,0]=255
|
|
||||||
# res[point_y0:point_y1,point_x0:point_x1,1]=0
|
|
||||||
# res[point_y0:point_y1,point_x0:point_x1,2]=0
|
|
||||||
reses.append(Image.fromarray(res))
|
|
||||||
text_res=text_res+';'+out_txt
|
|
||||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
|
||||||
ids = [int(i) for i in ids]
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return reses,[reses[i] for i in ids]
|
|
||||||
|
|
||||||
def interactive_infer_image_3l(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
|
|
||||||
t = []
|
|
||||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
|
||||||
transform1 = transforms.Compose(t)
|
|
||||||
image_ori = transform1(image['image'])
|
|
||||||
mask_ori = transform1(image['mask'])
|
|
||||||
width = image_ori.size[0]
|
|
||||||
height = image_ori.size[1]
|
|
||||||
image_ori = np.asarray(image_ori)
|
|
||||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
|
||||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
|
||||||
|
|
||||||
|
|
||||||
data = {"image": images, "height": height, "width": width}
|
|
||||||
|
|
||||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
|
||||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
|
||||||
points=mask_ori.nonzero().float().to(images.device)
|
|
||||||
if len(points)==0:
|
|
||||||
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
|
|
||||||
else:
|
|
||||||
point_=points.mean(0)[None]
|
|
||||||
point=point_.clone()
|
|
||||||
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
|
||||||
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
|
||||||
point = point[:, [1, 0]]
|
|
||||||
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
|
|
||||||
data['targets'] = [dict()]
|
|
||||||
data['targets'][0]['points']=point
|
|
||||||
data['targets'][0]['pb']=point.new_tensor([0.])
|
|
||||||
|
|
||||||
|
|
||||||
batch_inputs = [data]
|
|
||||||
masks, ious, pred_class, pred_class_score = model.model.evaluate_demo(batch_inputs,all_classes,all_parts, level=[0,1,2])
|
|
||||||
|
|
||||||
pred_masks_poses = masks
|
|
||||||
reses=[]
|
|
||||||
ious=ious[0,0]
|
|
||||||
ids=torch.argsort(ious,descending=True)
|
|
||||||
|
|
||||||
text_res=''
|
|
||||||
try:
|
|
||||||
thresh=float(thresh)
|
|
||||||
except Exception:
|
|
||||||
thresh=0.0
|
|
||||||
mask_ls=[]
|
|
||||||
ious_res=[]
|
|
||||||
areas=[]
|
|
||||||
new_pred_class = []
|
|
||||||
new_pred_class_score = []
|
|
||||||
for i in ids:
|
|
||||||
new_pred_class_score.append(pred_class_score[i])
|
|
||||||
new_pred_class.append(pred_class[i])
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
for i,(pred_masks_pos,iou, cls_name, cls_score) in enumerate(zip(pred_masks_poses[ids],ious[ids], new_pred_class, new_pred_class_score)):
|
|
||||||
iou=round(float(iou),2)
|
|
||||||
texts=f'{iou}_{cls_name}_{cls_score}'
|
|
||||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
|
||||||
area=mask.sum()
|
|
||||||
conti=False
|
|
||||||
if iou<thresh:
|
|
||||||
conti=True
|
|
||||||
for m in mask_ls:
|
|
||||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
|
||||||
conti=True
|
|
||||||
break
|
|
||||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
|
||||||
conti=False
|
|
||||||
if conti:
|
|
||||||
continue
|
|
||||||
ious_res.append(iou)
|
|
||||||
mask_ls.append(mask)
|
|
||||||
areas.append(area)
|
|
||||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
|
||||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
|
||||||
mask=(mask).astype(np.float)
|
|
||||||
out_txt = texts
|
|
||||||
visual = Visualizer(image_ori, metadata=metadata)
|
|
||||||
color=[0.,0.,1.0]
|
|
||||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
|
||||||
res = demo.get_image()
|
|
||||||
point_x0=max(0,int(point_[0, 1])-3)
|
|
||||||
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
|
||||||
point_y0 = max(0, int(point_[0, 0]) - 3)
|
|
||||||
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
|
||||||
res[point_y0:point_y1,point_x0:point_x1,0]=255
|
|
||||||
res[point_y0:point_y1,point_x0:point_x1,1]=0
|
|
||||||
res[point_y0:point_y1,point_x0:point_x1,2]=0
|
|
||||||
reses.append(Image.fromarray(res))
|
|
||||||
text_res=text_res+';'+out_txt
|
|
||||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
|
||||||
ids = [int(i) for i in ids]
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return reses,[reses[i] for i in ids]
|
|
||||||
|
|
||||||
def interactive_infer_image_semantic(model, image,all_classes,all_parts, thresh,text_size,hole_scale,island_scale,semantic, refimg=None, reftxt=None, audio_pth=None, video_pth=None):
|
|
||||||
t = []
|
|
||||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
|
||||||
transform1 = transforms.Compose(t)
|
|
||||||
image_ori = transform1(image['image'])
|
|
||||||
mask_ori = transform1(image['mask'])
|
|
||||||
width = image_ori.size[0]
|
|
||||||
height = image_ori.size[1]
|
|
||||||
image_ori = np.asarray(image_ori)
|
|
||||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
|
||||||
all_classes, all_parts=all_classes.strip().strip("\"[]").split(':'),all_parts.strip().strip("\"[]").split(':')
|
|
||||||
|
|
||||||
|
|
||||||
data = {"image": images, "height": height, "width": width}
|
|
||||||
|
|
||||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
|
||||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)[0]
|
|
||||||
points=mask_ori.nonzero().float().to(images.device)
|
|
||||||
if len(points)==0:
|
|
||||||
point_=point=points.new_tensor([[0.5,0.5,0.006,0.006]])
|
|
||||||
else:
|
|
||||||
point_=points.mean(0)[None]
|
|
||||||
point=point_.clone()
|
|
||||||
point[0, 0] = point_[0, 0] / mask_ori.shape[0]
|
|
||||||
point[0, 1] = point_[0, 1] / mask_ori.shape[1]
|
|
||||||
point = point[:, [1, 0]]
|
|
||||||
point=torch.cat([point,points.new_tensor([[0.005,0.005]])],dim=-1)
|
|
||||||
data['targets'] = [dict()]
|
|
||||||
data['targets'][0]['points']=point
|
|
||||||
data['targets'][0]['pb']=point.new_tensor([0.])
|
|
||||||
data['targets'][0]['pb']=point.new_tensor([1.])
|
|
||||||
|
|
||||||
|
|
||||||
batch_inputs = [data]
|
|
||||||
masks,ious = model.model.evaluate_demo(batch_inputs,all_classes,all_parts)
|
|
||||||
|
|
||||||
pred_masks_poses = masks
|
|
||||||
reses=[]
|
|
||||||
ious=ious[0,0]
|
|
||||||
ids=torch.argsort(ious,descending=True)
|
|
||||||
|
|
||||||
text_res=''
|
|
||||||
try:
|
|
||||||
thresh=float(thresh)
|
|
||||||
except Exception:
|
|
||||||
thresh=0.0
|
|
||||||
mask_ls=[]
|
|
||||||
ious_res=[]
|
|
||||||
areas=[]
|
|
||||||
for i,(pred_masks_pos,iou) in enumerate(zip(pred_masks_poses[ids],ious[ids])):
|
|
||||||
iou=round(float(iou),2)
|
|
||||||
texts=f'{iou}'
|
|
||||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
|
||||||
area=mask.sum()
|
|
||||||
conti=False
|
|
||||||
if iou<thresh:
|
|
||||||
conti=True
|
|
||||||
for m in mask_ls:
|
|
||||||
if np.logical_and(mask,m).sum()/np.logical_or(mask,m).sum()>0.95:
|
|
||||||
conti=True
|
|
||||||
break
|
|
||||||
if i == len(pred_masks_poses[ids])-1 and mask_ls==[]:
|
|
||||||
conti=False
|
|
||||||
if conti:
|
|
||||||
continue
|
|
||||||
ious_res.append(iou)
|
|
||||||
mask_ls.append(mask)
|
|
||||||
areas.append(area)
|
|
||||||
mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
|
||||||
mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
|
||||||
mask=(mask).astype(np.float)
|
|
||||||
out_txt = texts
|
|
||||||
visual = Visualizer(image_ori, metadata=metadata)
|
|
||||||
color=[0.,0.,1.0]
|
|
||||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
|
||||||
res = demo.get_image()
|
|
||||||
point_x0=max(0,int(point_[0, 1])-3)
|
|
||||||
point_x1=min(mask_ori.shape[1],int(point_[0, 1])+3)
|
|
||||||
point_y0 = max(0, int(point_[0, 0]) - 3)
|
|
||||||
point_y1 = min(mask_ori.shape[0], int(point_[0, 0]) + 3)
|
|
||||||
res[point_y0:point_y1,point_x0:point_x1,0]=255
|
|
||||||
res[point_y0:point_y1,point_x0:point_x1,1]=0
|
|
||||||
res[point_y0:point_y1,point_x0:point_x1,2]=0
|
|
||||||
reses.append(Image.fromarray(res))
|
|
||||||
text_res=text_res+';'+out_txt
|
|
||||||
ids=list(torch.argsort(torch.tensor(areas),descending=False))
|
|
||||||
ids = [int(i) for i in ids]
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return reses,[reses[i] for i in ids]
|
|
||||||
|
|
||||||
def remove_small_regions(
|
|
||||||
mask: np.ndarray, area_thresh: float, mode: str
|
|
||||||
) -> Tuple[np.ndarray, bool]:
|
|
||||||
"""
|
|
||||||
Removes small disconnected regions and holes in a mask. Returns the
|
|
||||||
mask and an indicator of if the mask has been modified.
|
|
||||||
"""
|
|
||||||
import cv2 # type: ignore
|
|
||||||
|
|
||||||
assert mode in ["holes", "islands"]
|
|
||||||
correct_holes = mode == "holes"
|
|
||||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
||||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
||||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
||||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
||||||
if len(small_regions) == 0:
|
|
||||||
return mask, False
|
|
||||||
fill_labels = [0] + small_regions
|
|
||||||
if not correct_holes:
|
|
||||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
|
||||||
# If every region is below threshold, keep largest
|
|
||||||
if len(fill_labels) == 0:
|
|
||||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
|
||||||
mask = np.isin(regions, fill_labels)
|
|
||||||
return mask, True
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from torchvision import transforms
|
|
||||||
from task_adapter.utils.visualizer import Visualizer
|
|
||||||
from typing import Tuple
|
|
||||||
from PIL import Image
|
|
||||||
from detectron2.data import MetadataCatalog
|
|
||||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
|
||||||
|
|
||||||
|
|
||||||
class SemanticSAMPredictor:
|
|
||||||
def __init__(self, model, thresh=0.5, text_size=640, hole_scale=100, island_scale=100):
|
|
||||||
"""
|
|
||||||
thresh: iou thresh to filter low confidence objects
|
|
||||||
text_size: resize the input image short edge for the model to process
|
|
||||||
hole_scale: fill in small holes as in SAM
|
|
||||||
island_scale: remove small regions as in SAM
|
|
||||||
"""
|
|
||||||
self.model = model
|
|
||||||
self.thresh = thresh
|
|
||||||
self.text_size = hole_scale
|
|
||||||
self.hole_scale = hole_scale
|
|
||||||
self.island_scale = island_scale
|
|
||||||
self.point = None
|
|
||||||
|
|
||||||
def predict(self, image_ori, image, point=None):
|
|
||||||
"""
|
|
||||||
produce up to 6 prediction results for each click
|
|
||||||
"""
|
|
||||||
width = image_ori.shape[0]
|
|
||||||
height = image_ori.shape[1]
|
|
||||||
|
|
||||||
data = {"image": image, "height": height, "width": width}
|
|
||||||
# import ipdb; ipdb.set_trace()
|
|
||||||
if point is None:
|
|
||||||
point = torch.tensor([[0.5, 0.5, 0.006, 0.006]]).cuda()
|
|
||||||
else:
|
|
||||||
point = torch.tensor(point).cuda()
|
|
||||||
point_ = point
|
|
||||||
point = point_.clone()
|
|
||||||
point[0, 0] = point_[0, 0]
|
|
||||||
point[0, 1] = point_[0, 1]
|
|
||||||
# point = point[:, [1, 0]]
|
|
||||||
point = torch.cat([point, point.new_tensor([[0.005, 0.005]])], dim=-1)
|
|
||||||
|
|
||||||
self.point = point[:, :2].clone()*(torch.tensor([width, height]).to(point))
|
|
||||||
|
|
||||||
data['targets'] = [dict()]
|
|
||||||
data['targets'][0]['points'] = point
|
|
||||||
data['targets'][0]['pb'] = point.new_tensor([0.])
|
|
||||||
|
|
||||||
batch_inputs = [data]
|
|
||||||
masks, ious = self.model.model.evaluate_demo(batch_inputs)
|
|
||||||
|
|
||||||
return masks, ious
|
|
||||||
|
|
||||||
def process_multi_mask(self, masks, ious, image_ori):
|
|
||||||
pred_masks_poses = masks
|
|
||||||
reses = []
|
|
||||||
ious = ious[0, 0]
|
|
||||||
ids = torch.argsort(ious, descending=True)
|
|
||||||
|
|
||||||
text_res = ''
|
|
||||||
mask_ls = []
|
|
||||||
ious_res = []
|
|
||||||
areas = []
|
|
||||||
for i, (pred_masks_pos, iou) in enumerate(zip(pred_masks_poses[ids], ious[ids])):
|
|
||||||
iou = round(float(iou), 2)
|
|
||||||
texts = f'{iou}'
|
|
||||||
mask = (pred_masks_pos > 0.0).cpu().numpy()
|
|
||||||
area = mask.sum()
|
|
||||||
conti = False
|
|
||||||
if iou < self.thresh:
|
|
||||||
conti = True
|
|
||||||
for m in mask_ls:
|
|
||||||
if np.logical_and(mask, m).sum() / np.logical_or(mask, m).sum() > 0.95:
|
|
||||||
conti = True
|
|
||||||
break
|
|
||||||
if i == len(pred_masks_poses[ids]) - 1 and mask_ls == []:
|
|
||||||
conti = False
|
|
||||||
if conti:
|
|
||||||
continue
|
|
||||||
ious_res.append(iou)
|
|
||||||
mask_ls.append(mask)
|
|
||||||
areas.append(area)
|
|
||||||
mask, _ = self.remove_small_regions(mask, int(self.hole_scale), mode="holes")
|
|
||||||
mask, _ = self.remove_small_regions(mask, int(self.island_scale), mode="islands")
|
|
||||||
mask = (mask).astype(np.float)
|
|
||||||
out_txt = texts
|
|
||||||
visual = Visualizer(image_ori, metadata=metadata)
|
|
||||||
color = [0., 0., 1.0]
|
|
||||||
demo = visual.draw_binary_mask(mask, color=color, text=texts)
|
|
||||||
res = demo.get_image()
|
|
||||||
point_x0 = max(0, int(self.point[0, 0]) - 3)
|
|
||||||
point_x1 = min(image_ori.shape[1], int(self.point[0, 0]) + 3)
|
|
||||||
point_y0 = max(0, int(self.point[0, 1]) - 3)
|
|
||||||
point_y1 = min(image_ori.shape[0], int(self.point[0, 1]) + 3)
|
|
||||||
res[point_y0:point_y1, point_x0:point_x1, 0] = 255
|
|
||||||
res[point_y0:point_y1, point_x0:point_x1, 1] = 0
|
|
||||||
res[point_y0:point_y1, point_x0:point_x1, 2] = 0
|
|
||||||
reses.append(Image.fromarray(res))
|
|
||||||
text_res = text_res + ';' + out_txt
|
|
||||||
ids = list(torch.argsort(torch.tensor(areas), descending=False))
|
|
||||||
ids = [int(i) for i in ids]
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return reses, [reses[i] for i in ids]
|
|
||||||
|
|
||||||
def predict_masks(self, image_ori, image, point=None):
|
|
||||||
masks, ious = self.predict(image_ori, image, point)
|
|
||||||
return self.process_multi_mask(masks, ious, image_ori)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def remove_small_regions(
|
|
||||||
mask: np.ndarray, area_thresh: float, mode: str
|
|
||||||
) -> Tuple[np.ndarray, bool]:
|
|
||||||
"""
|
|
||||||
Removes small disconnected regions and holes in a mask. Returns the
|
|
||||||
mask and an indicator of if the mask has been modified.
|
|
||||||
"""
|
|
||||||
import cv2 # type: ignore
|
|
||||||
|
|
||||||
assert mode in ["holes", "islands"]
|
|
||||||
correct_holes = mode == "holes"
|
|
||||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
||||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
||||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
||||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
||||||
if len(small_regions) == 0:
|
|
||||||
return mask, False
|
|
||||||
fill_labels = [0] + small_regions
|
|
||||||
if not correct_holes:
|
|
||||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
|
||||||
# If every region is below threshold, keep largest
|
|
||||||
if len(fill_labels) == 0:
|
|
||||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
|
||||||
mask = np.isin(regions, fill_labels)
|
|
||||||
return mask, True
|
|
||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user