Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9851374master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:59b1da30af12f76b691990363e0d221050a59cf53fc4a97e776bcb00228c6c2a | |||
| size 245864 | |||
| @@ -23,6 +23,8 @@ class Models(object): | |||
| panoptic_segmentation = 'swinL-panoptic-segmentation' | |||
| image_reid_person = 'passvitb' | |||
| video_summarization = 'pgl-video-summarization' | |||
| swinL_semantic_segmentation = 'swinL-semantic-segmentation' | |||
| vitadapter_semantic_segmentation = 'vitadapter-semantic-segmentation' | |||
| # nlp models | |||
| bert = 'bert' | |||
| @@ -117,6 +119,7 @@ class Pipelines(object): | |||
| video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' | |||
| image_panoptic_segmentation = 'image-panoptic-segmentation' | |||
| video_summarization = 'googlenet_pgl_video_summarization' | |||
| image_semantic_segmentation = 'image-semantic-segmentation' | |||
| image_reid_person = 'passvitb-image-reid-person' | |||
| # nlp tasks | |||
| @@ -4,8 +4,8 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
| face_generation, image_classification, image_color_enhance, | |||
| image_colorization, image_denoise, image_instance_segmentation, | |||
| image_panoptic_segmentation, image_portrait_enhancement, | |||
| image_reid_person, image_to_image_generation, | |||
| image_to_image_translation, object_detection, | |||
| product_retrieval_embedding, salient_detection, | |||
| super_resolution, video_single_object_tracking, | |||
| video_summarization, virual_tryon) | |||
| image_reid_person, image_semantic_segmentation, | |||
| image_to_image_generation, image_to_image_translation, | |||
| object_detection, product_retrieval_embedding, | |||
| salient_detection, super_resolution, | |||
| video_single_object_tracking, video_summarization, virual_tryon) | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .semantic_seg_model import SemanticSegmentation | |||
| else: | |||
| _import_structure = { | |||
| 'semantic_seg_model': ['SemanticSegmentation'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1 @@ | |||
| from .maskformer_semantic_head import MaskFormerSemanticHead | |||
| @@ -0,0 +1,47 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| from abc import ABCMeta, abstractmethod | |||
| from mmcv.runner import BaseModule | |||
| from mmdet.models.builder import build_loss | |||
| class BasePanopticFusionHead(BaseModule, metaclass=ABCMeta): | |||
| """Base class for panoptic heads.""" | |||
| def __init__(self, | |||
| num_things_classes=80, | |||
| num_stuff_classes=53, | |||
| test_cfg=None, | |||
| loss_panoptic=None, | |||
| init_cfg=None, | |||
| **kwargs): | |||
| super(BasePanopticFusionHead, self).__init__(init_cfg) | |||
| self.num_things_classes = num_things_classes | |||
| self.num_stuff_classes = num_stuff_classes | |||
| self.num_classes = num_things_classes + num_stuff_classes | |||
| self.test_cfg = test_cfg | |||
| if loss_panoptic: | |||
| self.loss_panoptic = build_loss(loss_panoptic) | |||
| else: | |||
| self.loss_panoptic = None | |||
| @property | |||
| def with_loss(self): | |||
| """bool: whether the panoptic head contains loss function.""" | |||
| return self.loss_panoptic is not None | |||
| @abstractmethod | |||
| def forward_train(self, gt_masks=None, gt_semantic_seg=None, **kwargs): | |||
| """Forward function during training.""" | |||
| @abstractmethod | |||
| def simple_test(self, | |||
| img_metas, | |||
| det_labels, | |||
| mask_preds, | |||
| seg_preds, | |||
| det_bboxes, | |||
| cfg=None, | |||
| **kwargs): | |||
| """Test without augmentation.""" | |||
| @@ -0,0 +1,57 @@ | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from mmdet.models.builder import HEADS | |||
| from .base_panoptic_fusion_head import BasePanopticFusionHead | |||
| @HEADS.register_module() | |||
| class MaskFormerSemanticHead(BasePanopticFusionHead): | |||
| def __init__(self, | |||
| num_things_classes=80, | |||
| num_stuff_classes=53, | |||
| test_cfg=None, | |||
| loss_panoptic=None, | |||
| init_cfg=None, | |||
| **kwargs): | |||
| super().__init__(num_things_classes, num_stuff_classes, test_cfg, | |||
| loss_panoptic, init_cfg, **kwargs) | |||
| def forward_train(self, **kwargs): | |||
| """MaskFormerFusionHead has no training loss.""" | |||
| return dict() | |||
| def simple_test(self, | |||
| mask_cls_results, | |||
| mask_pred_results, | |||
| img_metas, | |||
| rescale=False, | |||
| **kwargs): | |||
| results = [] | |||
| for mask_cls_result, mask_pred_result, meta in zip( | |||
| mask_cls_results, mask_pred_results, img_metas): | |||
| # remove padding | |||
| img_height, img_width = meta['img_shape'][:2] | |||
| mask_pred_result = mask_pred_result[:, :img_height, :img_width] | |||
| if rescale: | |||
| # return result in original resolution | |||
| ori_height, ori_width = meta['ori_shape'][:2] | |||
| mask_pred_result = F.interpolate( | |||
| mask_pred_result[:, None], | |||
| size=(ori_height, ori_width), | |||
| mode='bilinear', | |||
| align_corners=False)[:, 0] | |||
| # semantic inference | |||
| cls_score = F.softmax(mask_cls_result, dim=-1)[..., :-1] | |||
| mask_pred = mask_pred_result.sigmoid() | |||
| seg_mask = torch.einsum('qc,qhw->chw', cls_score, mask_pred) | |||
| # still need softmax and argmax | |||
| seg_logit = F.softmax(seg_mask, dim=0) | |||
| seg_pred = seg_logit.argmax(dim=0) | |||
| seg_pred = seg_pred.cpu().numpy() | |||
| results.append(seg_pred) | |||
| return results | |||
| @@ -0,0 +1,76 @@ | |||
| import os.path as osp | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base.base_torch_model import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.cv.image_semantic_segmentation import (pan_merge, | |||
| vit_adapter) | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| @MODELS.register_module( | |||
| Tasks.image_segmentation, module_name=Models.swinL_semantic_segmentation) | |||
| @MODELS.register_module( | |||
| Tasks.image_segmentation, | |||
| module_name=Models.vitadapter_semantic_segmentation) | |||
| class SemanticSegmentation(TorchModel): | |||
| def __init__(self, model_dir: str, **kwargs): | |||
| """str -- model file root.""" | |||
| super().__init__(model_dir, **kwargs) | |||
| from mmcv.runner import load_checkpoint | |||
| import mmcv | |||
| from mmdet.models import build_detector | |||
| config = osp.join(model_dir, 'mmcv_config.py') | |||
| cfg = mmcv.Config.fromfile(config) | |||
| if 'pretrained' in cfg.model: | |||
| cfg.model.pretrained = None | |||
| elif 'init_cfg' in cfg.model.backbone: | |||
| cfg.model.backbone.init_cfg = None | |||
| # build model | |||
| cfg.model.train_cfg = None | |||
| self.model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) | |||
| # load model | |||
| model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) | |||
| _ = load_checkpoint(self.model, model_path, map_location='cpu') | |||
| self.CLASSES = cfg['CLASSES'] # list | |||
| self.PALETTE = cfg['PALETTE'] # list | |||
| self.num_classes = len(self.CLASSES) | |||
| self.cfg = cfg | |||
| def forward(self, Inputs): | |||
| return self.model(**Inputs) | |||
| def postprocess(self, Inputs): | |||
| semantic_result = Inputs[0] | |||
| ids = np.unique(semantic_result)[::-1] | |||
| legal_indices = ids != self.model.num_classes # for VOID label | |||
| ids = ids[legal_indices] | |||
| segms = (semantic_result[None] == ids[:, None, None]) | |||
| masks = [it.astype(np.int) for it in segms] | |||
| labels_txt = np.array(self.CLASSES)[ids].tolist() | |||
| results = { | |||
| OutputKeys.MASKS: masks, | |||
| OutputKeys.LABELS: labels_txt, | |||
| OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))] | |||
| } | |||
| return results | |||
| def inference(self, data): | |||
| with torch.no_grad(): | |||
| results = self.model(return_loss=False, rescale=True, **data) | |||
| return results | |||
| @@ -0,0 +1,3 @@ | |||
| from .models import backbone, decode_heads, segmentors | |||
| from .utils import (ResizeToMultiple, add_prefix, build_pixel_sampler, | |||
| seg_resize) | |||
| @@ -0,0 +1,3 @@ | |||
| from .backbone import BASEBEiT, BEiTAdapter | |||
| from .decode_heads import Mask2FormerHeadFromMMSeg | |||
| from .segmentors import EncoderDecoderMask2Former | |||
| @@ -0,0 +1,4 @@ | |||
| from .base import BASEBEiT | |||
| from .beit_adapter import BEiTAdapter | |||
| __all__ = ['BEiTAdapter', 'BASEBEiT'] | |||
| @@ -0,0 +1,523 @@ | |||
| # The implementation refers to the VitAdapter | |||
| # available at | |||
| # https://github.com/czczup/ViT-Adapter.git | |||
| import logging | |||
| from functools import partial | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.utils.checkpoint as cp | |||
| from mmdet.models.utils.transformer import MultiScaleDeformableAttention | |||
| from timm.models.layers import DropPath | |||
| _logger = logging.getLogger(__name__) | |||
| def get_reference_points(spatial_shapes, device): | |||
| reference_points_list = [] | |||
| for lvl, (H_, W_) in enumerate(spatial_shapes): | |||
| ref_y, ref_x = torch.meshgrid( | |||
| torch.linspace( | |||
| 0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), | |||
| torch.linspace( | |||
| 0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) | |||
| ref_y = ref_y.reshape(-1)[None] / H_ | |||
| ref_x = ref_x.reshape(-1)[None] / W_ | |||
| ref = torch.stack((ref_x, ref_y), -1) | |||
| reference_points_list.append(ref) | |||
| reference_points = torch.cat(reference_points_list, 1) | |||
| reference_points = reference_points[:, :, None] | |||
| return reference_points | |||
| def deform_inputs(x): | |||
| bs, c, h, w = x.shape | |||
| spatial_shapes = torch.as_tensor([(h // 8, w // 8), (h // 16, w // 16), | |||
| (h // 32, w // 32)], | |||
| dtype=torch.long, | |||
| device=x.device) | |||
| level_start_index = torch.cat((spatial_shapes.new_zeros( | |||
| (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) | |||
| reference_points = get_reference_points([(h // 16, w // 16)], x.device) | |||
| deform_inputs1 = [reference_points, spatial_shapes, level_start_index] | |||
| spatial_shapes = torch.as_tensor([(h // 16, w // 16)], | |||
| dtype=torch.long, | |||
| device=x.device) | |||
| level_start_index = torch.cat((spatial_shapes.new_zeros( | |||
| (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) | |||
| reference_points = get_reference_points([(h // 8, w // 8), | |||
| (h // 16, w // 16), | |||
| (h // 32, w // 32)], x.device) | |||
| deform_inputs2 = [reference_points, spatial_shapes, level_start_index] | |||
| return deform_inputs1, deform_inputs2 | |||
| class ConvFFN(nn.Module): | |||
| def __init__(self, | |||
| in_features, | |||
| hidden_features=None, | |||
| out_features=None, | |||
| act_layer=nn.GELU, | |||
| drop=0.): | |||
| super().__init__() | |||
| out_features = out_features or in_features | |||
| hidden_features = hidden_features or in_features | |||
| self.fc1 = nn.Linear(in_features, hidden_features) | |||
| self.dwconv = DWConv(hidden_features) | |||
| self.act = act_layer() | |||
| self.fc2 = nn.Linear(hidden_features, out_features) | |||
| self.drop = nn.Dropout(drop) | |||
| def forward(self, x, H, W): | |||
| x = self.fc1(x) | |||
| x = self.dwconv(x, H, W) | |||
| x = self.act(x) | |||
| x = self.drop(x) | |||
| x = self.fc2(x) | |||
| x = self.drop(x) | |||
| return x | |||
| class DWConv(nn.Module): | |||
| def __init__(self, dim=768): | |||
| super().__init__() | |||
| self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) | |||
| def forward(self, x, H, W): | |||
| B, N, C = x.shape | |||
| n = N // 21 | |||
| x1 = x[:, 0:16 * n, :].transpose(1, 2).view(B, C, H * 2, | |||
| W * 2).contiguous() | |||
| x2 = x[:, 16 * n:20 * n, :].transpose(1, 2).view(B, C, H, | |||
| W).contiguous() | |||
| x3 = x[:, 20 * n:, :].transpose(1, 2).view(B, C, H // 2, | |||
| W // 2).contiguous() | |||
| x1 = self.dwconv(x1).flatten(2).transpose(1, 2) | |||
| x2 = self.dwconv(x2).flatten(2).transpose(1, 2) | |||
| x3 = self.dwconv(x3).flatten(2).transpose(1, 2) | |||
| x = torch.cat([x1, x2, x3], dim=1) | |||
| return x | |||
| class Extractor(nn.Module): | |||
| def __init__(self, | |||
| dim, | |||
| num_heads=6, | |||
| n_points=4, | |||
| n_levels=1, | |||
| deform_ratio=1.0, | |||
| with_cffn=True, | |||
| cffn_ratio=0.25, | |||
| drop=0., | |||
| drop_path=0., | |||
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |||
| with_cp=False): | |||
| super().__init__() | |||
| self.query_norm = norm_layer(dim) | |||
| self.feat_norm = norm_layer(dim) | |||
| self.attn = MultiScaleDeformableAttention( | |||
| embed_dims=dim, | |||
| num_heads=num_heads, | |||
| num_levels=n_levels, | |||
| num_points=n_points, | |||
| batch_first=True) | |||
| # modify to fit the deform_ratio | |||
| value_proj_in_features = self.attn.value_proj.weight.shape[0] | |||
| value_proj_out_features = int(value_proj_in_features * deform_ratio) | |||
| self.attn.value_proj = nn.Linear(value_proj_in_features, | |||
| value_proj_out_features) | |||
| self.attn.output_proj = nn.Linear(value_proj_out_features, | |||
| value_proj_in_features) | |||
| self.with_cffn = with_cffn | |||
| self.with_cp = with_cp | |||
| if with_cffn: | |||
| self.ffn = ConvFFN( | |||
| in_features=dim, | |||
| hidden_features=int(dim * cffn_ratio), | |||
| drop=drop) | |||
| self.ffn_norm = norm_layer(dim) | |||
| self.drop_path = DropPath( | |||
| drop_path) if drop_path > 0. else nn.Identity() | |||
| def forward(self, query, reference_points, feat, spatial_shapes, | |||
| level_start_index, H, W): | |||
| def _inner_forward(query, feat): | |||
| attn = self.attn( | |||
| query=self.query_norm(query), | |||
| key=None, | |||
| value=self.feat_norm(feat), | |||
| identity=None, | |||
| query_pos=None, | |||
| key_padding_mask=None, | |||
| reference_points=reference_points, | |||
| spatial_shapes=spatial_shapes, | |||
| level_start_index=level_start_index) | |||
| query = query + attn | |||
| if self.with_cffn: | |||
| query = query + self.drop_path( | |||
| self.ffn(self.ffn_norm(query), H, W)) | |||
| return query | |||
| if self.with_cp and query.requires_grad: | |||
| query = cp.checkpoint(_inner_forward, query, feat) | |||
| else: | |||
| query = _inner_forward(query, feat) | |||
| return query | |||
| class Injector(nn.Module): | |||
| def __init__(self, | |||
| dim, | |||
| num_heads=6, | |||
| n_points=4, | |||
| n_levels=1, | |||
| deform_ratio=1.0, | |||
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |||
| init_values=0., | |||
| with_cp=False): | |||
| super().__init__() | |||
| self.with_cp = with_cp | |||
| self.query_norm = norm_layer(dim) | |||
| self.feat_norm = norm_layer(dim) | |||
| self.attn = MultiScaleDeformableAttention( | |||
| embed_dims=dim, | |||
| num_heads=num_heads, | |||
| num_levels=n_levels, | |||
| num_points=n_points, | |||
| batch_first=True) | |||
| # modify to fit the deform_ratio | |||
| value_proj_in_features = self.attn.value_proj.weight.shape[0] | |||
| value_proj_out_features = int(value_proj_in_features * deform_ratio) | |||
| self.attn.value_proj = nn.Linear(value_proj_in_features, | |||
| value_proj_out_features) | |||
| self.attn.output_proj = nn.Linear(value_proj_out_features, | |||
| value_proj_in_features) | |||
| self.gamma = nn.Parameter( | |||
| init_values * torch.ones((dim)), requires_grad=True) | |||
| def forward(self, query, reference_points, feat, spatial_shapes, | |||
| level_start_index): | |||
| def _inner_forward(query, feat): | |||
| input_query = self.query_norm(query) | |||
| input_value = self.feat_norm(feat) | |||
| attn = self.attn( | |||
| query=input_query, | |||
| key=None, | |||
| value=input_value, | |||
| identity=None, | |||
| query_pos=None, | |||
| key_padding_mask=None, | |||
| reference_points=reference_points, | |||
| spatial_shapes=spatial_shapes, | |||
| level_start_index=level_start_index) | |||
| return query + self.gamma * attn | |||
| if self.with_cp and query.requires_grad: | |||
| query = cp.checkpoint(_inner_forward, query, feat) | |||
| else: | |||
| query = _inner_forward(query, feat) | |||
| return query | |||
| class InteractionBlock(nn.Module): | |||
| def __init__(self, | |||
| dim, | |||
| num_heads=6, | |||
| n_points=4, | |||
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |||
| drop=0., | |||
| drop_path=0., | |||
| with_cffn=True, | |||
| cffn_ratio=0.25, | |||
| init_values=0., | |||
| deform_ratio=1.0, | |||
| extra_extractor=False, | |||
| with_cp=False): | |||
| super().__init__() | |||
| self.injector = Injector( | |||
| dim=dim, | |||
| n_levels=3, | |||
| num_heads=num_heads, | |||
| init_values=init_values, | |||
| n_points=n_points, | |||
| norm_layer=norm_layer, | |||
| deform_ratio=deform_ratio, | |||
| with_cp=with_cp) | |||
| self.extractor = Extractor( | |||
| dim=dim, | |||
| n_levels=1, | |||
| num_heads=num_heads, | |||
| n_points=n_points, | |||
| norm_layer=norm_layer, | |||
| deform_ratio=deform_ratio, | |||
| with_cffn=with_cffn, | |||
| cffn_ratio=cffn_ratio, | |||
| drop=drop, | |||
| drop_path=drop_path, | |||
| with_cp=with_cp) | |||
| if extra_extractor: | |||
| self.extra_extractors = nn.Sequential(*[ | |||
| Extractor( | |||
| dim=dim, | |||
| num_heads=num_heads, | |||
| n_points=n_points, | |||
| norm_layer=norm_layer, | |||
| with_cffn=with_cffn, | |||
| cffn_ratio=cffn_ratio, | |||
| deform_ratio=deform_ratio, | |||
| drop=drop, | |||
| drop_path=drop_path, | |||
| with_cp=with_cp) for _ in range(2) | |||
| ]) | |||
| else: | |||
| self.extra_extractors = None | |||
| def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H, W): | |||
| x = self.injector( | |||
| query=x, | |||
| reference_points=deform_inputs1[0], | |||
| feat=c, | |||
| spatial_shapes=deform_inputs1[1], | |||
| level_start_index=deform_inputs1[2]) | |||
| for idx, blk in enumerate(blocks): | |||
| x = blk(x, H, W) | |||
| c = self.extractor( | |||
| query=c, | |||
| reference_points=deform_inputs2[0], | |||
| feat=x, | |||
| spatial_shapes=deform_inputs2[1], | |||
| level_start_index=deform_inputs2[2], | |||
| H=H, | |||
| W=W) | |||
| if self.extra_extractors is not None: | |||
| for extractor in self.extra_extractors: | |||
| c = extractor( | |||
| query=c, | |||
| reference_points=deform_inputs2[0], | |||
| feat=x, | |||
| spatial_shapes=deform_inputs2[1], | |||
| level_start_index=deform_inputs2[2], | |||
| H=H, | |||
| W=W) | |||
| return x, c | |||
| class InteractionBlockWithCls(nn.Module): | |||
| def __init__(self, | |||
| dim, | |||
| num_heads=6, | |||
| n_points=4, | |||
| norm_layer=partial(nn.LayerNorm, eps=1e-6), | |||
| drop=0., | |||
| drop_path=0., | |||
| with_cffn=True, | |||
| cffn_ratio=0.25, | |||
| init_values=0., | |||
| deform_ratio=1.0, | |||
| extra_extractor=False, | |||
| with_cp=False): | |||
| super().__init__() | |||
| self.injector = Injector( | |||
| dim=dim, | |||
| n_levels=3, | |||
| num_heads=num_heads, | |||
| init_values=init_values, | |||
| n_points=n_points, | |||
| norm_layer=norm_layer, | |||
| deform_ratio=deform_ratio, | |||
| with_cp=with_cp) | |||
| self.extractor = Extractor( | |||
| dim=dim, | |||
| n_levels=1, | |||
| num_heads=num_heads, | |||
| n_points=n_points, | |||
| norm_layer=norm_layer, | |||
| deform_ratio=deform_ratio, | |||
| with_cffn=with_cffn, | |||
| cffn_ratio=cffn_ratio, | |||
| drop=drop, | |||
| drop_path=drop_path, | |||
| with_cp=with_cp) | |||
| if extra_extractor: | |||
| self.extra_extractors = nn.Sequential(*[ | |||
| Extractor( | |||
| dim=dim, | |||
| num_heads=num_heads, | |||
| n_points=n_points, | |||
| norm_layer=norm_layer, | |||
| with_cffn=with_cffn, | |||
| cffn_ratio=cffn_ratio, | |||
| deform_ratio=deform_ratio, | |||
| drop=drop, | |||
| drop_path=drop_path, | |||
| with_cp=with_cp) for _ in range(2) | |||
| ]) | |||
| else: | |||
| self.extra_extractors = None | |||
| def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H, W): | |||
| x = self.injector( | |||
| query=x, | |||
| reference_points=deform_inputs1[0], | |||
| feat=c, | |||
| spatial_shapes=deform_inputs1[1], | |||
| level_start_index=deform_inputs1[2]) | |||
| x = torch.cat((cls, x), dim=1) | |||
| for idx, blk in enumerate(blocks): | |||
| x = blk(x, H, W) | |||
| cls, x = x[:, :1, ], x[:, 1:, ] | |||
| c = self.extractor( | |||
| query=c, | |||
| reference_points=deform_inputs2[0], | |||
| feat=x, | |||
| spatial_shapes=deform_inputs2[1], | |||
| level_start_index=deform_inputs2[2], | |||
| H=H, | |||
| W=W) | |||
| if self.extra_extractors is not None: | |||
| for extractor in self.extra_extractors: | |||
| c = extractor( | |||
| query=c, | |||
| reference_points=deform_inputs2[0], | |||
| feat=x, | |||
| spatial_shapes=deform_inputs2[1], | |||
| level_start_index=deform_inputs2[2], | |||
| H=H, | |||
| W=W) | |||
| return x, c, cls | |||
| class SpatialPriorModule(nn.Module): | |||
| def __init__(self, inplanes=64, embed_dim=384, with_cp=False): | |||
| super().__init__() | |||
| self.with_cp = with_cp | |||
| self.stem = nn.Sequential(*[ | |||
| nn.Conv2d( | |||
| 3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), | |||
| nn.SyncBatchNorm(inplanes), | |||
| nn.ReLU(inplace=True), | |||
| nn.Conv2d( | |||
| inplanes, | |||
| inplanes, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| bias=False), | |||
| nn.SyncBatchNorm(inplanes), | |||
| nn.ReLU(inplace=True), | |||
| nn.Conv2d( | |||
| inplanes, | |||
| inplanes, | |||
| kernel_size=3, | |||
| stride=1, | |||
| padding=1, | |||
| bias=False), | |||
| nn.SyncBatchNorm(inplanes), | |||
| nn.ReLU(inplace=True), | |||
| nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||
| ]) | |||
| self.conv2 = nn.Sequential(*[ | |||
| nn.Conv2d( | |||
| inplanes, | |||
| 2 * inplanes, | |||
| kernel_size=3, | |||
| stride=2, | |||
| padding=1, | |||
| bias=False), | |||
| nn.SyncBatchNorm(2 * inplanes), | |||
| nn.ReLU(inplace=True) | |||
| ]) | |||
| self.conv3 = nn.Sequential(*[ | |||
| nn.Conv2d( | |||
| 2 * inplanes, | |||
| 4 * inplanes, | |||
| kernel_size=3, | |||
| stride=2, | |||
| padding=1, | |||
| bias=False), | |||
| nn.SyncBatchNorm(4 * inplanes), | |||
| nn.ReLU(inplace=True) | |||
| ]) | |||
| self.conv4 = nn.Sequential(*[ | |||
| nn.Conv2d( | |||
| 4 * inplanes, | |||
| 4 * inplanes, | |||
| kernel_size=3, | |||
| stride=2, | |||
| padding=1, | |||
| bias=False), | |||
| nn.SyncBatchNorm(4 * inplanes), | |||
| nn.ReLU(inplace=True) | |||
| ]) | |||
| self.fc1 = nn.Conv2d( | |||
| inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True) | |||
| self.fc2 = nn.Conv2d( | |||
| 2 * inplanes, | |||
| embed_dim, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=True) | |||
| self.fc3 = nn.Conv2d( | |||
| 4 * inplanes, | |||
| embed_dim, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=True) | |||
| self.fc4 = nn.Conv2d( | |||
| 4 * inplanes, | |||
| embed_dim, | |||
| kernel_size=1, | |||
| stride=1, | |||
| padding=0, | |||
| bias=True) | |||
| def forward(self, x): | |||
| def _inner_forward(x): | |||
| c1 = self.stem(x) | |||
| c2 = self.conv2(c1) | |||
| c3 = self.conv3(c2) | |||
| c4 = self.conv4(c3) | |||
| c1 = self.fc1(c1) | |||
| c2 = self.fc2(c2) | |||
| c3 = self.fc3(c3) | |||
| c4 = self.fc4(c4) | |||
| bs, dim, _, _ = c1.shape | |||
| c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s | |||
| c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s | |||
| c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s | |||
| return c1, c2, c3, c4 | |||
| if self.with_cp and x.requires_grad: | |||
| outs = cp.checkpoint(_inner_forward, x) | |||
| else: | |||
| outs = _inner_forward(x) | |||
| return outs | |||
| @@ -0,0 +1,3 @@ | |||
| from .beit import BASEBEiT | |||
| __all__ = ['BASEBEiT'] | |||
| @@ -0,0 +1,476 @@ | |||
| # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) | |||
| # Github source: https://github.com/microsoft/unilm/tree/master/beit | |||
| # This implementation refers to | |||
| # https://github.com/czczup/ViT-Adapter.git | |||
| import math | |||
| from functools import partial | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torch.utils.checkpoint as cp | |||
| from mmcv.runner import _load_checkpoint | |||
| from mmdet.models.builder import BACKBONES | |||
| from mmdet.utils import get_root_logger | |||
| from timm.models.layers import drop_path, to_2tuple, trunc_normal_ | |||
| class DropPath(nn.Module): | |||
| """Drop paths (Stochastic Depth) per sample (when applied in main path of | |||
| residual blocks).""" | |||
| def __init__(self, drop_prob=None): | |||
| super(DropPath, self).__init__() | |||
| self.drop_prob = drop_prob | |||
| def forward(self, x): | |||
| return drop_path(x, self.drop_prob, self.training) | |||
| def extra_repr(self) -> str: | |||
| return 'p={}'.format(self.drop_prob) | |||
| class Mlp(nn.Module): | |||
| def __init__(self, | |||
| in_features, | |||
| hidden_features=None, | |||
| out_features=None, | |||
| act_layer=nn.GELU, | |||
| drop=0.): | |||
| super().__init__() | |||
| out_features = out_features or in_features | |||
| hidden_features = hidden_features or in_features | |||
| self.fc1 = nn.Linear(in_features, hidden_features) | |||
| self.act = act_layer() | |||
| self.fc2 = nn.Linear(hidden_features, out_features) | |||
| self.drop = nn.Dropout(drop) | |||
| def forward(self, x): | |||
| x = self.fc1(x) | |||
| x = self.act(x) | |||
| # commit dropout for the original BERT implement | |||
| x = self.fc2(x) | |||
| x = self.drop(x) | |||
| return x | |||
| class Attention(nn.Module): | |||
| def __init__(self, | |||
| dim, | |||
| num_heads=8, | |||
| qkv_bias=False, | |||
| qk_scale=None, | |||
| attn_drop=0., | |||
| proj_drop=0., | |||
| window_size=None, | |||
| attn_head_dim=None): | |||
| super().__init__() | |||
| self.num_heads = num_heads | |||
| head_dim = dim // num_heads | |||
| if attn_head_dim is not None: | |||
| head_dim = attn_head_dim | |||
| all_head_dim = head_dim * self.num_heads | |||
| # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights | |||
| self.scale = qk_scale or head_dim**-0.5 | |||
| self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) | |||
| if qkv_bias: | |||
| self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) | |||
| self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) | |||
| else: | |||
| self.q_bias = None | |||
| self.v_bias = None | |||
| if window_size: | |||
| self.window_size = window_size | |||
| self.num_relative_distance = (2 * window_size[0] | |||
| - 1) * (2 * window_size[1] - 1) + 3 | |||
| self.relative_position_bias_table = nn.Parameter( | |||
| torch.zeros(self.num_relative_distance, | |||
| num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |||
| # cls to token & token 2 cls & cls to cls | |||
| # get pair-wise relative position index for each token inside the window | |||
| coords_h = torch.arange(window_size[0]) | |||
| coords_w = torch.arange(window_size[1]) | |||
| coords = torch.stack(torch.meshgrid([coords_h, | |||
| coords_w])) # 2, Wh, Ww | |||
| coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||
| relative_coords = coords_flatten[:, :, | |||
| None] - coords_flatten[:, | |||
| None, :] # 2, Wh*Ww, Wh*Ww | |||
| relative_coords = relative_coords.permute( | |||
| 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |||
| relative_coords[:, :, | |||
| 0] += window_size[0] - 1 # shift to start from 0 | |||
| relative_coords[:, :, 1] += window_size[1] - 1 | |||
| relative_coords[:, :, 0] *= 2 * window_size[1] - 1 | |||
| relative_position_index = \ | |||
| torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) | |||
| relative_position_index[1:, 1:] = relative_coords.sum( | |||
| -1) # Wh*Ww, Wh*Ww | |||
| relative_position_index[0, 0:] = self.num_relative_distance - 3 | |||
| relative_position_index[0:, 0] = self.num_relative_distance - 2 | |||
| relative_position_index[0, 0] = self.num_relative_distance - 1 | |||
| self.register_buffer('relative_position_index', | |||
| relative_position_index) | |||
| else: | |||
| self.window_size = None | |||
| self.relative_position_bias_table = None | |||
| self.relative_position_index = None | |||
| self.attn_drop = nn.Dropout(attn_drop) | |||
| self.proj = nn.Linear(all_head_dim, dim) | |||
| self.proj_drop = nn.Dropout(proj_drop) | |||
| def forward(self, x, rel_pos_bias=None): | |||
| B, N, C = x.shape | |||
| qkv_bias = None | |||
| if self.q_bias is not None: | |||
| qkv_bias = torch.cat( | |||
| (self.q_bias, | |||
| torch.zeros_like(self.v_bias, | |||
| requires_grad=False), self.v_bias)) | |||
| qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) | |||
| qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |||
| q, k, v = qkv[0], qkv[1], qkv[ | |||
| 2] # make torchscript happy (cannot use tensor as tuple) | |||
| q = q * self.scale | |||
| attn = (q @ k.transpose(-2, -1)) | |||
| if self.relative_position_bias_table is not None: | |||
| relative_position_bias = \ | |||
| self.relative_position_bias_table[self.relative_position_index.view(-1)].view( | |||
| self.window_size[0] * self.window_size[1] + 1, | |||
| self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH | |||
| relative_position_bias = relative_position_bias.permute( | |||
| 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |||
| attn = attn + relative_position_bias.unsqueeze(0) | |||
| if rel_pos_bias is not None: | |||
| attn = attn + rel_pos_bias | |||
| attn = attn.softmax(dim=-1) | |||
| attn = self.attn_drop(attn) | |||
| x = (attn @ v).transpose(1, 2).reshape(B, N, -1) | |||
| x = self.proj(x) | |||
| x = self.proj_drop(x) | |||
| return x | |||
| class Block(nn.Module): | |||
| def __init__(self, | |||
| dim, | |||
| num_heads, | |||
| mlp_ratio=4., | |||
| qkv_bias=False, | |||
| qk_scale=None, | |||
| drop=0., | |||
| attn_drop=0., | |||
| drop_path=0., | |||
| init_values=None, | |||
| act_layer=nn.GELU, | |||
| norm_layer=nn.LayerNorm, | |||
| window_size=None, | |||
| attn_head_dim=None, | |||
| with_cp=False): | |||
| super().__init__() | |||
| self.with_cp = with_cp | |||
| self.norm1 = norm_layer(dim) | |||
| self.attn = Attention( | |||
| dim, | |||
| num_heads=num_heads, | |||
| qkv_bias=qkv_bias, | |||
| qk_scale=qk_scale, | |||
| attn_drop=attn_drop, | |||
| proj_drop=drop, | |||
| window_size=window_size, | |||
| attn_head_dim=attn_head_dim) | |||
| # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | |||
| self.drop_path = DropPath( | |||
| drop_path) if drop_path > 0. else nn.Identity() | |||
| self.norm2 = norm_layer(dim) | |||
| mlp_hidden_dim = int(dim * mlp_ratio) | |||
| self.mlp = Mlp( | |||
| in_features=dim, | |||
| hidden_features=mlp_hidden_dim, | |||
| act_layer=act_layer, | |||
| drop=drop) | |||
| if init_values is not None: | |||
| self.gamma_1 = nn.Parameter( | |||
| init_values * torch.ones((dim)), requires_grad=True) | |||
| self.gamma_2 = nn.Parameter( | |||
| init_values * torch.ones((dim)), requires_grad=True) | |||
| else: | |||
| self.gamma_1, self.gamma_2 = None, None | |||
| def forward(self, x, H, W, rel_pos_bias=None): | |||
| def _inner_forward(x): | |||
| if self.gamma_1 is None: | |||
| x = x + self.drop_path( | |||
| self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) | |||
| x = x + self.drop_path(self.mlp(self.norm2(x))) | |||
| else: | |||
| x = x + self.drop_path(self.gamma_1 * self.attn( | |||
| self.norm1(x), rel_pos_bias=rel_pos_bias)) | |||
| x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) | |||
| return x | |||
| if self.with_cp and x.requires_grad: | |||
| x = cp.checkpoint(_inner_forward, x) | |||
| else: | |||
| x = _inner_forward(x) | |||
| return x | |||
| class PatchEmbed(nn.Module): | |||
| """ Image to Patch Embedding | |||
| """ | |||
| def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): | |||
| super().__init__() | |||
| img_size = to_2tuple(img_size) | |||
| patch_size = to_2tuple(patch_size) | |||
| num_patches = (img_size[1] // patch_size[1]) * ( | |||
| img_size[0] // patch_size[0]) | |||
| self.patch_shape = (img_size[0] // patch_size[0], | |||
| img_size[1] // patch_size[1]) | |||
| self.img_size = img_size | |||
| self.patch_size = patch_size | |||
| self.num_patches = num_patches | |||
| self.proj = nn.Conv2d( | |||
| in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) | |||
| def forward(self, x, **kwargs): | |||
| B, C, H, W = x.shape | |||
| # FIXME look at relaxing size constraints | |||
| # assert H == self.img_size[0] and W == self.img_size[1], \ | |||
| # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." | |||
| x = self.proj(x) | |||
| Hp, Wp = x.shape[2], x.shape[3] | |||
| x = x.flatten(2).transpose(1, 2) | |||
| return x, Hp, Wp | |||
| class HybridEmbed(nn.Module): | |||
| """ CNN Feature Map Embedding | |||
| Extract feature map from CNN, flatten, project to embedding dim. | |||
| """ | |||
| def __init__(self, | |||
| backbone, | |||
| img_size=224, | |||
| feature_size=None, | |||
| in_chans=3, | |||
| embed_dim=768): | |||
| super().__init__() | |||
| assert isinstance(backbone, nn.Module) | |||
| img_size = to_2tuple(img_size) | |||
| self.img_size = img_size | |||
| self.backbone = backbone | |||
| if feature_size is None: | |||
| with torch.no_grad(): | |||
| # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature | |||
| # map for all networks, the feature metadata has reliable channel and stride info, but using | |||
| # stride to calc feature dim requires info about padding of each stage that isn't captured. | |||
| training = backbone.training | |||
| if training: | |||
| backbone.eval() | |||
| o = self.backbone( | |||
| torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] | |||
| feature_size = o.shape[-2:] | |||
| feature_dim = o.shape[1] | |||
| backbone.train(training) | |||
| else: | |||
| feature_size = to_2tuple(feature_size) | |||
| feature_dim = self.backbone.feature_info.channels()[-1] | |||
| self.num_patches = feature_size[0] * feature_size[1] | |||
| self.proj = nn.Linear(feature_dim, embed_dim) | |||
| def forward(self, x): | |||
| x = self.backbone(x)[-1] | |||
| x = x.flatten(2).transpose(1, 2) | |||
| x = self.proj(x) | |||
| return x | |||
| class RelativePositionBias(nn.Module): | |||
| def __init__(self, window_size, num_heads): | |||
| super().__init__() | |||
| self.window_size = window_size | |||
| self.num_relative_distance = (2 * window_size[0] | |||
| - 1) * (2 * window_size[1] - 1) + 3 | |||
| self.relative_position_bias_table = nn.Parameter( | |||
| torch.zeros(self.num_relative_distance, | |||
| num_heads)) # 2*Wh-1 * 2*Ww-1, nH | |||
| # cls to token & token 2 cls & cls to cls | |||
| # get pair-wise relative position index for each token inside the window | |||
| coords_h = torch.arange(window_size[0]) | |||
| coords_w = torch.arange(window_size[1]) | |||
| coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww | |||
| coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww | |||
| relative_coords = coords_flatten[:, :, | |||
| None] - coords_flatten[:, | |||
| None, :] # 2, Wh*Ww, Wh*Ww | |||
| relative_coords = relative_coords.permute( | |||
| 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 | |||
| relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 | |||
| relative_coords[:, :, 1] += window_size[1] - 1 | |||
| relative_coords[:, :, 0] *= 2 * window_size[1] - 1 | |||
| relative_position_index = \ | |||
| torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) | |||
| relative_position_index[1:, | |||
| 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww | |||
| relative_position_index[0, 0:] = self.num_relative_distance - 3 | |||
| relative_position_index[0:, 0] = self.num_relative_distance - 2 | |||
| relative_position_index[0, 0] = self.num_relative_distance - 1 | |||
| self.register_buffer('relative_position_index', | |||
| relative_position_index) | |||
| def forward(self): | |||
| relative_position_bias = \ | |||
| self.relative_position_bias_table[self.relative_position_index.view(-1)].view( | |||
| self.window_size[0] * self.window_size[1] + 1, | |||
| self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH | |||
| return relative_position_bias.permute( | |||
| 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww | |||
| @BACKBONES.register_module() | |||
| class BASEBEiT(nn.Module): | |||
| """ Vision Transformer with support for patch or hybrid CNN input stage | |||
| """ | |||
| def __init__(self, | |||
| img_size=512, | |||
| patch_size=16, | |||
| in_chans=3, | |||
| num_classes=80, | |||
| embed_dim=768, | |||
| depth=12, | |||
| num_heads=12, | |||
| mlp_ratio=4., | |||
| qkv_bias=False, | |||
| qk_scale=None, | |||
| drop_rate=0., | |||
| attn_drop_rate=0., | |||
| drop_path_rate=0., | |||
| hybrid_backbone=None, | |||
| norm_layer=None, | |||
| init_values=None, | |||
| use_checkpoint=False, | |||
| use_abs_pos_emb=False, | |||
| use_rel_pos_bias=True, | |||
| use_shared_rel_pos_bias=False, | |||
| pretrained=None, | |||
| with_cp=False): | |||
| super().__init__() | |||
| norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) | |||
| self.norm_layer = norm_layer | |||
| self.num_classes = num_classes | |||
| self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models | |||
| self.drop_path_rate = drop_path_rate | |||
| if hybrid_backbone is not None: | |||
| self.patch_embed = HybridEmbed( | |||
| hybrid_backbone, | |||
| img_size=img_size, | |||
| in_chans=in_chans, | |||
| embed_dim=embed_dim) | |||
| else: | |||
| self.patch_embed = PatchEmbed( | |||
| img_size=img_size, | |||
| patch_size=patch_size, | |||
| in_chans=in_chans, | |||
| embed_dim=embed_dim) | |||
| num_patches = self.patch_embed.num_patches | |||
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |||
| if use_abs_pos_emb: | |||
| self.pos_embed = nn.Parameter( | |||
| torch.zeros(1, num_patches + 1, embed_dim)) | |||
| else: | |||
| self.pos_embed = None | |||
| self.pos_drop = nn.Dropout(p=drop_rate) | |||
| if use_shared_rel_pos_bias: | |||
| self.rel_pos_bias = RelativePositionBias( | |||
| window_size=self.patch_embed.patch_shape, num_heads=num_heads) | |||
| else: | |||
| self.rel_pos_bias = None | |||
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) | |||
| ] # stochastic depth decay rule | |||
| self.use_rel_pos_bias = use_rel_pos_bias | |||
| self.use_checkpoint = use_checkpoint | |||
| self.blocks = nn.ModuleList([ | |||
| Block( | |||
| dim=embed_dim, | |||
| num_heads=num_heads, | |||
| mlp_ratio=mlp_ratio, | |||
| qkv_bias=qkv_bias, | |||
| qk_scale=qk_scale, | |||
| drop=drop_rate, | |||
| attn_drop=attn_drop_rate, | |||
| drop_path=dpr[i], | |||
| norm_layer=norm_layer, | |||
| with_cp=with_cp, | |||
| init_values=init_values, | |||
| window_size=self.patch_embed.patch_shape | |||
| if use_rel_pos_bias else None) for i in range(depth) | |||
| ]) | |||
| trunc_normal_(self.cls_token, std=.02) | |||
| self.apply(self._init_weights) | |||
| self.init_weights(pretrained) | |||
| def init_weights(self, pretrained=None): | |||
| """Initialize the weights in backbone. | |||
| Args: | |||
| pretrained (str, optional): Path to pre-trained weights. | |||
| Defaults to None. | |||
| """ | |||
| if isinstance(pretrained, str): | |||
| logger = get_root_logger() | |||
| init_cfg = dict(type='Pretrained', checkpoint=pretrained) | |||
| checkpoint = _load_checkpoint( | |||
| init_cfg['checkpoint'], logger=logger, map_location='cpu') | |||
| state_dict = self.resize_rel_pos_embed(checkpoint) | |||
| self.load_state_dict(state_dict, False) | |||
| def fix_init_weight(self): | |||
| def rescale(param, layer_id): | |||
| param.div_(math.sqrt(2.0 * layer_id)) | |||
| for layer_id, layer in enumerate(self.blocks): | |||
| rescale(layer.attn.proj.weight.data, layer_id + 1) | |||
| rescale(layer.mlp.fc2.weight.data, layer_id + 1) | |||
| def _init_weights(self, m): | |||
| if isinstance(m, nn.Linear): | |||
| trunc_normal_(m.weight, std=.02) | |||
| if isinstance(m, nn.Linear) and m.bias is not None: | |||
| nn.init.constant_(m.bias, 0) | |||
| elif isinstance(m, nn.LayerNorm): | |||
| nn.init.constant_(m.bias, 0) | |||
| nn.init.constant_(m.weight, 1.0) | |||
| def get_num_layers(self): | |||
| return len(self.blocks) | |||
| @@ -0,0 +1,169 @@ | |||
| # The implementation refers to the VitAdapter | |||
| # available at | |||
| # https://github.com/czczup/ViT-Adapter.git | |||
| import logging | |||
| import math | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from mmdet.models.builder import BACKBONES | |||
| from mmdet.models.utils.transformer import MultiScaleDeformableAttention | |||
| from timm.models.layers import DropPath, trunc_normal_ | |||
| from torch.nn.init import normal_ | |||
| from .adapter_modules import InteractionBlockWithCls as InteractionBlock | |||
| from .adapter_modules import SpatialPriorModule, deform_inputs | |||
| from .base.beit import BASEBEiT | |||
| _logger = logging.getLogger(__name__) | |||
| @BACKBONES.register_module() | |||
| class BEiTAdapter(BASEBEiT): | |||
| def __init__(self, | |||
| pretrain_size=224, | |||
| conv_inplane=64, | |||
| n_points=4, | |||
| deform_num_heads=6, | |||
| init_values=0., | |||
| cffn_ratio=0.25, | |||
| deform_ratio=1.0, | |||
| with_cffn=True, | |||
| interaction_indexes=None, | |||
| add_vit_feature=True, | |||
| with_cp=False, | |||
| *args, | |||
| **kwargs): | |||
| super().__init__( | |||
| init_values=init_values, with_cp=with_cp, *args, **kwargs) | |||
| self.num_block = len(self.blocks) | |||
| self.pretrain_size = (pretrain_size, pretrain_size) | |||
| self.flags = [ | |||
| i for i in range(-1, self.num_block, self.num_block // 4) | |||
| ][1:] | |||
| self.interaction_indexes = interaction_indexes | |||
| self.add_vit_feature = add_vit_feature | |||
| embed_dim = self.embed_dim | |||
| self.level_embed = nn.Parameter(torch.zeros(3, embed_dim)) | |||
| self.spm = SpatialPriorModule( | |||
| inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False) | |||
| self.interactions = nn.Sequential(*[ | |||
| InteractionBlock( | |||
| dim=embed_dim, | |||
| num_heads=deform_num_heads, | |||
| n_points=n_points, | |||
| init_values=init_values, | |||
| drop_path=self.drop_path_rate, | |||
| norm_layer=self.norm_layer, | |||
| with_cffn=with_cffn, | |||
| cffn_ratio=cffn_ratio, | |||
| deform_ratio=deform_ratio, | |||
| extra_extractor=True if i == len(interaction_indexes) | |||
| - 1 else False, | |||
| with_cp=with_cp) for i in range(len(interaction_indexes)) | |||
| ]) | |||
| self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) | |||
| self.norm1 = nn.SyncBatchNorm(embed_dim) | |||
| self.norm2 = nn.SyncBatchNorm(embed_dim) | |||
| self.norm3 = nn.SyncBatchNorm(embed_dim) | |||
| self.norm4 = nn.SyncBatchNorm(embed_dim) | |||
| self.up.apply(self._init_weights) | |||
| self.spm.apply(self._init_weights) | |||
| self.interactions.apply(self._init_weights) | |||
| self.apply(self._init_deform_weights) | |||
| normal_(self.level_embed) | |||
| def _init_weights(self, m): | |||
| if isinstance(m, nn.Linear): | |||
| trunc_normal_(m.weight, std=.02) | |||
| if isinstance(m, nn.Linear) and m.bias is not None: | |||
| nn.init.constant_(m.bias, 0) | |||
| elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d): | |||
| nn.init.constant_(m.bias, 0) | |||
| nn.init.constant_(m.weight, 1.0) | |||
| elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): | |||
| fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |||
| fan_out //= m.groups | |||
| m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) | |||
| if m.bias is not None: | |||
| m.bias.data.zero_() | |||
| def _get_pos_embed(self, pos_embed, H, W): | |||
| pos_embed = pos_embed.reshape(1, self.pretrain_size[0] // 16, | |||
| self.pretrain_size[1] // 16, | |||
| -1).permute(0, 3, 1, 2) | |||
| pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \ | |||
| reshape(1, -1, H * W).permute(0, 2, 1) | |||
| return pos_embed | |||
| def _init_deform_weights(self, m): | |||
| if isinstance(m, MultiScaleDeformableAttention): | |||
| m.init_weights() | |||
| def _add_level_embed(self, c2, c3, c4): | |||
| c2 = c2 + self.level_embed[0] | |||
| c3 = c3 + self.level_embed[1] | |||
| c4 = c4 + self.level_embed[2] | |||
| return c2, c3, c4 | |||
| def forward(self, x): | |||
| deform_inputs1, deform_inputs2 = deform_inputs(x) | |||
| # SPM forward | |||
| c1, c2, c3, c4 = self.spm(x) | |||
| c2, c3, c4 = self._add_level_embed(c2, c3, c4) | |||
| c = torch.cat([c2, c3, c4], dim=1) | |||
| # Patch Embedding forward | |||
| x, H, W = self.patch_embed(x) | |||
| bs, n, dim = x.shape | |||
| cls = self.cls_token.expand( | |||
| bs, -1, -1) # stole cls_tokens impl from Phil Wang, thanks | |||
| if self.pos_embed is not None: | |||
| pos_embed = self._get_pos_embed(self.pos_embed, H, W) | |||
| x = x + pos_embed | |||
| x = self.pos_drop(x) | |||
| # Interaction | |||
| outs = list() | |||
| for i, layer in enumerate(self.interactions): | |||
| indexes = self.interaction_indexes[i] | |||
| x, c, cls = layer(x, c, cls, | |||
| self.blocks[indexes[0]:indexes[-1] + 1], | |||
| deform_inputs1, deform_inputs2, H, W) | |||
| outs.append(x.transpose(1, 2).view(bs, dim, H, W).contiguous()) | |||
| # Split & Reshape | |||
| c2 = c[:, 0:c2.size(1), :] | |||
| c3 = c[:, c2.size(1):c2.size(1) + c3.size(1), :] | |||
| c4 = c[:, c2.size(1) + c3.size(1):, :] | |||
| c2 = c2.transpose(1, 2).view(bs, dim, H * 2, W * 2).contiguous() | |||
| c3 = c3.transpose(1, 2).view(bs, dim, H, W).contiguous() | |||
| c4 = c4.transpose(1, 2).view(bs, dim, H // 2, W // 2).contiguous() | |||
| c1 = self.up(c2) + c1 | |||
| if self.add_vit_feature: | |||
| x1, x2, x3, x4 = outs | |||
| x1 = F.interpolate( | |||
| x1, scale_factor=4, mode='bilinear', align_corners=False) | |||
| x2 = F.interpolate( | |||
| x2, scale_factor=2, mode='bilinear', align_corners=False) | |||
| x4 = F.interpolate( | |||
| x4, scale_factor=0.5, mode='bilinear', align_corners=False) | |||
| c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4 | |||
| # Final Norm | |||
| f1 = self.norm1(c1) | |||
| f2 = self.norm2(c2) | |||
| f3 = self.norm3(c3) | |||
| f4 = self.norm4(c4) | |||
| return [f1, f2, f3, f4] | |||
| @@ -0,0 +1,3 @@ | |||
| from .mask2former_head_from_mmseg import Mask2FormerHeadFromMMSeg | |||
| __all__ = ['Mask2FormerHeadFromMMSeg'] | |||
| @@ -0,0 +1,267 @@ | |||
| # The implementation refers to the VitAdapter | |||
| # available at | |||
| # https://github.com/czczup/ViT-Adapter.git | |||
| from abc import ABCMeta, abstractmethod | |||
| import torch | |||
| import torch.nn as nn | |||
| from mmcv.runner import BaseModule, auto_fp16, force_fp32 | |||
| from mmdet.models.builder import build_loss | |||
| from mmdet.models.losses import accuracy | |||
| from ...utils import build_pixel_sampler, seg_resize | |||
| class BaseDecodeHead(BaseModule, metaclass=ABCMeta): | |||
| """Base class for BaseDecodeHead. | |||
| Args: | |||
| in_channels (int|Sequence[int]): Input channels. | |||
| channels (int): Channels after modules, before conv_seg. | |||
| num_classes (int): Number of classes. | |||
| dropout_ratio (float): Ratio of dropout layer. Default: 0.1. | |||
| conv_cfg (dict|None): Config of conv layers. Default: None. | |||
| norm_cfg (dict|None): Config of norm layers. Default: None. | |||
| act_cfg (dict): Config of activation layers. | |||
| Default: dict(type='ReLU') | |||
| in_index (int|Sequence[int]): Input feature index. Default: -1 | |||
| input_transform (str|None): Transformation type of input features. | |||
| Options: 'resize_concat', 'multiple_select', None. | |||
| 'resize_concat': Multiple feature maps will be resize to the | |||
| same size as first one and than concat together. | |||
| Usually used in FCN head of HRNet. | |||
| 'multiple_select': Multiple feature maps will be bundle into | |||
| a list and passed into decode head. | |||
| None: Only one select feature map is allowed. | |||
| Default: None. | |||
| loss_decode (dict | Sequence[dict]): Config of decode loss. | |||
| The `loss_name` is property of corresponding loss function which | |||
| could be shown in training log. If you want this loss | |||
| item to be included into the backward graph, `loss_` must be the | |||
| prefix of the name. Defaults to 'loss_ce'. | |||
| e.g. dict(type='CrossEntropyLoss'), | |||
| [dict(type='CrossEntropyLoss', loss_name='loss_ce'), | |||
| dict(type='DiceLoss', loss_name='loss_dice')] | |||
| Default: dict(type='CrossEntropyLoss'). | |||
| ignore_index (int | None): The label index to be ignored. When using | |||
| masked BCE loss, ignore_index should be set to None. Default: 255. | |||
| sampler (dict|None): The config of segmentation map sampler. | |||
| Default: None. | |||
| align_corners (bool): align_corners argument of F.interpolate. | |||
| Default: False. | |||
| init_cfg (dict or list[dict], optional): Initialization config dict. | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| channels, | |||
| *, | |||
| num_classes, | |||
| dropout_ratio=0.1, | |||
| conv_cfg=None, | |||
| norm_cfg=None, | |||
| act_cfg=dict(type='ReLU'), | |||
| in_index=-1, | |||
| input_transform=None, | |||
| loss_decode=dict( | |||
| type='CrossEntropyLoss', | |||
| use_sigmoid=False, | |||
| loss_weight=1.0), | |||
| ignore_index=255, | |||
| sampler=None, | |||
| align_corners=False, | |||
| init_cfg=dict( | |||
| type='Normal', std=0.01, override=dict(name='conv_seg'))): | |||
| super(BaseDecodeHead, self).__init__(init_cfg) | |||
| self._init_inputs(in_channels, in_index, input_transform) | |||
| self.channels = channels | |||
| self.num_classes = num_classes | |||
| self.dropout_ratio = dropout_ratio | |||
| self.conv_cfg = conv_cfg | |||
| self.norm_cfg = norm_cfg | |||
| self.act_cfg = act_cfg | |||
| self.in_index = in_index | |||
| self.ignore_index = ignore_index | |||
| self.align_corners = align_corners | |||
| if isinstance(loss_decode, dict): | |||
| self.loss_decode = build_loss(loss_decode) | |||
| elif isinstance(loss_decode, (list, tuple)): | |||
| self.loss_decode = nn.ModuleList() | |||
| for loss in loss_decode: | |||
| self.loss_decode.append(build_loss(loss)) | |||
| else: | |||
| raise TypeError(f'loss_decode must be a dict or sequence of dict,\ | |||
| but got {type(loss_decode)}') | |||
| if sampler is not None: | |||
| self.sampler = build_pixel_sampler(sampler, context=self) | |||
| else: | |||
| self.sampler = None | |||
| self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1) | |||
| if dropout_ratio > 0: | |||
| self.dropout = nn.Dropout2d(dropout_ratio) | |||
| else: | |||
| self.dropout = None | |||
| self.fp16_enabled = False | |||
| def extra_repr(self): | |||
| """Extra repr.""" | |||
| s = f'input_transform={self.input_transform}, ' \ | |||
| f'ignore_index={self.ignore_index}, ' \ | |||
| f'align_corners={self.align_corners}' | |||
| return s | |||
| def _init_inputs(self, in_channels, in_index, input_transform): | |||
| """Check and initialize input transforms. | |||
| The in_channels, in_index and input_transform must match. | |||
| Specifically, when input_transform is None, only single feature map | |||
| will be selected. So in_channels and in_index must be of type int. | |||
| When input_transform | |||
| Args: | |||
| in_channels (int|Sequence[int]): Input channels. | |||
| in_index (int|Sequence[int]): Input feature index. | |||
| input_transform (str|None): Transformation type of input features. | |||
| Options: 'resize_concat', 'multiple_select', None. | |||
| 'resize_concat': Multiple feature maps will be resize to the | |||
| same size as first one and than concat together. | |||
| Usually used in FCN head of HRNet. | |||
| 'multiple_select': Multiple feature maps will be bundle into | |||
| a list and passed into decode head. | |||
| None: Only one select feature map is allowed. | |||
| """ | |||
| if input_transform is not None: | |||
| assert input_transform in ['resize_concat', 'multiple_select'] | |||
| self.input_transform = input_transform | |||
| self.in_index = in_index | |||
| if input_transform is not None: | |||
| assert isinstance(in_channels, (list, tuple)) | |||
| assert isinstance(in_index, (list, tuple)) | |||
| assert len(in_channels) == len(in_index) | |||
| if input_transform == 'resize_concat': | |||
| self.in_channels = sum(in_channels) | |||
| else: | |||
| self.in_channels = in_channels | |||
| else: | |||
| assert isinstance(in_channels, int) | |||
| assert isinstance(in_index, int) | |||
| self.in_channels = in_channels | |||
| def _transform_inputs(self, inputs): | |||
| """Transform inputs for decoder. | |||
| Args: | |||
| inputs (list[Tensor]): List of multi-level img features. | |||
| Returns: | |||
| Tensor: The transformed inputs | |||
| """ | |||
| if self.input_transform == 'resize_concat': | |||
| inputs = [inputs[i] for i in self.in_index] | |||
| upsampled_inputs = [ | |||
| seg_resize( | |||
| input=x, | |||
| size=inputs[0].shape[2:], | |||
| mode='bilinear', | |||
| align_corners=self.align_corners) for x in inputs | |||
| ] | |||
| inputs = torch.cat(upsampled_inputs, dim=1) | |||
| elif self.input_transform == 'multiple_select': | |||
| inputs = [inputs[i] for i in self.in_index] | |||
| else: | |||
| inputs = inputs[self.in_index] | |||
| return inputs | |||
| @auto_fp16() | |||
| @abstractmethod | |||
| def forward(self, inputs): | |||
| """Placeholder of forward function.""" | |||
| pass | |||
| def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg): | |||
| """Forward function for training. | |||
| Args: | |||
| inputs (list[Tensor]): List of multi-level img features. | |||
| img_metas (list[dict]): List of image info dict where each dict | |||
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
| For details on the values of these keys see | |||
| `mmseg/datasets/pipelines/formatting.py:Collect`. | |||
| gt_semantic_seg (Tensor): Semantic segmentation masks | |||
| used if the architecture supports semantic segmentation task. | |||
| train_cfg (dict): The training config. | |||
| Returns: | |||
| dict[str, Tensor]: a dictionary of loss components | |||
| """ | |||
| seg_logits = self.forward(inputs) | |||
| losses = self.losses(seg_logits, gt_semantic_seg) | |||
| return losses | |||
| def forward_test(self, inputs, img_metas, test_cfg): | |||
| """Forward function for testing. | |||
| Args: | |||
| inputs (list[Tensor]): List of multi-level img features. | |||
| img_metas (list[dict]): List of image info dict where each dict | |||
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
| For details on the values of these keys see | |||
| `mmseg/datasets/pipelines/formatting.py:Collect`. | |||
| test_cfg (dict): The testing config. | |||
| Returns: | |||
| Tensor: Output segmentation map. | |||
| """ | |||
| return self.forward(inputs) | |||
| def cls_seg(self, feat): | |||
| """Classify each pixel.""" | |||
| if self.dropout is not None: | |||
| feat = self.dropout(feat) | |||
| output = self.conv_seg(feat) | |||
| return output | |||
| @force_fp32(apply_to=('seg_logit', )) | |||
| def losses(self, seg_logit, seg_label): | |||
| """Compute segmentation loss.""" | |||
| loss = dict() | |||
| seg_logit = seg_resize( | |||
| input=seg_logit, | |||
| size=seg_label.shape[2:], | |||
| mode='bilinear', | |||
| align_corners=self.align_corners) | |||
| if self.sampler is not None: | |||
| seg_weight = self.sampler.sample(seg_logit, seg_label) | |||
| else: | |||
| seg_weight = None | |||
| seg_label = seg_label.squeeze(1) | |||
| if not isinstance(self.loss_decode, nn.ModuleList): | |||
| losses_decode = [self.loss_decode] | |||
| else: | |||
| losses_decode = self.loss_decode | |||
| for loss_decode in losses_decode: | |||
| if loss_decode.loss_name not in loss: | |||
| loss[loss_decode.loss_name] = loss_decode( | |||
| seg_logit, | |||
| seg_label, | |||
| weight=seg_weight, | |||
| ignore_index=self.ignore_index) | |||
| else: | |||
| loss[loss_decode.loss_name] += loss_decode( | |||
| seg_logit, | |||
| seg_label, | |||
| weight=seg_weight, | |||
| ignore_index=self.ignore_index) | |||
| loss['acc_seg'] = accuracy( | |||
| seg_logit, seg_label, ignore_index=self.ignore_index) | |||
| return loss | |||
| @@ -0,0 +1,581 @@ | |||
| # The implementation refers to the VitAdapter | |||
| # available at | |||
| # https://github.com/czczup/ViT-Adapter.git | |||
| import copy | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init | |||
| from mmcv.cnn.bricks.transformer import (build_positional_encoding, | |||
| build_transformer_layer_sequence) | |||
| from mmcv.ops import point_sample | |||
| from mmcv.runner import ModuleList, force_fp32 | |||
| from mmdet.core import build_assigner, build_sampler, multi_apply, reduce_mean | |||
| from mmdet.models.builder import HEADS, build_loss | |||
| from mmdet.models.utils import get_uncertain_point_coords_with_randomness | |||
| from .base_decode_head import BaseDecodeHead | |||
| @HEADS.register_module() | |||
| class Mask2FormerHeadFromMMSeg(BaseDecodeHead): | |||
| """Implements the Mask2Former head. | |||
| See `Masked-attention Mask Transformer for Universal Image | |||
| Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details. | |||
| Args: | |||
| in_channels (list[int]): Number of channels in the input feature map. | |||
| feat_channels (int): Number of channels for features. | |||
| out_channels (int): Number of channels for output. | |||
| num_things_classes (int): Number of things. | |||
| num_stuff_classes (int): Number of stuff. | |||
| num_queries (int): Number of query in Transformer decoder. | |||
| pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel | |||
| decoder. Defaults to None. | |||
| enforce_decoder_input_project (bool, optional): Whether to add | |||
| a layer to change the embed_dim of tranformer encoder in | |||
| pixel decoder to the embed_dim of transformer decoder. | |||
| Defaults to False. | |||
| transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for | |||
| transformer decoder. Defaults to None. | |||
| positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for | |||
| transformer decoder position encoding. Defaults to None. | |||
| loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification | |||
| loss. Defaults to None. | |||
| loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss. | |||
| Defaults to None. | |||
| loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss. | |||
| Defaults to None. | |||
| train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of | |||
| Mask2Former head. | |||
| test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of | |||
| Mask2Former head. | |||
| init_cfg (dict or list[dict], optional): Initialization config dict. | |||
| Defaults to None. | |||
| """ | |||
| def __init__(self, | |||
| in_channels, | |||
| feat_channels, | |||
| out_channels, | |||
| num_things_classes=80, | |||
| num_stuff_classes=53, | |||
| num_queries=100, | |||
| num_transformer_feat_level=3, | |||
| pixel_decoder=None, | |||
| enforce_decoder_input_project=False, | |||
| transformer_decoder=None, | |||
| positional_encoding=None, | |||
| loss_cls=None, | |||
| loss_mask=None, | |||
| loss_dice=None, | |||
| train_cfg=None, | |||
| test_cfg=None, | |||
| init_cfg=None, | |||
| **kwargs): | |||
| super(Mask2FormerHeadFromMMSeg, self).__init__( | |||
| in_channels=in_channels, | |||
| channels=feat_channels, | |||
| num_classes=(num_things_classes + num_stuff_classes), | |||
| init_cfg=init_cfg, | |||
| input_transform='multiple_select', | |||
| **kwargs) | |||
| self.num_things_classes = num_things_classes | |||
| self.num_stuff_classes = num_stuff_classes | |||
| self.num_classes = self.num_things_classes + self.num_stuff_classes | |||
| self.num_queries = num_queries | |||
| self.num_transformer_feat_level = num_transformer_feat_level | |||
| self.num_heads = transformer_decoder.transformerlayers. \ | |||
| attn_cfgs.num_heads | |||
| self.num_transformer_decoder_layers = transformer_decoder.num_layers | |||
| assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level | |||
| pixel_decoder_ = copy.deepcopy(pixel_decoder) | |||
| pixel_decoder_.update( | |||
| in_channels=in_channels, | |||
| feat_channels=feat_channels, | |||
| out_channels=out_channels) | |||
| self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1] | |||
| self.transformer_decoder = build_transformer_layer_sequence( | |||
| transformer_decoder) | |||
| self.decoder_embed_dims = self.transformer_decoder.embed_dims | |||
| self.decoder_input_projs = ModuleList() | |||
| # from low resolution to high resolution | |||
| for _ in range(num_transformer_feat_level): | |||
| if (self.decoder_embed_dims != feat_channels | |||
| or enforce_decoder_input_project): | |||
| self.decoder_input_projs.append( | |||
| Conv2d( | |||
| feat_channels, self.decoder_embed_dims, kernel_size=1)) | |||
| else: | |||
| self.decoder_input_projs.append(nn.Identity()) | |||
| self.decoder_positional_encoding = build_positional_encoding( | |||
| positional_encoding) | |||
| self.query_embed = nn.Embedding(self.num_queries, feat_channels) | |||
| self.query_feat = nn.Embedding(self.num_queries, feat_channels) | |||
| # from low resolution to high resolution | |||
| self.level_embed = nn.Embedding(self.num_transformer_feat_level, | |||
| feat_channels) | |||
| self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1) | |||
| self.mask_embed = nn.Sequential( | |||
| nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |||
| nn.Linear(feat_channels, feat_channels), nn.ReLU(inplace=True), | |||
| nn.Linear(feat_channels, out_channels)) | |||
| self.conv_seg = None # fix a bug here (conv_seg is not used) | |||
| self.test_cfg = test_cfg | |||
| self.train_cfg = train_cfg | |||
| if train_cfg: | |||
| self.assigner = build_assigner(self.train_cfg.assigner) | |||
| self.sampler = build_sampler(self.train_cfg.sampler, context=self) | |||
| self.num_points = self.train_cfg.get('num_points', 12544) | |||
| self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) | |||
| self.importance_sample_ratio = self.train_cfg.get( | |||
| 'importance_sample_ratio', 0.75) | |||
| self.class_weight = loss_cls.class_weight | |||
| self.loss_cls = build_loss(loss_cls) | |||
| self.loss_mask = build_loss(loss_mask) | |||
| self.loss_dice = build_loss(loss_dice) | |||
| def init_weights(self): | |||
| for m in self.decoder_input_projs: | |||
| if isinstance(m, Conv2d): | |||
| caffe2_xavier_init(m, bias=0) | |||
| self.pixel_decoder.init_weights() | |||
| for p in self.transformer_decoder.parameters(): | |||
| if p.dim() > 1: | |||
| nn.init.xavier_normal_(p) | |||
| def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, | |||
| gt_masks_list, img_metas): | |||
| """Compute classification and mask targets for all images for a decoder | |||
| layer. | |||
| Args: | |||
| cls_scores_list (list[Tensor]): Mask score logits from a single | |||
| decoder layer for all images. Each with shape [num_queries, | |||
| cls_out_channels]. | |||
| mask_preds_list (list[Tensor]): Mask logits from a single decoder | |||
| layer for all images. Each with shape [num_queries, h, w]. | |||
| gt_labels_list (list[Tensor]): Ground truth class indices for all | |||
| images. Each with shape (n, ), n is the sum of number of stuff | |||
| type and number of instance in a image. | |||
| gt_masks_list (list[Tensor]): Ground truth mask for each image, | |||
| each with shape (n, h, w). | |||
| img_metas (list[dict]): List of image meta information. | |||
| Returns: | |||
| tuple[list[Tensor]]: a tuple containing the following targets. | |||
| - labels_list (list[Tensor]): Labels of all images. | |||
| Each with shape [num_queries, ]. | |||
| - label_weights_list (list[Tensor]): Label weights of all | |||
| images.Each with shape [num_queries, ]. | |||
| - mask_targets_list (list[Tensor]): Mask targets of all images. | |||
| Each with shape [num_queries, h, w]. | |||
| - mask_weights_list (list[Tensor]): Mask weights of all images. | |||
| Each with shape [num_queries, ]. | |||
| - num_total_pos (int): Number of positive samples in all | |||
| images. | |||
| - num_total_neg (int): Number of negative samples in all | |||
| images. | |||
| """ | |||
| (labels_list, label_weights_list, mask_targets_list, mask_weights_list, | |||
| pos_inds_list, | |||
| neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list, | |||
| mask_preds_list, gt_labels_list, | |||
| gt_masks_list, img_metas) | |||
| num_total_pos = sum((inds.numel() for inds in pos_inds_list)) | |||
| num_total_neg = sum((inds.numel() for inds in neg_inds_list)) | |||
| return (labels_list, label_weights_list, mask_targets_list, | |||
| mask_weights_list, num_total_pos, num_total_neg) | |||
| def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, | |||
| img_metas): | |||
| """Compute classification and mask targets for one image. | |||
| Args: | |||
| cls_score (Tensor): Mask score logits from a single decoder layer | |||
| for one image. Shape (num_queries, cls_out_channels). | |||
| mask_pred (Tensor): Mask logits for a single decoder layer for one | |||
| image. Shape (num_queries, h, w). | |||
| gt_labels (Tensor): Ground truth class indices for one image with | |||
| shape (num_gts, ). | |||
| gt_masks (Tensor): Ground truth mask for each image, each with | |||
| shape (num_gts, h, w). | |||
| img_metas (dict): Image informtation. | |||
| Returns: | |||
| tuple[Tensor]: A tuple containing the following for one image. | |||
| - labels (Tensor): Labels of each image. \ | |||
| shape (num_queries, ). | |||
| - label_weights (Tensor): Label weights of each image. \ | |||
| shape (num_queries, ). | |||
| - mask_targets (Tensor): Mask targets of each image. \ | |||
| shape (num_queries, h, w). | |||
| - mask_weights (Tensor): Mask weights of each image. \ | |||
| shape (num_queries, ). | |||
| - pos_inds (Tensor): Sampled positive indices for each \ | |||
| image. | |||
| - neg_inds (Tensor): Sampled negative indices for each \ | |||
| image. | |||
| """ | |||
| # sample points | |||
| num_queries = cls_score.shape[0] | |||
| num_gts = gt_labels.shape[0] | |||
| point_coords = torch.rand((1, self.num_points, 2), | |||
| device=cls_score.device) | |||
| # shape (num_queries, num_points) | |||
| mask_points_pred = point_sample( | |||
| mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, | |||
| 1)).squeeze(1) | |||
| # shape (num_gts, num_points) | |||
| gt_points_masks = point_sample( | |||
| gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, | |||
| 1)).squeeze(1) | |||
| # assign and sample | |||
| assign_result = self.assigner.assign(cls_score, mask_points_pred, | |||
| gt_labels, gt_points_masks, | |||
| img_metas) | |||
| sampling_result = self.sampler.sample(assign_result, mask_pred, | |||
| gt_masks) | |||
| pos_inds = sampling_result.pos_inds | |||
| neg_inds = sampling_result.neg_inds | |||
| # label target | |||
| labels = gt_labels.new_full((self.num_queries, ), | |||
| self.num_classes, | |||
| dtype=torch.long) | |||
| labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] | |||
| label_weights = gt_labels.new_ones((self.num_queries, )) | |||
| # mask target | |||
| mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds] | |||
| mask_weights = mask_pred.new_zeros((self.num_queries, )) | |||
| mask_weights[pos_inds] = 1.0 | |||
| return (labels, label_weights, mask_targets, mask_weights, pos_inds, | |||
| neg_inds) | |||
| def loss_single(self, cls_scores, mask_preds, gt_labels_list, | |||
| gt_masks_list, img_metas): | |||
| """Loss function for outputs from a single decoder layer. | |||
| Args: | |||
| cls_scores (Tensor): Mask score logits from a single decoder layer | |||
| for all images. Shape (batch_size, num_queries, | |||
| cls_out_channels). Note `cls_out_channels` should includes | |||
| background. | |||
| mask_preds (Tensor): Mask logits for a pixel decoder for all | |||
| images. Shape (batch_size, num_queries, h, w). | |||
| gt_labels_list (list[Tensor]): Ground truth class indices for each | |||
| image, each with shape (num_gts, ). | |||
| gt_masks_list (list[Tensor]): Ground truth mask for each image, | |||
| each with shape (num_gts, h, w). | |||
| img_metas (list[dict]): List of image meta information. | |||
| Returns: | |||
| tuple[Tensor]: Loss components for outputs from a single \ | |||
| decoder layer. | |||
| """ | |||
| num_imgs = cls_scores.size(0) | |||
| cls_scores_list = [cls_scores[i] for i in range(num_imgs)] | |||
| mask_preds_list = [mask_preds[i] for i in range(num_imgs)] | |||
| (labels_list, label_weights_list, mask_targets_list, mask_weights_list, | |||
| num_total_pos, | |||
| num_total_neg) = self.get_targets(cls_scores_list, mask_preds_list, | |||
| gt_labels_list, gt_masks_list, | |||
| img_metas) | |||
| # shape (batch_size, num_queries) | |||
| labels = torch.stack(labels_list, dim=0) | |||
| # shape (batch_size, num_queries) | |||
| label_weights = torch.stack(label_weights_list, dim=0) | |||
| # shape (num_total_gts, h, w) | |||
| mask_targets = torch.cat(mask_targets_list, dim=0) | |||
| # shape (batch_size, num_queries) | |||
| mask_weights = torch.stack(mask_weights_list, dim=0) | |||
| # classfication loss | |||
| # shape (batch_size * num_queries, ) | |||
| cls_scores = cls_scores.flatten(0, 1) | |||
| labels = labels.flatten(0, 1) | |||
| label_weights = label_weights.flatten(0, 1) | |||
| class_weight = cls_scores.new_tensor(self.class_weight) | |||
| loss_cls = self.loss_cls( | |||
| cls_scores, | |||
| labels, | |||
| label_weights, | |||
| avg_factor=class_weight[labels].sum()) | |||
| num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos])) | |||
| num_total_masks = max(num_total_masks, 1) | |||
| # extract positive ones | |||
| # shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w) | |||
| mask_preds = mask_preds[mask_weights > 0] | |||
| if mask_targets.shape[0] == 0: | |||
| # zero match | |||
| loss_dice = mask_preds.sum() | |||
| loss_mask = mask_preds.sum() | |||
| return loss_cls, loss_mask, loss_dice | |||
| with torch.no_grad(): | |||
| points_coords = get_uncertain_point_coords_with_randomness( | |||
| mask_preds.unsqueeze(1), None, self.num_points, | |||
| self.oversample_ratio, self.importance_sample_ratio) | |||
| # shape (num_total_gts, h, w) -> (num_total_gts, num_points) | |||
| mask_point_targets = point_sample( | |||
| mask_targets.unsqueeze(1).float(), points_coords).squeeze(1) | |||
| # shape (num_queries, h, w) -> (num_queries, num_points) | |||
| mask_point_preds = point_sample( | |||
| mask_preds.unsqueeze(1), points_coords).squeeze(1) | |||
| # dice loss | |||
| loss_dice = self.loss_dice( | |||
| mask_point_preds, mask_point_targets, avg_factor=num_total_masks) | |||
| # mask loss | |||
| # shape (num_queries, num_points) -> (num_queries * num_points, ) | |||
| mask_point_preds = mask_point_preds.reshape(-1, 1) | |||
| # shape (num_total_gts, num_points) -> (num_total_gts * num_points, ) | |||
| mask_point_targets = mask_point_targets.reshape(-1) | |||
| loss_mask = self.loss_mask( | |||
| mask_point_preds, | |||
| mask_point_targets, | |||
| avg_factor=num_total_masks * self.num_points) | |||
| return loss_cls, loss_mask, loss_dice | |||
| @force_fp32(apply_to=('all_cls_scores', 'all_mask_preds')) | |||
| def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, | |||
| gt_masks_list, img_metas): | |||
| """Loss function. | |||
| Args: | |||
| all_cls_scores (Tensor): Classification scores for all decoder | |||
| layers with shape [num_decoder, batch_size, num_queries, | |||
| cls_out_channels]. | |||
| all_mask_preds (Tensor): Mask scores for all decoder layers with | |||
| shape [num_decoder, batch_size, num_queries, h, w]. | |||
| gt_labels_list (list[Tensor]): Ground truth class indices for each | |||
| image with shape (n, ). n is the sum of number of stuff type | |||
| and number of instance in a image. | |||
| gt_masks_list (list[Tensor]): Ground truth mask for each image with | |||
| shape (n, h, w). | |||
| img_metas (list[dict]): List of image meta information. | |||
| Returns: | |||
| dict[str, Tensor]: A dictionary of loss components. | |||
| """ | |||
| num_dec_layers = len(all_cls_scores) | |||
| all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] | |||
| all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)] | |||
| img_metas_list = [img_metas for _ in range(num_dec_layers)] | |||
| losses_cls, losses_mask, losses_dice = multi_apply( | |||
| self.loss_single, all_cls_scores, all_mask_preds, | |||
| all_gt_labels_list, all_gt_masks_list, img_metas_list) | |||
| loss_dict = dict() | |||
| # loss from the last decoder layer | |||
| loss_dict['loss_cls'] = losses_cls[-1] | |||
| loss_dict['loss_mask'] = losses_mask[-1] | |||
| loss_dict['loss_dice'] = losses_dice[-1] | |||
| # loss from other decoder layers | |||
| num_dec_layer = 0 | |||
| for loss_cls_i, loss_mask_i, loss_dice_i in zip( | |||
| losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]): | |||
| loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i | |||
| loss_dict[f'd{num_dec_layer}.loss_mask'] = loss_mask_i | |||
| loss_dict[f'd{num_dec_layer}.loss_dice'] = loss_dice_i | |||
| num_dec_layer += 1 | |||
| return loss_dict | |||
| def forward_head(self, decoder_out, mask_feature, attn_mask_target_size): | |||
| """Forward for head part which is called after every decoder layer. | |||
| Args: | |||
| decoder_out (Tensor): in shape (num_queries, batch_size, c). | |||
| mask_feature (Tensor): in shape (batch_size, c, h, w). | |||
| attn_mask_target_size (tuple[int, int]): target attention | |||
| mask size. | |||
| Returns: | |||
| tuple: A tuple contain three elements. | |||
| - cls_pred (Tensor): Classification scores in shape \ | |||
| (batch_size, num_queries, cls_out_channels). \ | |||
| Note `cls_out_channels` should includes background. | |||
| - mask_pred (Tensor): Mask scores in shape \ | |||
| (batch_size, num_queries,h, w). | |||
| - attn_mask (Tensor): Attention mask in shape \ | |||
| (batch_size * num_heads, num_queries, h, w). | |||
| """ | |||
| decoder_out = self.transformer_decoder.post_norm(decoder_out) | |||
| decoder_out = decoder_out.transpose(0, 1) | |||
| # shape (num_queries, batch_size, c) | |||
| cls_pred = self.cls_embed(decoder_out) | |||
| # shape (num_queries, batch_size, c) | |||
| mask_embed = self.mask_embed(decoder_out) | |||
| # shape (num_queries, batch_size, h, w) | |||
| mask_pred = torch.einsum('bqc,bchw->bqhw', mask_embed, mask_feature) | |||
| attn_mask = F.interpolate( | |||
| mask_pred, | |||
| attn_mask_target_size, | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| # shape (num_queries, batch_size, h, w) -> | |||
| # (batch_size * num_head, num_queries, h, w) | |||
| attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat( | |||
| (1, self.num_heads, 1, 1)).flatten(0, 1) | |||
| attn_mask = attn_mask.sigmoid() < 0.5 | |||
| attn_mask = attn_mask.detach() | |||
| return cls_pred, mask_pred, attn_mask | |||
| def forward(self, feats, img_metas): | |||
| """Forward function. | |||
| Args: | |||
| feats (list[Tensor]): Multi scale Features from the | |||
| upstream network, each is a 4D-tensor. | |||
| img_metas (list[dict]): List of image information. | |||
| Returns: | |||
| tuple: A tuple contains two elements. | |||
| - cls_pred_list (list[Tensor)]: Classification logits \ | |||
| for each decoder layer. Each is a 3D-tensor with shape \ | |||
| (batch_size, num_queries, cls_out_channels). \ | |||
| Note `cls_out_channels` should includes background. | |||
| - mask_pred_list (list[Tensor]): Mask logits for each \ | |||
| decoder layer. Each with shape (batch_size, num_queries, \ | |||
| h, w). | |||
| """ | |||
| batch_size = len(img_metas) | |||
| mask_features, multi_scale_memorys = self.pixel_decoder(feats) | |||
| # multi_scale_memorys (from low resolution to high resolution) | |||
| decoder_inputs = [] | |||
| decoder_positional_encodings = [] | |||
| for i in range(self.num_transformer_feat_level): | |||
| decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i]) | |||
| # shape (batch_size, c, h, w) -> (h*w, batch_size, c) | |||
| decoder_input = decoder_input.flatten(2).permute(2, 0, 1) | |||
| level_embed = self.level_embed.weight[i].view(1, 1, -1) | |||
| decoder_input = decoder_input + level_embed | |||
| # shape (batch_size, c, h, w) -> (h*w, batch_size, c) | |||
| mask = decoder_input.new_zeros( | |||
| (batch_size, ) + multi_scale_memorys[i].shape[-2:], | |||
| dtype=torch.bool) | |||
| decoder_positional_encoding = self.decoder_positional_encoding( | |||
| mask) | |||
| decoder_positional_encoding = decoder_positional_encoding.flatten( | |||
| 2).permute(2, 0, 1) | |||
| decoder_inputs.append(decoder_input) | |||
| decoder_positional_encodings.append(decoder_positional_encoding) | |||
| # shape (num_queries, c) -> (num_queries, batch_size, c) | |||
| query_feat = self.query_feat.weight.unsqueeze(1).repeat( | |||
| (1, batch_size, 1)) | |||
| query_embed = self.query_embed.weight.unsqueeze(1).repeat( | |||
| (1, batch_size, 1)) | |||
| cls_pred_list = [] | |||
| mask_pred_list = [] | |||
| cls_pred, mask_pred, attn_mask = self.forward_head( | |||
| query_feat, mask_features, multi_scale_memorys[0].shape[-2:]) | |||
| cls_pred_list.append(cls_pred) | |||
| mask_pred_list.append(mask_pred) | |||
| for i in range(self.num_transformer_decoder_layers): | |||
| level_idx = i % self.num_transformer_feat_level | |||
| # if a mask is all True(all background), then set it all False. | |||
| attn_mask[torch.where( | |||
| attn_mask.sum(-1) == attn_mask.shape[-1])] = False | |||
| # cross_attn + self_attn | |||
| layer = self.transformer_decoder.layers[i] | |||
| attn_masks = [attn_mask, None] | |||
| query_feat = layer( | |||
| query=query_feat, | |||
| key=decoder_inputs[level_idx], | |||
| value=decoder_inputs[level_idx], | |||
| query_pos=query_embed, | |||
| key_pos=decoder_positional_encodings[level_idx], | |||
| attn_masks=attn_masks, | |||
| query_key_padding_mask=None, | |||
| # here we do not apply masking on padded region | |||
| key_padding_mask=None) | |||
| cls_pred, mask_pred, attn_mask = self.forward_head( | |||
| query_feat, mask_features, multi_scale_memorys[ | |||
| (i + 1) % self.num_transformer_feat_level].shape[-2:]) | |||
| cls_pred_list.append(cls_pred) | |||
| mask_pred_list.append(mask_pred) | |||
| return cls_pred_list, mask_pred_list | |||
| def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, | |||
| gt_masks): | |||
| """Forward function for training mode. | |||
| Args: | |||
| x (list[Tensor]): Multi-level features from the upstream network, | |||
| each is a 4D-tensor. | |||
| img_metas (list[Dict]): List of image information. | |||
| gt_semantic_seg (list[tensor]):Each element is the ground truth | |||
| of semantic segmentation with the shape (N, H, W). | |||
| train_cfg (dict): The training config, which not been used in | |||
| maskformer. | |||
| gt_labels (list[Tensor]): Each element is ground truth labels of | |||
| each box, shape (num_gts,). | |||
| gt_masks (list[BitmapMasks]): Each element is masks of instances | |||
| of a image, shape (num_gts, h, w). | |||
| Returns: | |||
| losses (dict[str, Tensor]): a dictionary of loss components | |||
| """ | |||
| # forward | |||
| all_cls_scores, all_mask_preds = self(x, img_metas) | |||
| # loss | |||
| losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, | |||
| img_metas) | |||
| return losses | |||
| def forward_test(self, inputs, img_metas, test_cfg): | |||
| """Test segment without test-time aumengtation. | |||
| Only the output of last decoder layers was used. | |||
| Args: | |||
| inputs (list[Tensor]): Multi-level features from the | |||
| upstream network, each is a 4D-tensor. | |||
| img_metas (list[dict]): List of image information. | |||
| test_cfg (dict): Testing config. | |||
| Returns: | |||
| seg_mask (Tensor): Predicted semantic segmentation logits. | |||
| """ | |||
| all_cls_scores, all_mask_preds = self(inputs, img_metas) | |||
| cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1] | |||
| ori_h, ori_w, _ = img_metas[0]['ori_shape'] | |||
| # semantic inference | |||
| cls_score = F.softmax(cls_score, dim=-1)[..., :-1] | |||
| mask_pred = mask_pred.sigmoid() | |||
| seg_mask = torch.einsum('bqc,bqhw->bchw', cls_score, mask_pred) | |||
| return seg_mask | |||
| @@ -0,0 +1,3 @@ | |||
| from .encoder_decoder_mask2former import EncoderDecoderMask2Former | |||
| __all__ = ['EncoderDecoderMask2Former'] | |||
| @@ -0,0 +1,314 @@ | |||
| # The implementation refers to the VitAdapter | |||
| # available at | |||
| # https://github.com/czczup/ViT-Adapter.git | |||
| import warnings | |||
| from abc import ABCMeta, abstractmethod | |||
| from collections import OrderedDict | |||
| import mmcv | |||
| import numpy as np | |||
| import torch | |||
| import torch.distributed as dist | |||
| from mmcv.runner import BaseModule, auto_fp16 | |||
| class BaseSegmentor(BaseModule, metaclass=ABCMeta): | |||
| """Base class for segmentors.""" | |||
| def __init__(self, init_cfg=None): | |||
| super(BaseSegmentor, self).__init__(init_cfg) | |||
| self.fp16_enabled = False | |||
| @property | |||
| def with_neck(self): | |||
| """bool: whether the segmentor has neck""" | |||
| return hasattr(self, 'neck') and self.neck is not None | |||
| @property | |||
| def with_auxiliary_head(self): | |||
| """bool: whether the segmentor has auxiliary head""" | |||
| return hasattr(self, | |||
| 'auxiliary_head') and self.auxiliary_head is not None | |||
| @property | |||
| def with_decode_head(self): | |||
| """bool: whether the segmentor has decode head""" | |||
| return hasattr(self, 'decode_head') and self.decode_head is not None | |||
| @abstractmethod | |||
| def extract_feat(self, imgs): | |||
| """Placeholder for extract features from images.""" | |||
| pass | |||
| @abstractmethod | |||
| def encode_decode(self, img, img_metas): | |||
| """Placeholder for encode images with backbone and decode into a | |||
| semantic segmentation map of the same size as input.""" | |||
| pass | |||
| @abstractmethod | |||
| def forward_train(self, imgs, img_metas, **kwargs): | |||
| """Placeholder for Forward function for training.""" | |||
| pass | |||
| @abstractmethod | |||
| def simple_test(self, img, img_meta, **kwargs): | |||
| """Placeholder for single image test.""" | |||
| pass | |||
| @abstractmethod | |||
| def aug_test(self, imgs, img_metas, **kwargs): | |||
| """Placeholder for augmentation test.""" | |||
| pass | |||
| def forward_test(self, imgs, img_metas, **kwargs): | |||
| """ | |||
| Args: | |||
| imgs (List[Tensor]): the outer list indicates test-time | |||
| augmentations and inner Tensor should have a shape NxCxHxW, | |||
| which contains all images in the batch. | |||
| img_metas (List[List[dict]]): the outer list indicates test-time | |||
| augs (multiscale, flip, etc.) and the inner list indicates | |||
| images in a batch. | |||
| """ | |||
| for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: | |||
| if not isinstance(var, list): | |||
| raise TypeError(f'{name} must be a list, but got ' | |||
| f'{type(var)}') | |||
| num_augs = len(imgs) | |||
| if num_augs != len(img_metas): | |||
| raise ValueError(f'num of augmentations ({len(imgs)}) != ' | |||
| f'num of image meta ({len(img_metas)})') | |||
| # all images in the same aug batch all of the same ori_shape and pad | |||
| # shape | |||
| def tensor_to_tuple(input_tensor): | |||
| return tuple(input_tensor.cpu().numpy()) | |||
| for img_meta in img_metas: | |||
| ori_shapes = [_['ori_shape'] for _ in img_meta] | |||
| if isinstance(ori_shapes[0], torch.Tensor): | |||
| assert all( | |||
| tensor_to_tuple(shape) == tensor_to_tuple(ori_shapes[0]) | |||
| for shape in ori_shapes) | |||
| else: | |||
| assert all(shape == ori_shapes[0] for shape in ori_shapes) | |||
| img_shapes = [_['img_shape'] for _ in img_meta] | |||
| if isinstance(img_shapes[0], torch.Tensor): | |||
| assert all( | |||
| tensor_to_tuple(shape) == tensor_to_tuple(img_shapes[0]) | |||
| for shape in img_shapes) | |||
| else: | |||
| assert all(shape == img_shapes[0] for shape in img_shapes) | |||
| pad_shapes = [_['pad_shape'] for _ in img_meta] | |||
| if isinstance(pad_shapes[0], torch.Tensor): | |||
| assert all( | |||
| tensor_to_tuple(shape) == tensor_to_tuple(pad_shapes[0]) | |||
| for shape in pad_shapes) | |||
| else: | |||
| assert all(shape == pad_shapes[0] for shape in pad_shapes) | |||
| if num_augs == 1: | |||
| return self.simple_test(imgs[0], img_metas[0], **kwargs) | |||
| else: | |||
| return self.aug_test(imgs, img_metas, **kwargs) | |||
| @auto_fp16(apply_to=('img', )) | |||
| def forward(self, img, img_metas, return_loss=True, **kwargs): | |||
| """Calls either :func:`forward_train` or :func:`forward_test` depending | |||
| on whether ``return_loss`` is ``True``. | |||
| Note this setting will change the expected inputs. When | |||
| ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor | |||
| and List[dict]), and when ``resturn_loss=False``, img and img_meta | |||
| should be double nested (i.e. List[Tensor], List[List[dict]]), with | |||
| the outer list indicating test time augmentations. | |||
| """ | |||
| if return_loss: | |||
| return self.forward_train(img, img_metas, **kwargs) | |||
| else: | |||
| return self.forward_test(img, img_metas, **kwargs) | |||
| def train_step(self, data_batch, optimizer, **kwargs): | |||
| """The iteration step during training. | |||
| This method defines an iteration step during training, except for the | |||
| back propagation and optimizer updating, which are done in an optimizer | |||
| hook. Note that in some complicated cases or models, the whole process | |||
| including back propagation and optimizer updating is also defined in | |||
| this method, such as GAN. | |||
| Args: | |||
| data (dict): The output of dataloader. | |||
| optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of | |||
| runner is passed to ``train_step()``. This argument is unused | |||
| and reserved. | |||
| Returns: | |||
| dict: It should contain at least 3 keys: ``loss``, ``log_vars``, | |||
| ``num_samples``. | |||
| ``loss`` is a tensor for back propagation, which can be a | |||
| weighted sum of multiple losses. | |||
| ``log_vars`` contains all the variables to be sent to the | |||
| logger. | |||
| ``num_samples`` indicates the batch size (when the model is | |||
| DDP, it means the batch size on each GPU), which is used for | |||
| averaging the logs. | |||
| """ | |||
| losses = self(**data_batch) | |||
| loss, log_vars = self._parse_losses(losses) | |||
| outputs = dict( | |||
| loss=loss, | |||
| log_vars=log_vars, | |||
| num_samples=len(data_batch['img_metas'])) | |||
| return outputs | |||
| def val_step(self, data_batch, optimizer=None, **kwargs): | |||
| """The iteration step during validation. | |||
| This method shares the same signature as :func:`train_step`, but used | |||
| during val epochs. Note that the evaluation after training epochs is | |||
| not implemented with this method, but an evaluation hook. | |||
| """ | |||
| losses = self(**data_batch) | |||
| loss, log_vars = self._parse_losses(losses) | |||
| log_vars_ = dict() | |||
| for loss_name, loss_value in log_vars.items(): | |||
| k = loss_name + '_val' | |||
| log_vars_[k] = loss_value | |||
| outputs = dict( | |||
| loss=loss, | |||
| log_vars=log_vars_, | |||
| num_samples=len(data_batch['img_metas'])) | |||
| return outputs | |||
| @staticmethod | |||
| def _parse_losses(losses): | |||
| """Parse the raw outputs (losses) of the network. | |||
| Args: | |||
| losses (dict): Raw output of the network, which usually contain | |||
| losses and other necessary information. | |||
| Returns: | |||
| tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor | |||
| which may be a weighted sum of all losses, log_vars contains | |||
| all the variables to be sent to the logger. | |||
| """ | |||
| log_vars = OrderedDict() | |||
| for loss_name, loss_value in losses.items(): | |||
| if isinstance(loss_value, torch.Tensor): | |||
| log_vars[loss_name] = loss_value.mean() | |||
| elif isinstance(loss_value, list): | |||
| log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) | |||
| else: | |||
| raise TypeError( | |||
| f'{loss_name} is not a tensor or list of tensors') | |||
| loss = sum(_value for _key, _value in log_vars.items() | |||
| if 'loss' in _key) | |||
| # If the loss_vars has different length, raise assertion error | |||
| # to prevent GPUs from infinite waiting. | |||
| if dist.is_available() and dist.is_initialized(): | |||
| log_var_length = torch.tensor(len(log_vars), device=loss.device) | |||
| dist.all_reduce(log_var_length) | |||
| message = (f'rank {dist.get_rank()}' | |||
| + f' len(log_vars): {len(log_vars)}' + ' keys: ' | |||
| + ','.join(log_vars.keys()) + '\n') | |||
| assert log_var_length == len(log_vars) * dist.get_world_size(), \ | |||
| 'loss log variables are different across GPUs!\n' + message | |||
| log_vars['loss'] = loss | |||
| for loss_name, loss_value in log_vars.items(): | |||
| # reduce loss when distributed training | |||
| if dist.is_available() and dist.is_initialized(): | |||
| loss_value = loss_value.data.clone() | |||
| dist.all_reduce(loss_value.div_(dist.get_world_size())) | |||
| log_vars[loss_name] = loss_value.item() | |||
| return loss, log_vars | |||
| def show_result(self, | |||
| img, | |||
| result, | |||
| palette=None, | |||
| win_name='', | |||
| show=False, | |||
| wait_time=0, | |||
| out_file=None, | |||
| opacity=0.5): | |||
| """Draw `result` over `img`. | |||
| Args: | |||
| img (str or Tensor): The image to be displayed. | |||
| result (Tensor): The semantic segmentation results to draw over | |||
| `img`. | |||
| palette (list[list[int]]] | np.ndarray | None): The palette of | |||
| segmentation map. If None is given, random palette will be | |||
| generated. Default: None | |||
| win_name (str): The window name. | |||
| wait_time (int): Value of waitKey param. | |||
| Default: 0. | |||
| show (bool): Whether to show the image. | |||
| Default: False. | |||
| out_file (str or None): The filename to write the image. | |||
| Default: None. | |||
| opacity(float): Opacity of painted segmentation map. | |||
| Default 0.5. | |||
| Must be in (0, 1] range. | |||
| Returns: | |||
| img (Tensor): Only if not `show` or `out_file` | |||
| """ | |||
| img = mmcv.imread(img) | |||
| img = img.copy() | |||
| seg = result[0] | |||
| if palette is None: | |||
| if self.PALETTE is None: | |||
| # Get random state before set seed, | |||
| # and restore random state later. | |||
| # It will prevent loss of randomness, as the palette | |||
| # may be different in each iteration if not specified. | |||
| # See: https://github.com/open-mmlab/mmdetection/issues/5844 | |||
| state = np.random.get_state() | |||
| np.random.seed(42) | |||
| # random palette | |||
| palette = np.random.randint( | |||
| 0, 255, size=(len(self.CLASSES), 3)) | |||
| np.random.set_state(state) | |||
| else: | |||
| palette = self.PALETTE | |||
| palette = np.array(palette) | |||
| assert palette.shape[0] == len(self.CLASSES) | |||
| assert palette.shape[1] == 3 | |||
| assert len(palette.shape) == 2 | |||
| assert 0 < opacity <= 1.0 | |||
| color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) | |||
| for label, color in enumerate(palette): | |||
| color_seg[seg == label, :] = color | |||
| # convert to BGR | |||
| color_seg = color_seg[..., ::-1] | |||
| img = img * (1 - opacity) + color_seg * opacity | |||
| img = img.astype(np.uint8) | |||
| # if out_file specified, do not show image in window | |||
| if out_file is not None: | |||
| show = False | |||
| if show: | |||
| mmcv.imshow(img, win_name, wait_time) | |||
| if out_file is not None: | |||
| mmcv.imwrite(img, out_file) | |||
| if not (show or out_file): | |||
| warnings.warn('show==False and out_file is not specified, only ' | |||
| 'result image will be returned') | |||
| return img | |||
| @@ -0,0 +1,303 @@ | |||
| # The implementation refers to the VitAdapter | |||
| # available at | |||
| # https://github.com/czczup/ViT-Adapter.git | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from mmdet.models import builder | |||
| from mmdet.models.builder import DETECTORS | |||
| from ...utils import add_prefix, seg_resize | |||
| from .base_segmentor import BaseSegmentor | |||
| @DETECTORS.register_module() | |||
| class EncoderDecoderMask2Former(BaseSegmentor): | |||
| """Encoder Decoder segmentors. | |||
| EncoderDecoder typically consists of backbone, decode_head, auxiliary_head. | |||
| Note that auxiliary_head is only used for deep supervision during training, | |||
| which could be dumped during inference. | |||
| """ | |||
| def __init__(self, | |||
| backbone, | |||
| decode_head, | |||
| neck=None, | |||
| auxiliary_head=None, | |||
| train_cfg=None, | |||
| test_cfg=None, | |||
| pretrained=None, | |||
| init_cfg=None): | |||
| super(EncoderDecoderMask2Former, self).__init__(init_cfg) | |||
| if pretrained is not None: | |||
| assert backbone.get('pretrained') is None, \ | |||
| 'both backbone and segmentor set pretrained weight' | |||
| backbone.pretrained = pretrained | |||
| self.backbone = builder.build_backbone(backbone) | |||
| if neck is not None: | |||
| self.neck = builder.build_neck(neck) | |||
| decode_head.update(train_cfg=train_cfg) | |||
| decode_head.update(test_cfg=test_cfg) | |||
| self._init_decode_head(decode_head) | |||
| self._init_auxiliary_head(auxiliary_head) | |||
| self.train_cfg = train_cfg | |||
| self.test_cfg = test_cfg | |||
| assert self.with_decode_head | |||
| def _init_decode_head(self, decode_head): | |||
| """Initialize ``decode_head``""" | |||
| self.decode_head = builder.build_head(decode_head) | |||
| self.align_corners = self.decode_head.align_corners | |||
| self.num_classes = self.decode_head.num_classes | |||
| def _init_auxiliary_head(self, auxiliary_head): | |||
| """Initialize ``auxiliary_head``""" | |||
| if auxiliary_head is not None: | |||
| if isinstance(auxiliary_head, list): | |||
| self.auxiliary_head = nn.ModuleList() | |||
| for head_cfg in auxiliary_head: | |||
| self.auxiliary_head.append(builder.build_head(head_cfg)) | |||
| else: | |||
| self.auxiliary_head = builder.build_head(auxiliary_head) | |||
| def extract_feat(self, img): | |||
| """Extract features from images.""" | |||
| x = self.backbone(img) | |||
| if self.with_neck: | |||
| x = self.neck(x) | |||
| return x | |||
| def encode_decode(self, img, img_metas): | |||
| """Encode images with backbone and decode into a semantic segmentation | |||
| map of the same size as input.""" | |||
| x = self.extract_feat(img) | |||
| out = self._decode_head_forward_test(x, img_metas) | |||
| out = seg_resize( | |||
| input=out, | |||
| size=img.shape[2:], | |||
| mode='bilinear', | |||
| align_corners=self.align_corners) | |||
| return out | |||
| def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, | |||
| **kwargs): | |||
| """Run forward function and calculate loss for decode head in | |||
| training.""" | |||
| losses = dict() | |||
| loss_decode = self.decode_head.forward_train(x, img_metas, | |||
| gt_semantic_seg, **kwargs) | |||
| losses.update(add_prefix(loss_decode, 'decode')) | |||
| return losses | |||
| def _decode_head_forward_test(self, x, img_metas): | |||
| """Run forward function and calculate loss for decode head in | |||
| inference.""" | |||
| seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg) | |||
| return seg_logits | |||
| def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg): | |||
| """Run forward function and calculate loss for auxiliary head in | |||
| training.""" | |||
| losses = dict() | |||
| if isinstance(self.auxiliary_head, nn.ModuleList): | |||
| for idx, aux_head in enumerate(self.auxiliary_head): | |||
| loss_aux = aux_head.forward_train(x, img_metas, | |||
| gt_semantic_seg, | |||
| self.train_cfg) | |||
| losses.update(add_prefix(loss_aux, f'aux_{idx}')) | |||
| else: | |||
| loss_aux = self.auxiliary_head.forward_train( | |||
| x, img_metas, gt_semantic_seg, self.train_cfg) | |||
| losses.update(add_prefix(loss_aux, 'aux')) | |||
| return losses | |||
| def forward_dummy(self, img): | |||
| """Dummy forward function.""" | |||
| seg_logit = self.encode_decode(img, None) | |||
| return seg_logit | |||
| def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs): | |||
| """Forward function for training. | |||
| Args: | |||
| img (Tensor): Input images. | |||
| img_metas (list[dict]): List of image info dict where each dict | |||
| has: 'img_shape', 'scale_factor', 'flip', and may also contain | |||
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
| For details on the values of these keys see | |||
| `mmseg/datasets/pipelines/formatting.py:Collect`. | |||
| gt_semantic_seg (Tensor): Semantic segmentation masks | |||
| used if the architecture supports semantic segmentation task. | |||
| Returns: | |||
| dict[str, Tensor]: a dictionary of loss components | |||
| """ | |||
| x = self.extract_feat(img) | |||
| losses = dict() | |||
| loss_decode = self._decode_head_forward_train(x, img_metas, | |||
| gt_semantic_seg, | |||
| **kwargs) | |||
| losses.update(loss_decode) | |||
| if self.with_auxiliary_head: | |||
| loss_aux = self._auxiliary_head_forward_train( | |||
| x, img_metas, gt_semantic_seg) | |||
| losses.update(loss_aux) | |||
| return losses | |||
| # TODO refactor | |||
| def slide_inference(self, img, img_meta, rescale): | |||
| """Inference by sliding-window with overlap. | |||
| If h_crop > h_img or w_crop > w_img, the small patch will be used to | |||
| decode without padding. | |||
| """ | |||
| h_stride, w_stride = self.test_cfg.stride | |||
| h_crop, w_crop = self.test_cfg.crop_size | |||
| batch_size, _, h_img, w_img = img.size() | |||
| num_classes = self.num_classes | |||
| h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 | |||
| w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 | |||
| preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) | |||
| count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) | |||
| for h_idx in range(h_grids): | |||
| for w_idx in range(w_grids): | |||
| y1 = h_idx * h_stride | |||
| x1 = w_idx * w_stride | |||
| y2 = min(y1 + h_crop, h_img) | |||
| x2 = min(x1 + w_crop, w_img) | |||
| y1 = max(y2 - h_crop, 0) | |||
| x1 = max(x2 - w_crop, 0) | |||
| crop_img = img[:, :, y1:y2, x1:x2] | |||
| crop_seg_logit = self.encode_decode(crop_img, img_meta) | |||
| preds += F.pad(crop_seg_logit, | |||
| (int(x1), int(preds.shape[3] - x2), int(y1), | |||
| int(preds.shape[2] - y2))) | |||
| count_mat[:, :, y1:y2, x1:x2] += 1 | |||
| assert (count_mat == 0).sum() == 0 | |||
| if torch.onnx.is_in_onnx_export(): | |||
| # cast count_mat to constant while exporting to ONNX | |||
| count_mat = torch.from_numpy( | |||
| count_mat.cpu().detach().numpy()).to(device=img.device) | |||
| preds = preds / count_mat | |||
| def tensor_to_tuple(input_tensor): | |||
| return tuple(input_tensor.cpu().numpy()) | |||
| if rescale: | |||
| preds = seg_resize( | |||
| preds, | |||
| size=tensor_to_tuple(img_meta[0]['ori_shape'])[:2] | |||
| if isinstance(img_meta[0]['ori_shape'], torch.Tensor) else | |||
| img_meta[0]['ori_shape'], | |||
| mode='bilinear', | |||
| align_corners=self.align_corners, | |||
| warning=False) | |||
| return preds | |||
| def whole_inference(self, img, img_meta, rescale): | |||
| """Inference with full image.""" | |||
| seg_logit = self.encode_decode(img, img_meta) | |||
| if rescale: | |||
| # support dynamic shape for onnx | |||
| if torch.onnx.is_in_onnx_export(): | |||
| size = img.shape[2:] | |||
| else: | |||
| size = img_meta[0]['ori_shape'][:2] | |||
| seg_logit = seg_resize( | |||
| seg_logit, | |||
| size=size, | |||
| mode='bilinear', | |||
| align_corners=self.align_corners, | |||
| warning=False) | |||
| return seg_logit | |||
| def inference(self, img, img_meta, rescale): | |||
| """Inference with slide/whole style. | |||
| Args: | |||
| img (Tensor): The input image of shape (N, 3, H, W). | |||
| img_meta (dict): Image info dict where each dict has: 'img_shape', | |||
| 'scale_factor', 'flip', and may also contain | |||
| 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |||
| For details on the values of these keys see | |||
| `mmseg/datasets/pipelines/formatting.py:Collect`. | |||
| rescale (bool): Whether rescale back to original shape. | |||
| Returns: | |||
| Tensor: The output segmentation map. | |||
| """ | |||
| assert self.test_cfg.mode in ['slide', 'whole'] | |||
| ori_shape = img_meta[0]['ori_shape'] | |||
| def tensor_to_tuple(input_tensor): | |||
| return tuple(input_tensor.cpu().numpy()) | |||
| if isinstance(ori_shape, torch.Tensor): | |||
| assert all( | |||
| tensor_to_tuple(_['ori_shape']) == tensor_to_tuple(ori_shape) | |||
| for _ in img_meta) | |||
| else: | |||
| assert all(_['ori_shape'] == ori_shape for _ in img_meta) | |||
| if self.test_cfg.mode == 'slide': | |||
| seg_logit = self.slide_inference(img, img_meta, rescale) | |||
| else: | |||
| seg_logit = self.whole_inference(img, img_meta, rescale) | |||
| output = F.softmax(seg_logit, dim=1) | |||
| flip = img_meta[0]['flip'] | |||
| if flip: | |||
| flip_direction = img_meta[0]['flip_direction'] | |||
| assert flip_direction in ['horizontal', 'vertical'] | |||
| if flip_direction == 'horizontal': | |||
| output = output.flip(dims=(3, )) | |||
| elif flip_direction == 'vertical': | |||
| output = output.flip(dims=(2, )) | |||
| return output | |||
| def simple_test(self, img, img_meta, rescale=True): | |||
| """Simple test with single image.""" | |||
| seg_logit = self.inference(img, img_meta, rescale) | |||
| seg_pred = seg_logit.argmax(dim=1) | |||
| if torch.onnx.is_in_onnx_export(): | |||
| # our inference backend only support 4D output | |||
| seg_pred = seg_pred.unsqueeze(0) | |||
| return seg_pred | |||
| seg_pred = seg_pred.cpu().numpy() | |||
| # unravel batch dim | |||
| seg_pred = list(seg_pred) | |||
| return seg_pred | |||
| def aug_test(self, imgs, img_metas, rescale=True): | |||
| """Test with augmentations. | |||
| Only rescale=True is supported. | |||
| """ | |||
| # aug_test rescale all imgs back to ori_shape for now | |||
| assert rescale | |||
| # to save memory, we get augmented seg logit inplace | |||
| seg_logit = self.inference(imgs[0], img_metas[0], rescale) | |||
| for i in range(1, len(imgs)): | |||
| cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) | |||
| seg_logit += cur_seg_logit | |||
| seg_logit /= len(imgs) | |||
| seg_pred = seg_logit.argmax(dim=1) | |||
| seg_pred = seg_pred.cpu().numpy() | |||
| # unravel batch dim | |||
| seg_pred = list(seg_pred) | |||
| return seg_pred | |||
| @@ -0,0 +1,7 @@ | |||
| from .builder import build_pixel_sampler | |||
| from .data_process_func import ResizeToMultiple | |||
| from .seg_func import add_prefix, seg_resize | |||
| __all__ = [ | |||
| 'seg_resize', 'add_prefix', 'build_pixel_sampler', 'ResizeToMultiple' | |||
| ] | |||
| @@ -0,0 +1,11 @@ | |||
| # The implementation refers to the VitAdapter | |||
| # available at | |||
| # https://github.com/czczup/ViT-Adapter.git | |||
| from mmcv.utils import Registry, build_from_cfg | |||
| PIXEL_SAMPLERS = Registry('pixel sampler') | |||
| def build_pixel_sampler(cfg, **default_args): | |||
| """Build pixel sampler for segmentation map.""" | |||
| return build_from_cfg(cfg, PIXEL_SAMPLERS, default_args) | |||
| @@ -0,0 +1,60 @@ | |||
| # Copyright (c) OpenMMLab. All rights reserved. | |||
| import mmcv | |||
| from mmdet.datasets.builder import PIPELINES | |||
| @PIPELINES.register_module() | |||
| class ResizeToMultiple(object): | |||
| """Resize images & seg to multiple of divisor. | |||
| Args: | |||
| size_divisor (int): images and gt seg maps need to resize to multiple | |||
| of size_divisor. Default: 32. | |||
| interpolation (str, optional): The interpolation mode of image resize. | |||
| Default: None | |||
| """ | |||
| def __init__(self, size_divisor=32, interpolation=None): | |||
| self.size_divisor = size_divisor | |||
| self.interpolation = interpolation | |||
| def __call__(self, results): | |||
| """Call function to resize images, semantic segmentation map to | |||
| multiple of size divisor. | |||
| Args: | |||
| results (dict): Result dict from loading pipeline. | |||
| Returns: | |||
| dict: Resized results, 'img_shape', 'pad_shape' keys are updated. | |||
| """ | |||
| # Align image to multiple of size divisor. | |||
| img = results['img'] | |||
| img = mmcv.imresize_to_multiple( | |||
| img, | |||
| self.size_divisor, | |||
| scale_factor=1, | |||
| interpolation=self.interpolation | |||
| if self.interpolation else 'bilinear') | |||
| results['img'] = img | |||
| results['img_shape'] = img.shape | |||
| results['pad_shape'] = img.shape | |||
| # Align segmentation map to multiple of size divisor. | |||
| for key in results.get('seg_fields', []): | |||
| gt_seg = results[key] | |||
| gt_seg = mmcv.imresize_to_multiple( | |||
| gt_seg, | |||
| self.size_divisor, | |||
| scale_factor=1, | |||
| interpolation='nearest') | |||
| results[key] = gt_seg | |||
| return results | |||
| def __repr__(self): | |||
| repr_str = self.__class__.__name__ | |||
| repr_str += (f'(size_divisor={self.size_divisor}, ' | |||
| f'interpolation={self.interpolation})') | |||
| return repr_str | |||
| @@ -0,0 +1,48 @@ | |||
| # The implementation refers to the VitAdapter | |||
| # available at | |||
| # https://github.com/czczup/ViT-Adapter.git | |||
| import warnings | |||
| import torch.nn.functional as F | |||
| def seg_resize(input, | |||
| size=None, | |||
| scale_factor=None, | |||
| mode='nearest', | |||
| align_corners=None, | |||
| warning=True): | |||
| if warning: | |||
| if size is not None and align_corners: | |||
| input_h, input_w = tuple(int(x) for x in input.shape[2:]) | |||
| output_h, output_w = tuple(int(x) for x in size) | |||
| if output_h > input_h or output_w > input_w: | |||
| if ((output_h > 1 and output_w > 1 and input_h > 1 | |||
| and input_w > 1) and (output_h - 1) % (input_h - 1) | |||
| and (output_w - 1) % (input_w - 1)): | |||
| warnings.warn( | |||
| f'When align_corners={align_corners}, ' | |||
| 'the output would more aligned if ' | |||
| f'input size {(input_h, input_w)} is `x+1` and ' | |||
| f'out size {(output_h, output_w)} is `nx+1`') | |||
| return F.interpolate(input, size, scale_factor, mode, align_corners) | |||
| def add_prefix(inputs, prefix): | |||
| """Add prefix for dict. | |||
| Args: | |||
| inputs (dict): The input dict with str keys. | |||
| prefix (str): The prefix to add. | |||
| Returns: | |||
| dict: The dict with keys updated with ``prefix``. | |||
| """ | |||
| outputs = dict() | |||
| for name, value in inputs.items(): | |||
| outputs[f'{prefix}.{name}'] = value | |||
| return outputs | |||
| @@ -26,6 +26,7 @@ if TYPE_CHECKING: | |||
| from .image_panoptic_segmentation_pipeline import ImagePanopticSegmentationPipeline | |||
| from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline | |||
| from .image_reid_person_pipeline import ImageReidPersonPipeline | |||
| from .image_semantic_segmentation_pipeline import ImageSemanticSegmentationPipeline | |||
| from .image_style_transfer_pipeline import ImageStyleTransferPipeline | |||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
| from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline | |||
| @@ -66,6 +67,8 @@ else: | |||
| 'image_portrait_enhancement_pipeline': | |||
| ['ImagePortraitEnhancementPipeline'], | |||
| 'image_reid_person_pipeline': ['ImageReidPersonPipeline'], | |||
| 'image_semantic_segmentation_pipeline': | |||
| ['ImageSemanticSegmentationPipeline'], | |||
| 'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'], | |||
| 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | |||
| 'image_to_image_translation_pipeline': | |||
| @@ -0,0 +1,95 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import Any, Dict, Union | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Model, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_segmentation, | |||
| module_name=Pipelines.image_semantic_segmentation) | |||
| class ImageSemanticSegmentationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a image semantic segmentation pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| logger.info('semantic segmentation model, pipeline init') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| from mmdet.datasets.pipelines import Compose | |||
| from mmcv.parallel import collate, scatter | |||
| from mmdet.datasets import replace_ImageToTensor | |||
| cfg = self.model.cfg | |||
| # build the data pipeline | |||
| if isinstance(input, str): | |||
| # input is str, file names, pipeline loadimagefromfile | |||
| # collect data | |||
| data = dict(img_info=dict(filename=input), img_prefix=None) | |||
| elif isinstance(input, PIL.Image.Image): # BGR | |||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||
| img = np.array(input)[:, :, ::-1] | |||
| # collect data | |||
| data = dict(img=img) | |||
| elif isinstance(input, np.ndarray): | |||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
| else: | |||
| img = input | |||
| # collect data | |||
| data = dict(img=img) | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| # data = dict(img=input) | |||
| cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | |||
| test_pipeline = Compose(cfg.data.test.pipeline) | |||
| data = test_pipeline(data) | |||
| # copy from mmdet_model collect data | |||
| data = collate([data], samples_per_gpu=1) | |||
| data['img_metas'] = [ | |||
| img_metas.data[0] for img_metas in data['img_metas'] | |||
| ] | |||
| data['img'] = [img.data[0] for img in data['img']] | |||
| if next(self.model.parameters()).is_cuda: | |||
| # scatter to specified GPU | |||
| data = scatter(data, [next(self.model.parameters()).device])[0] | |||
| return data | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| results = self.model.inference(input) | |||
| return results | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| results = self.model.postprocess(inputs) | |||
| outputs = { | |||
| OutputKeys.MASKS: results[OutputKeys.MASKS], | |||
| OutputKeys.LABELS: results[OutputKeys.LABELS], | |||
| OutputKeys.SCORES: results[OutputKeys.SCORES] | |||
| } | |||
| return outputs | |||
| @@ -153,3 +153,16 @@ def panoptic_seg_masks_to_image(masks): | |||
| draw_img[mask] = color_mask | |||
| return draw_img | |||
| def semantic_seg_masks_to_image(masks): | |||
| from mmdet.core.visualization.palette import get_palette | |||
| mask_palette = get_palette('coco', 133) | |||
| draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3]) | |||
| for i, mask in enumerate(masks): | |||
| color_mask = mask_palette[i] | |||
| mask = mask.astype(bool) | |||
| draw_img[mask] = color_mask | |||
| return draw_img | |||
| @@ -0,0 +1,54 @@ | |||
| import unittest | |||
| import cv2 | |||
| import PIL | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.cv.image_utils import semantic_seg_masks_to_image | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import test_level | |||
| class ImageSemanticSegmentationTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_image_semantic_segmentation_panmerge(self): | |||
| input_location = 'data/test/images/image_semantic_segmentation.jpg' | |||
| model_id = 'damo/cv_swinL_semantic-segmentation_cocopanmerge' | |||
| segmenter = pipeline(Tasks.image_segmentation, model=model_id) | |||
| result = segmenter(input_location) | |||
| draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
| cv2.imwrite('result.jpg', draw_img) | |||
| print('test_image_semantic_segmentation_panmerge DONE') | |||
| PIL_array = PIL.Image.open(input_location) | |||
| result = segmenter(PIL_array) | |||
| draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
| cv2.imwrite('result.jpg', draw_img) | |||
| print('test_image_semantic_segmentation_panmerge_from_PIL DONE') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_image_semantic_segmentation_vitadapter(self): | |||
| input_location = 'data/test/images/image_semantic_segmentation.jpg' | |||
| model_id = 'damo/cv_vitadapter_semantic-segmentation_cocostuff164k' | |||
| segmenter = pipeline(Tasks.image_segmentation, model=model_id) | |||
| result = segmenter(input_location) | |||
| draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
| cv2.imwrite('result.jpg', draw_img) | |||
| print('test_image_semantic_segmentation_vitadapter DONE') | |||
| PIL_array = PIL.Image.open(input_location) | |||
| result = segmenter(PIL_array) | |||
| draw_img = semantic_seg_masks_to_image(result[OutputKeys.MASKS]) | |||
| cv2.imwrite('result.jpg', draw_img) | |||
| print('test_image_semantic_segmentation_vitadapter_from_PIL DONE') | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||