Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10189886master
| @@ -1,3 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | import os.path as osp | ||||
| import torch | import torch | ||||
| @@ -49,6 +50,4 @@ class SwinLPanopticSegmentation(TorchModel): | |||||
| return results | return results | ||||
| def forward(self, Inputs): | def forward(self, Inputs): | ||||
| import pdb | |||||
| pdb.set_trace() | |||||
| return self.model(**Inputs) | return self.model(**Inputs) | ||||
| @@ -1 +1,2 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .maskformer_semantic_head import MaskFormerSemanticHead | from .maskformer_semantic_head import MaskFormerSemanticHead | ||||
| @@ -1,3 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | import torch | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from mmdet.models.builder import HEADS | from mmdet.models.builder import HEADS | ||||
| @@ -1,3 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | import os.path as osp | ||||
| import numpy as np | import numpy as np | ||||
| @@ -1,3 +1,5 @@ | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| from .models import backbone, decode_heads, segmentors | from .models import backbone, decode_heads, segmentors | ||||
| from .utils import (ResizeToMultiple, add_prefix, build_pixel_sampler, | from .utils import (ResizeToMultiple, add_prefix, build_pixel_sampler, | ||||
| seg_resize) | seg_resize) | ||||
| @@ -1,3 +1,5 @@ | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| from .backbone import BASEBEiT, BEiTAdapter | from .backbone import BASEBEiT, BEiTAdapter | ||||
| from .decode_heads import Mask2FormerHeadFromMMSeg | from .decode_heads import Mask2FormerHeadFromMMSeg | ||||
| from .segmentors import EncoderDecoderMask2Former | from .segmentors import EncoderDecoderMask2Former | ||||
| @@ -1,3 +1,5 @@ | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| from .base import BASEBEiT | from .base import BASEBEiT | ||||
| from .beit_adapter import BEiTAdapter | from .beit_adapter import BEiTAdapter | ||||
| @@ -1,6 +1,5 @@ | |||||
| # The implementation refers to the VitAdapter | |||||
| # available at | |||||
| # https://github.com/czczup/ViT-Adapter.git | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| import logging | import logging | ||||
| from functools import partial | from functools import partial | ||||
| @@ -417,7 +416,7 @@ class SpatialPriorModule(nn.Module): | |||||
| self.stem = nn.Sequential(*[ | self.stem = nn.Sequential(*[ | ||||
| nn.Conv2d( | nn.Conv2d( | ||||
| 3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), | 3, inplanes, kernel_size=3, stride=2, padding=1, bias=False), | ||||
| nn.SyncBatchNorm(inplanes), | |||||
| nn.BatchNorm2d(inplanes), | |||||
| nn.ReLU(inplace=True), | nn.ReLU(inplace=True), | ||||
| nn.Conv2d( | nn.Conv2d( | ||||
| inplanes, | inplanes, | ||||
| @@ -426,7 +425,7 @@ class SpatialPriorModule(nn.Module): | |||||
| stride=1, | stride=1, | ||||
| padding=1, | padding=1, | ||||
| bias=False), | bias=False), | ||||
| nn.SyncBatchNorm(inplanes), | |||||
| nn.BatchNorm2d(inplanes), | |||||
| nn.ReLU(inplace=True), | nn.ReLU(inplace=True), | ||||
| nn.Conv2d( | nn.Conv2d( | ||||
| inplanes, | inplanes, | ||||
| @@ -435,7 +434,7 @@ class SpatialPriorModule(nn.Module): | |||||
| stride=1, | stride=1, | ||||
| padding=1, | padding=1, | ||||
| bias=False), | bias=False), | ||||
| nn.SyncBatchNorm(inplanes), | |||||
| nn.BatchNorm2d(inplanes), | |||||
| nn.ReLU(inplace=True), | nn.ReLU(inplace=True), | ||||
| nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||||
| ]) | ]) | ||||
| @@ -447,7 +446,7 @@ class SpatialPriorModule(nn.Module): | |||||
| stride=2, | stride=2, | ||||
| padding=1, | padding=1, | ||||
| bias=False), | bias=False), | ||||
| nn.SyncBatchNorm(2 * inplanes), | |||||
| nn.BatchNorm2d(2 * inplanes), | |||||
| nn.ReLU(inplace=True) | nn.ReLU(inplace=True) | ||||
| ]) | ]) | ||||
| self.conv3 = nn.Sequential(*[ | self.conv3 = nn.Sequential(*[ | ||||
| @@ -458,7 +457,7 @@ class SpatialPriorModule(nn.Module): | |||||
| stride=2, | stride=2, | ||||
| padding=1, | padding=1, | ||||
| bias=False), | bias=False), | ||||
| nn.SyncBatchNorm(4 * inplanes), | |||||
| nn.BatchNorm2d(4 * inplanes), | |||||
| nn.ReLU(inplace=True) | nn.ReLU(inplace=True) | ||||
| ]) | ]) | ||||
| self.conv4 = nn.Sequential(*[ | self.conv4 = nn.Sequential(*[ | ||||
| @@ -469,7 +468,7 @@ class SpatialPriorModule(nn.Module): | |||||
| stride=2, | stride=2, | ||||
| padding=1, | padding=1, | ||||
| bias=False), | bias=False), | ||||
| nn.SyncBatchNorm(4 * inplanes), | |||||
| nn.BatchNorm2d(4 * inplanes), | |||||
| nn.ReLU(inplace=True) | nn.ReLU(inplace=True) | ||||
| ]) | ]) | ||||
| self.fc1 = nn.Conv2d( | self.fc1 = nn.Conv2d( | ||||
| @@ -1,3 +1,5 @@ | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| from .beit import BASEBEiT | from .beit import BASEBEiT | ||||
| __all__ = ['BASEBEiT'] | __all__ = ['BASEBEiT'] | ||||
| @@ -1,7 +1,5 @@ | |||||
| # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) | |||||
| # Github source: https://github.com/microsoft/unilm/tree/master/beit | |||||
| # This implementation refers to | |||||
| # https://github.com/czczup/ViT-Adapter.git | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| import math | import math | ||||
| from functools import partial | from functools import partial | ||||
| @@ -1,6 +1,5 @@ | |||||
| # The implementation refers to the VitAdapter | |||||
| # available at | |||||
| # https://github.com/czczup/ViT-Adapter.git | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| import logging | import logging | ||||
| import math | import math | ||||
| @@ -69,10 +68,10 @@ class BEiTAdapter(BASEBEiT): | |||||
| ]) | ]) | ||||
| self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) | self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2) | ||||
| self.norm1 = nn.SyncBatchNorm(embed_dim) | |||||
| self.norm2 = nn.SyncBatchNorm(embed_dim) | |||||
| self.norm3 = nn.SyncBatchNorm(embed_dim) | |||||
| self.norm4 = nn.SyncBatchNorm(embed_dim) | |||||
| self.norm1 = nn.BatchNorm2d(embed_dim) | |||||
| self.norm2 = nn.BatchNorm2d(embed_dim) | |||||
| self.norm3 = nn.BatchNorm2d(embed_dim) | |||||
| self.norm4 = nn.BatchNorm2d(embed_dim) | |||||
| self.up.apply(self._init_weights) | self.up.apply(self._init_weights) | ||||
| self.spm.apply(self._init_weights) | self.spm.apply(self._init_weights) | ||||
| @@ -1,3 +1,5 @@ | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| from .mask2former_head_from_mmseg import Mask2FormerHeadFromMMSeg | from .mask2former_head_from_mmseg import Mask2FormerHeadFromMMSeg | ||||
| __all__ = ['Mask2FormerHeadFromMMSeg'] | __all__ = ['Mask2FormerHeadFromMMSeg'] | ||||
| @@ -1,6 +1,5 @@ | |||||
| # The implementation refers to the VitAdapter | |||||
| # available at | |||||
| # https://github.com/czczup/ViT-Adapter.git | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
| import torch | import torch | ||||
| @@ -1,6 +1,5 @@ | |||||
| # The implementation refers to the VitAdapter | |||||
| # available at | |||||
| # https://github.com/czczup/ViT-Adapter.git | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| import copy | import copy | ||||
| @@ -1,3 +1,5 @@ | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| from .encoder_decoder_mask2former import EncoderDecoderMask2Former | from .encoder_decoder_mask2former import EncoderDecoderMask2Former | ||||
| __all__ = ['EncoderDecoderMask2Former'] | __all__ = ['EncoderDecoderMask2Former'] | ||||
| @@ -1,6 +1,5 @@ | |||||
| # The implementation refers to the VitAdapter | |||||
| # available at | |||||
| # https://github.com/czczup/ViT-Adapter.git | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| import warnings | import warnings | ||||
| from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| @@ -1,6 +1,5 @@ | |||||
| # The implementation refers to the VitAdapter | |||||
| # available at | |||||
| # https://github.com/czczup/ViT-Adapter.git | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| @@ -1,3 +1,5 @@ | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| from .builder import build_pixel_sampler | from .builder import build_pixel_sampler | ||||
| from .data_process_func import ResizeToMultiple | from .data_process_func import ResizeToMultiple | ||||
| from .seg_func import add_prefix, seg_resize | from .seg_func import add_prefix, seg_resize | ||||
| @@ -1,6 +1,5 @@ | |||||
| # The implementation refers to the VitAdapter | |||||
| # available at | |||||
| # https://github.com/czczup/ViT-Adapter.git | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| from mmcv.utils import Registry, build_from_cfg | from mmcv.utils import Registry, build_from_cfg | ||||
| PIXEL_SAMPLERS = Registry('pixel sampler') | PIXEL_SAMPLERS = Registry('pixel sampler') | ||||
| @@ -1,6 +1,5 @@ | |||||
| # The implementation refers to the VitAdapter | |||||
| # available at | |||||
| # https://github.com/czczup/ViT-Adapter.git | |||||
| # The implementation is adopted from VitAdapter, | |||||
| # made publicly available under the Apache License at https://github.com/czczup/ViT-Adapter.git | |||||
| import warnings | import warnings | ||||
| @@ -4,11 +4,13 @@ from typing import Any, Dict, Union | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| import PIL | import PIL | ||||
| import torch | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Pipeline | from modelscope.pipelines.base import Input, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -39,28 +41,24 @@ class ImagePanopticSegmentationPipeline(Pipeline): | |||||
| # build the data pipeline | # build the data pipeline | ||||
| if isinstance(input, str): | if isinstance(input, str): | ||||
| # input is str, file names, pipeline loadimagefromfile | |||||
| # collect data | |||||
| data = dict(img_info=dict(filename=input), img_prefix=None) | |||||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||||
| img = np.array(load_image(input)) | |||||
| img = img[:, :, ::-1] # convert to bgr | |||||
| elif isinstance(input, PIL.Image.Image): | elif isinstance(input, PIL.Image.Image): | ||||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | ||||
| img = np.array(input.convert('RGB')) | img = np.array(input.convert('RGB')) | ||||
| # collect data | |||||
| data = dict(img=img) | |||||
| elif isinstance(input, np.ndarray): | elif isinstance(input, np.ndarray): | ||||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | ||||
| if len(input.shape) == 2: | if len(input.shape) == 2: | ||||
| img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | ||||
| else: | else: | ||||
| img = input | img = input | ||||
| img = img[:, :, ::-1] # in rgb order | |||||
| # collect data | |||||
| data = dict(img=img) | |||||
| else: | else: | ||||
| raise TypeError(f'input should be either str, PIL.Image,' | raise TypeError(f'input should be either str, PIL.Image,' | ||||
| f' np.array, but got {type(input)}') | f' np.array, but got {type(input)}') | ||||
| # collect data | |||||
| data = dict(img=img) | |||||
| cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | ||||
| test_pipeline = Compose(cfg.data.test.pipeline) | test_pipeline = Compose(cfg.data.test.pipeline) | ||||
| @@ -10,6 +10,7 @@ from modelscope.metainfo import Pipelines | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Model, Pipeline | from modelscope.pipelines.base import Input, Model, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -40,28 +41,24 @@ class ImageSemanticSegmentationPipeline(Pipeline): | |||||
| # build the data pipeline | # build the data pipeline | ||||
| if isinstance(input, str): | if isinstance(input, str): | ||||
| # input is str, file names, pipeline loadimagefromfile | |||||
| # collect data | |||||
| data = dict(img_info=dict(filename=input), img_prefix=None) | |||||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | |||||
| img = np.array(load_image(input)) | |||||
| img = img[:, :, ::-1] # convert to bgr | |||||
| elif isinstance(input, PIL.Image.Image): # BGR | elif isinstance(input, PIL.Image.Image): # BGR | ||||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | ||||
| img = np.array(input)[:, :, ::-1] | img = np.array(input)[:, :, ::-1] | ||||
| # collect data | |||||
| data = dict(img=img) | |||||
| elif isinstance(input, np.ndarray): | elif isinstance(input, np.ndarray): | ||||
| cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' | ||||
| if len(input.shape) == 2: | if len(input.shape) == 2: | ||||
| img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | ||||
| else: | else: | ||||
| img = input | img = input | ||||
| # collect data | |||||
| data = dict(img=img) | |||||
| else: | else: | ||||
| raise TypeError(f'input should be either str, PIL.Image,' | raise TypeError(f'input should be either str, PIL.Image,' | ||||
| f' np.array, but got {type(input)}') | f' np.array, but got {type(input)}') | ||||
| # data = dict(img=input) | |||||
| # collect data | |||||
| data = dict(img=img) | |||||
| cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | ||||
| test_pipeline = Compose(cfg.data.test.pipeline) | test_pipeline = Compose(cfg.data.test.pipeline) | ||||
| @@ -80,11 +77,9 @@ class ImageSemanticSegmentationPipeline(Pipeline): | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| results = self.model.inference(input) | results = self.model.inference(input) | ||||
| return results | return results | ||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | ||||
| results = self.model.postprocess(inputs) | results = self.model.postprocess(inputs) | ||||
| outputs = { | outputs = { | ||||
| OutputKeys.MASKS: results[OutputKeys.MASKS], | OutputKeys.MASKS: results[OutputKeys.MASKS], | ||||