Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10252583master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:a16038f7809127eb3e03cbae049592d193707e095309daca78f7d108d67fe4ec | |||
| size 108357 | |||
| @@ -42,6 +42,7 @@ class Models(object): | |||
| hand_static = 'hand-static' | |||
| face_human_hand_detection = 'face-human-hand-detection' | |||
| face_emotion = 'face-emotion' | |||
| product_segmentation = 'product-segmentation' | |||
| # EasyCV models | |||
| yolox = 'YOLOX' | |||
| @@ -185,6 +186,7 @@ class Pipelines(object): | |||
| hand_static = 'hand-static' | |||
| face_human_hand_detection = 'face-human-hand-detection' | |||
| face_emotion = 'face-emotion' | |||
| product_segmentation = 'product-segmentation' | |||
| # nlp tasks | |||
| sentence_similarity = 'sentence-similarity' | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .seg_infer import F3NetProductSegmentation | |||
| else: | |||
| _import_structure = {'seg_infer': ['F3NetProductSegmentation']} | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,197 @@ | |||
| # The implementation here is modified based on F3Net, | |||
| # originally Apache 2.0 License and publicly avaialbe at https://github.com/weijun88/F3Net | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Bottleneck(nn.Module): | |||
| def __init__(self, | |||
| inplanes, | |||
| planes, | |||
| stride=1, | |||
| downsample=None, | |||
| dilation=1): | |||
| super(Bottleneck, self).__init__() | |||
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||
| self.bn1 = nn.BatchNorm2d(planes) | |||
| self.conv2 = nn.Conv2d( | |||
| planes, | |||
| planes, | |||
| kernel_size=3, | |||
| stride=stride, | |||
| padding=(3 * dilation - 1) // 2, | |||
| bias=False, | |||
| dilation=dilation) | |||
| self.bn2 = nn.BatchNorm2d(planes) | |||
| self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | |||
| self.bn3 = nn.BatchNorm2d(planes * 4) | |||
| self.downsample = downsample | |||
| def forward(self, x): | |||
| out = F.relu(self.bn1(self.conv1(x)), inplace=True) | |||
| out = F.relu(self.bn2(self.conv2(out)), inplace=True) | |||
| out = self.bn3(self.conv3(out)) | |||
| if self.downsample is not None: | |||
| x = self.downsample(x) | |||
| return F.relu(out + x, inplace=True) | |||
| class ResNet(nn.Module): | |||
| def __init__(self): | |||
| super(ResNet, self).__init__() | |||
| self.inplanes = 64 | |||
| self.conv1 = nn.Conv2d( | |||
| 3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||
| self.bn1 = nn.BatchNorm2d(64) | |||
| self.layer1 = self.make_layer(64, 3, stride=1, dilation=1) | |||
| self.layer2 = self.make_layer(128, 4, stride=2, dilation=1) | |||
| self.layer3 = self.make_layer(256, 6, stride=2, dilation=1) | |||
| self.layer4 = self.make_layer(512, 3, stride=2, dilation=1) | |||
| def make_layer(self, planes, blocks, stride, dilation): | |||
| downsample = nn.Sequential( | |||
| nn.Conv2d( | |||
| self.inplanes, | |||
| planes * 4, | |||
| kernel_size=1, | |||
| stride=stride, | |||
| bias=False), nn.BatchNorm2d(planes * 4)) | |||
| layers = [ | |||
| Bottleneck( | |||
| self.inplanes, planes, stride, downsample, dilation=dilation) | |||
| ] | |||
| self.inplanes = planes * 4 | |||
| for _ in range(1, blocks): | |||
| layers.append(Bottleneck(self.inplanes, planes, dilation=dilation)) | |||
| return nn.Sequential(*layers) | |||
| def forward(self, x): | |||
| x = x.reshape(1, 3, 448, 448) | |||
| out1 = F.relu(self.bn1(self.conv1(x)), inplace=True) | |||
| out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1) | |||
| out2 = self.layer1(out1) | |||
| out3 = self.layer2(out2) | |||
| out4 = self.layer3(out3) | |||
| out5 = self.layer4(out4) | |||
| return out2, out3, out4, out5 | |||
| class CFM(nn.Module): | |||
| def __init__(self): | |||
| super(CFM, self).__init__() | |||
| self.conv1h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
| self.bn1h = nn.BatchNorm2d(64) | |||
| self.conv2h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
| self.bn2h = nn.BatchNorm2d(64) | |||
| self.conv3h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
| self.bn3h = nn.BatchNorm2d(64) | |||
| self.conv4h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
| self.bn4h = nn.BatchNorm2d(64) | |||
| self.conv1v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
| self.bn1v = nn.BatchNorm2d(64) | |||
| self.conv2v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
| self.bn2v = nn.BatchNorm2d(64) | |||
| self.conv3v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
| self.bn3v = nn.BatchNorm2d(64) | |||
| self.conv4v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||
| self.bn4v = nn.BatchNorm2d(64) | |||
| def forward(self, left, down): | |||
| if down.size()[2:] != left.size()[2:]: | |||
| down = F.interpolate(down, size=left.size()[2:], mode='bilinear') | |||
| out1h = F.relu(self.bn1h(self.conv1h(left)), inplace=True) | |||
| out2h = F.relu(self.bn2h(self.conv2h(out1h)), inplace=True) | |||
| out1v = F.relu(self.bn1v(self.conv1v(down)), inplace=True) | |||
| out2v = F.relu(self.bn2v(self.conv2v(out1v)), inplace=True) | |||
| fuse = out2h * out2v | |||
| out3h = F.relu(self.bn3h(self.conv3h(fuse)), inplace=True) + out1h | |||
| out4h = F.relu(self.bn4h(self.conv4h(out3h)), inplace=True) | |||
| out3v = F.relu(self.bn3v(self.conv3v(fuse)), inplace=True) + out1v | |||
| out4v = F.relu(self.bn4v(self.conv4v(out3v)), inplace=True) | |||
| return out4h, out4v | |||
| class Decoder(nn.Module): | |||
| def __init__(self): | |||
| super(Decoder, self).__init__() | |||
| self.cfm45 = CFM() | |||
| self.cfm34 = CFM() | |||
| self.cfm23 = CFM() | |||
| def forward(self, out2h, out3h, out4h, out5v, fback=None): | |||
| if fback is not None: | |||
| refine5 = F.interpolate( | |||
| fback, size=out5v.size()[2:], mode='bilinear') | |||
| refine4 = F.interpolate( | |||
| fback, size=out4h.size()[2:], mode='bilinear') | |||
| refine3 = F.interpolate( | |||
| fback, size=out3h.size()[2:], mode='bilinear') | |||
| refine2 = F.interpolate( | |||
| fback, size=out2h.size()[2:], mode='bilinear') | |||
| out5v = out5v + refine5 | |||
| out4h, out4v = self.cfm45(out4h + refine4, out5v) | |||
| out3h, out3v = self.cfm34(out3h + refine3, out4v) | |||
| out2h, pred = self.cfm23(out2h + refine2, out3v) | |||
| else: | |||
| out4h, out4v = self.cfm45(out4h, out5v) | |||
| out3h, out3v = self.cfm34(out3h, out4v) | |||
| out2h, pred = self.cfm23(out2h, out3v) | |||
| return out2h, out3h, out4h, out5v, pred | |||
| class F3Net(nn.Module): | |||
| def __init__(self): | |||
| super(F3Net, self).__init__() | |||
| self.bkbone = ResNet() | |||
| self.squeeze5 = nn.Sequential( | |||
| nn.Conv2d(2048, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||
| self.squeeze4 = nn.Sequential( | |||
| nn.Conv2d(1024, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||
| self.squeeze3 = nn.Sequential( | |||
| nn.Conv2d(512, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||
| self.squeeze2 = nn.Sequential( | |||
| nn.Conv2d(256, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||
| self.decoder1 = Decoder() | |||
| self.decoder2 = Decoder() | |||
| self.linearp1 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
| self.linearp2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
| self.linearr2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
| self.linearr3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
| self.linearr4 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
| self.linearr5 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||
| def forward(self, x, shape=None): | |||
| x = x.reshape(1, 3, 448, 448) | |||
| out2h, out3h, out4h, out5v = self.bkbone(x) | |||
| out2h, out3h, out4h, out5v = self.squeeze2(out2h), self.squeeze3( | |||
| out3h), self.squeeze4(out4h), self.squeeze5(out5v) | |||
| out2h, out3h, out4h, out5v, pred1 = self.decoder1( | |||
| out2h, out3h, out4h, out5v) | |||
| out2h, out3h, out4h, out5v, pred2 = self.decoder2( | |||
| out2h, out3h, out4h, out5v, pred1) | |||
| shape = x.size()[2:] if shape is None else shape | |||
| pred1 = F.interpolate( | |||
| self.linearp1(pred1), size=shape, mode='bilinear') | |||
| pred2 = F.interpolate( | |||
| self.linearp2(pred2), size=shape, mode='bilinear') | |||
| out2h = F.interpolate( | |||
| self.linearr2(out2h), size=shape, mode='bilinear') | |||
| out3h = F.interpolate( | |||
| self.linearr3(out3h), size=shape, mode='bilinear') | |||
| out4h = F.interpolate( | |||
| self.linearr4(out4h), size=shape, mode='bilinear') | |||
| out5h = F.interpolate( | |||
| self.linearr5(out5v), size=shape, mode='bilinear') | |||
| return pred1, pred2, out2h, out3h, out4h, out5h | |||
| @@ -0,0 +1,77 @@ | |||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| from PIL import Image | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from .net import F3Net | |||
| logger = get_logger() | |||
| def load_state_dict(model_dir, device): | |||
| _dict = torch.load( | |||
| '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | |||
| map_location=device) | |||
| state_dict = {} | |||
| for k, v in _dict.items(): | |||
| if k.startswith('module'): | |||
| k = k[7:] | |||
| state_dict[k] = v | |||
| return state_dict | |||
| @MODELS.register_module( | |||
| Tasks.product_segmentation, module_name=Models.product_segmentation) | |||
| class F3NetForProductSegmentation(TorchModel): | |||
| def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||
| super().__init__( | |||
| model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||
| self.model = F3Net() | |||
| if torch.cuda.is_available(): | |||
| self.device = 'cuda' | |||
| logger.info('Use GPU') | |||
| else: | |||
| self.device = 'cpu' | |||
| logger.info('Use CPU') | |||
| self.params = load_state_dict(model_dir, self.device) | |||
| self.model.load_state_dict(self.params) | |||
| self.model.to(self.device) | |||
| self.model.eval() | |||
| self.model.to(self.device) | |||
| def forward(self, x): | |||
| pred_result = self.model(x) | |||
| return pred_result | |||
| mean, std = np.array([[[124.55, 118.90, | |||
| 102.94]]]), np.array([[[56.77, 55.97, 57.50]]]) | |||
| def inference(model, device, input_path): | |||
| img = Image.open(input_path) | |||
| img = np.array(img.convert('RGB')).astype(np.float32) | |||
| img = (img - mean) / std | |||
| img = cv2.resize(img, dsize=(448, 448), interpolation=cv2.INTER_LINEAR) | |||
| img = torch.from_numpy(img) | |||
| img = img.permute(2, 0, 1) | |||
| img = img.to(device).float() | |||
| outputs = model(img) | |||
| out = outputs[0] | |||
| pred = (torch.sigmoid(out[0, 0]) * 255).cpu().numpy() | |||
| pred[pred < 20] = 0 | |||
| pred = pred[:, :, np.newaxis] | |||
| pred = np.round(pred) | |||
| logger.info('Inference Done') | |||
| return pred | |||
| @@ -674,5 +674,12 @@ TASK_OUTPUTS = { | |||
| # { | |||
| # {'output': 'Happiness', 'boxes': (203, 104, 663, 564)} | |||
| # } | |||
| Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES] | |||
| Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES], | |||
| # { | |||
| # "masks": [ | |||
| # np.array # 2D array containing only 0, 255 | |||
| # ] | |||
| # } | |||
| Tasks.product_segmentation: [OutputKeys.MASKS], | |||
| } | |||
| @@ -187,6 +187,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| (Pipelines.face_human_hand_detection, | |||
| 'damo/cv_nanodet_face-human-hand-detection'), | |||
| Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'), | |||
| Tasks.product_segmentation: (Pipelines.product_segmentation, | |||
| 'damo/cv_F3Net_product-segmentation'), | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
| from typing import Any, Dict | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.product_segmentation import seg_infer | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.product_segmentation, module_name=Pipelines.product_segmentation) | |||
| class F3NetForProductSegmentationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create product segmentation pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| return input | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| mask = seg_infer.inference(self.model, self.device, | |||
| input['input_path']) | |||
| return {OutputKeys.MASKS: mask} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -45,6 +45,7 @@ class CVTasks(object): | |||
| hand_static = 'hand-static' | |||
| face_human_hand_detection = 'face-human-hand-detection' | |||
| face_emotion = 'face-emotion' | |||
| product_segmentation = 'product-segmentation' | |||
| # image editing | |||
| skin_retouching = 'skin-retouching' | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.test_utils import test_level | |||
| logger = get_logger() | |||
| class ProductSegmentationTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_F3Net_product-segmentation' | |||
| self.input = { | |||
| 'input_path': 'data/test/images/product_segmentation.jpg' | |||
| } | |||
| def pipeline_inference(self, pipeline: Pipeline, input: str): | |||
| result = pipeline(input) | |||
| cv2.imwrite('test_product_segmentation_mask.jpg', | |||
| result[OutputKeys.MASKS]) | |||
| logger.info('test done') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| product_segmentation = pipeline( | |||
| Tasks.product_segmentation, model=self.model_id) | |||
| self.pipeline_inference(product_segmentation, self.input) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| product_segmentation = pipeline(Tasks.product_segmentation) | |||
| self.pipeline_inference(product_segmentation, self.input) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||