Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10252583master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:a16038f7809127eb3e03cbae049592d193707e095309daca78f7d108d67fe4ec | |||||
| size 108357 | |||||
| @@ -42,6 +42,7 @@ class Models(object): | |||||
| hand_static = 'hand-static' | hand_static = 'hand-static' | ||||
| face_human_hand_detection = 'face-human-hand-detection' | face_human_hand_detection = 'face-human-hand-detection' | ||||
| face_emotion = 'face-emotion' | face_emotion = 'face-emotion' | ||||
| product_segmentation = 'product-segmentation' | |||||
| # EasyCV models | # EasyCV models | ||||
| yolox = 'YOLOX' | yolox = 'YOLOX' | ||||
| @@ -185,6 +186,7 @@ class Pipelines(object): | |||||
| hand_static = 'hand-static' | hand_static = 'hand-static' | ||||
| face_human_hand_detection = 'face-human-hand-detection' | face_human_hand_detection = 'face-human-hand-detection' | ||||
| face_emotion = 'face-emotion' | face_emotion = 'face-emotion' | ||||
| product_segmentation = 'product-segmentation' | |||||
| # nlp tasks | # nlp tasks | ||||
| sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .seg_infer import F3NetProductSegmentation | |||||
| else: | |||||
| _import_structure = {'seg_infer': ['F3NetProductSegmentation']} | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,197 @@ | |||||
| # The implementation here is modified based on F3Net, | |||||
| # originally Apache 2.0 License and publicly avaialbe at https://github.com/weijun88/F3Net | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| class Bottleneck(nn.Module): | |||||
| def __init__(self, | |||||
| inplanes, | |||||
| planes, | |||||
| stride=1, | |||||
| downsample=None, | |||||
| dilation=1): | |||||
| super(Bottleneck, self).__init__() | |||||
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||||
| self.bn1 = nn.BatchNorm2d(planes) | |||||
| self.conv2 = nn.Conv2d( | |||||
| planes, | |||||
| planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=(3 * dilation - 1) // 2, | |||||
| bias=False, | |||||
| dilation=dilation) | |||||
| self.bn2 = nn.BatchNorm2d(planes) | |||||
| self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | |||||
| self.bn3 = nn.BatchNorm2d(planes * 4) | |||||
| self.downsample = downsample | |||||
| def forward(self, x): | |||||
| out = F.relu(self.bn1(self.conv1(x)), inplace=True) | |||||
| out = F.relu(self.bn2(self.conv2(out)), inplace=True) | |||||
| out = self.bn3(self.conv3(out)) | |||||
| if self.downsample is not None: | |||||
| x = self.downsample(x) | |||||
| return F.relu(out + x, inplace=True) | |||||
| class ResNet(nn.Module): | |||||
| def __init__(self): | |||||
| super(ResNet, self).__init__() | |||||
| self.inplanes = 64 | |||||
| self.conv1 = nn.Conv2d( | |||||
| 3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||||
| self.bn1 = nn.BatchNorm2d(64) | |||||
| self.layer1 = self.make_layer(64, 3, stride=1, dilation=1) | |||||
| self.layer2 = self.make_layer(128, 4, stride=2, dilation=1) | |||||
| self.layer3 = self.make_layer(256, 6, stride=2, dilation=1) | |||||
| self.layer4 = self.make_layer(512, 3, stride=2, dilation=1) | |||||
| def make_layer(self, planes, blocks, stride, dilation): | |||||
| downsample = nn.Sequential( | |||||
| nn.Conv2d( | |||||
| self.inplanes, | |||||
| planes * 4, | |||||
| kernel_size=1, | |||||
| stride=stride, | |||||
| bias=False), nn.BatchNorm2d(planes * 4)) | |||||
| layers = [ | |||||
| Bottleneck( | |||||
| self.inplanes, planes, stride, downsample, dilation=dilation) | |||||
| ] | |||||
| self.inplanes = planes * 4 | |||||
| for _ in range(1, blocks): | |||||
| layers.append(Bottleneck(self.inplanes, planes, dilation=dilation)) | |||||
| return nn.Sequential(*layers) | |||||
| def forward(self, x): | |||||
| x = x.reshape(1, 3, 448, 448) | |||||
| out1 = F.relu(self.bn1(self.conv1(x)), inplace=True) | |||||
| out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1) | |||||
| out2 = self.layer1(out1) | |||||
| out3 = self.layer2(out2) | |||||
| out4 = self.layer3(out3) | |||||
| out5 = self.layer4(out4) | |||||
| return out2, out3, out4, out5 | |||||
| class CFM(nn.Module): | |||||
| def __init__(self): | |||||
| super(CFM, self).__init__() | |||||
| self.conv1h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||||
| self.bn1h = nn.BatchNorm2d(64) | |||||
| self.conv2h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||||
| self.bn2h = nn.BatchNorm2d(64) | |||||
| self.conv3h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||||
| self.bn3h = nn.BatchNorm2d(64) | |||||
| self.conv4h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||||
| self.bn4h = nn.BatchNorm2d(64) | |||||
| self.conv1v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||||
| self.bn1v = nn.BatchNorm2d(64) | |||||
| self.conv2v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||||
| self.bn2v = nn.BatchNorm2d(64) | |||||
| self.conv3v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||||
| self.bn3v = nn.BatchNorm2d(64) | |||||
| self.conv4v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) | |||||
| self.bn4v = nn.BatchNorm2d(64) | |||||
| def forward(self, left, down): | |||||
| if down.size()[2:] != left.size()[2:]: | |||||
| down = F.interpolate(down, size=left.size()[2:], mode='bilinear') | |||||
| out1h = F.relu(self.bn1h(self.conv1h(left)), inplace=True) | |||||
| out2h = F.relu(self.bn2h(self.conv2h(out1h)), inplace=True) | |||||
| out1v = F.relu(self.bn1v(self.conv1v(down)), inplace=True) | |||||
| out2v = F.relu(self.bn2v(self.conv2v(out1v)), inplace=True) | |||||
| fuse = out2h * out2v | |||||
| out3h = F.relu(self.bn3h(self.conv3h(fuse)), inplace=True) + out1h | |||||
| out4h = F.relu(self.bn4h(self.conv4h(out3h)), inplace=True) | |||||
| out3v = F.relu(self.bn3v(self.conv3v(fuse)), inplace=True) + out1v | |||||
| out4v = F.relu(self.bn4v(self.conv4v(out3v)), inplace=True) | |||||
| return out4h, out4v | |||||
| class Decoder(nn.Module): | |||||
| def __init__(self): | |||||
| super(Decoder, self).__init__() | |||||
| self.cfm45 = CFM() | |||||
| self.cfm34 = CFM() | |||||
| self.cfm23 = CFM() | |||||
| def forward(self, out2h, out3h, out4h, out5v, fback=None): | |||||
| if fback is not None: | |||||
| refine5 = F.interpolate( | |||||
| fback, size=out5v.size()[2:], mode='bilinear') | |||||
| refine4 = F.interpolate( | |||||
| fback, size=out4h.size()[2:], mode='bilinear') | |||||
| refine3 = F.interpolate( | |||||
| fback, size=out3h.size()[2:], mode='bilinear') | |||||
| refine2 = F.interpolate( | |||||
| fback, size=out2h.size()[2:], mode='bilinear') | |||||
| out5v = out5v + refine5 | |||||
| out4h, out4v = self.cfm45(out4h + refine4, out5v) | |||||
| out3h, out3v = self.cfm34(out3h + refine3, out4v) | |||||
| out2h, pred = self.cfm23(out2h + refine2, out3v) | |||||
| else: | |||||
| out4h, out4v = self.cfm45(out4h, out5v) | |||||
| out3h, out3v = self.cfm34(out3h, out4v) | |||||
| out2h, pred = self.cfm23(out2h, out3v) | |||||
| return out2h, out3h, out4h, out5v, pred | |||||
| class F3Net(nn.Module): | |||||
| def __init__(self): | |||||
| super(F3Net, self).__init__() | |||||
| self.bkbone = ResNet() | |||||
| self.squeeze5 = nn.Sequential( | |||||
| nn.Conv2d(2048, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||||
| self.squeeze4 = nn.Sequential( | |||||
| nn.Conv2d(1024, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||||
| self.squeeze3 = nn.Sequential( | |||||
| nn.Conv2d(512, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||||
| self.squeeze2 = nn.Sequential( | |||||
| nn.Conv2d(256, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) | |||||
| self.decoder1 = Decoder() | |||||
| self.decoder2 = Decoder() | |||||
| self.linearp1 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||||
| self.linearp2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||||
| self.linearr2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||||
| self.linearr3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||||
| self.linearr4 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||||
| self.linearr5 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) | |||||
| def forward(self, x, shape=None): | |||||
| x = x.reshape(1, 3, 448, 448) | |||||
| out2h, out3h, out4h, out5v = self.bkbone(x) | |||||
| out2h, out3h, out4h, out5v = self.squeeze2(out2h), self.squeeze3( | |||||
| out3h), self.squeeze4(out4h), self.squeeze5(out5v) | |||||
| out2h, out3h, out4h, out5v, pred1 = self.decoder1( | |||||
| out2h, out3h, out4h, out5v) | |||||
| out2h, out3h, out4h, out5v, pred2 = self.decoder2( | |||||
| out2h, out3h, out4h, out5v, pred1) | |||||
| shape = x.size()[2:] if shape is None else shape | |||||
| pred1 = F.interpolate( | |||||
| self.linearp1(pred1), size=shape, mode='bilinear') | |||||
| pred2 = F.interpolate( | |||||
| self.linearp2(pred2), size=shape, mode='bilinear') | |||||
| out2h = F.interpolate( | |||||
| self.linearr2(out2h), size=shape, mode='bilinear') | |||||
| out3h = F.interpolate( | |||||
| self.linearr3(out3h), size=shape, mode='bilinear') | |||||
| out4h = F.interpolate( | |||||
| self.linearr4(out4h), size=shape, mode='bilinear') | |||||
| out5h = F.interpolate( | |||||
| self.linearr5(out5v), size=shape, mode='bilinear') | |||||
| return pred1, pred2, out2h, out3h, out4h, out5h | |||||
| @@ -0,0 +1,77 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from PIL import Image | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .net import F3Net | |||||
| logger = get_logger() | |||||
| def load_state_dict(model_dir, device): | |||||
| _dict = torch.load( | |||||
| '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | |||||
| map_location=device) | |||||
| state_dict = {} | |||||
| for k, v in _dict.items(): | |||||
| if k.startswith('module'): | |||||
| k = k[7:] | |||||
| state_dict[k] = v | |||||
| return state_dict | |||||
| @MODELS.register_module( | |||||
| Tasks.product_segmentation, module_name=Models.product_segmentation) | |||||
| class F3NetForProductSegmentation(TorchModel): | |||||
| def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||||
| super().__init__( | |||||
| model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||||
| self.model = F3Net() | |||||
| if torch.cuda.is_available(): | |||||
| self.device = 'cuda' | |||||
| logger.info('Use GPU') | |||||
| else: | |||||
| self.device = 'cpu' | |||||
| logger.info('Use CPU') | |||||
| self.params = load_state_dict(model_dir, self.device) | |||||
| self.model.load_state_dict(self.params) | |||||
| self.model.to(self.device) | |||||
| self.model.eval() | |||||
| self.model.to(self.device) | |||||
| def forward(self, x): | |||||
| pred_result = self.model(x) | |||||
| return pred_result | |||||
| mean, std = np.array([[[124.55, 118.90, | |||||
| 102.94]]]), np.array([[[56.77, 55.97, 57.50]]]) | |||||
| def inference(model, device, input_path): | |||||
| img = Image.open(input_path) | |||||
| img = np.array(img.convert('RGB')).astype(np.float32) | |||||
| img = (img - mean) / std | |||||
| img = cv2.resize(img, dsize=(448, 448), interpolation=cv2.INTER_LINEAR) | |||||
| img = torch.from_numpy(img) | |||||
| img = img.permute(2, 0, 1) | |||||
| img = img.to(device).float() | |||||
| outputs = model(img) | |||||
| out = outputs[0] | |||||
| pred = (torch.sigmoid(out[0, 0]) * 255).cpu().numpy() | |||||
| pred[pred < 20] = 0 | |||||
| pred = pred[:, :, np.newaxis] | |||||
| pred = np.round(pred) | |||||
| logger.info('Inference Done') | |||||
| return pred | |||||
| @@ -674,5 +674,12 @@ TASK_OUTPUTS = { | |||||
| # { | # { | ||||
| # {'output': 'Happiness', 'boxes': (203, 104, 663, 564)} | # {'output': 'Happiness', 'boxes': (203, 104, 663, 564)} | ||||
| # } | # } | ||||
| Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES] | |||||
| Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES], | |||||
| # { | |||||
| # "masks": [ | |||||
| # np.array # 2D array containing only 0, 255 | |||||
| # ] | |||||
| # } | |||||
| Tasks.product_segmentation: [OutputKeys.MASKS], | |||||
| } | } | ||||
| @@ -187,6 +187,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| (Pipelines.face_human_hand_detection, | (Pipelines.face_human_hand_detection, | ||||
| 'damo/cv_nanodet_face-human-hand-detection'), | 'damo/cv_nanodet_face-human-hand-detection'), | ||||
| Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'), | Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'), | ||||
| Tasks.product_segmentation: (Pipelines.product_segmentation, | |||||
| 'damo/cv_F3Net_product-segmentation'), | |||||
| } | } | ||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| from typing import Any, Dict | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.product_segmentation import seg_infer | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.product_segmentation, module_name=Pipelines.product_segmentation) | |||||
| class F3NetForProductSegmentationPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create product segmentation pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| logger.info('load model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| return input | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| mask = seg_infer.inference(self.model, self.device, | |||||
| input['input_path']) | |||||
| return {OutputKeys.MASKS: mask} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -45,6 +45,7 @@ class CVTasks(object): | |||||
| hand_static = 'hand-static' | hand_static = 'hand-static' | ||||
| face_human_hand_detection = 'face-human-hand-detection' | face_human_hand_detection = 'face-human-hand-detection' | ||||
| face_emotion = 'face-emotion' | face_emotion = 'face-emotion' | ||||
| product_segmentation = 'product-segmentation' | |||||
| # image editing | # image editing | ||||
| skin_retouching = 'skin-retouching' | skin_retouching = 'skin-retouching' | ||||
| @@ -0,0 +1,43 @@ | |||||
| # Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. | |||||
| import unittest | |||||
| import cv2 | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.test_utils import test_level | |||||
| logger = get_logger() | |||||
| class ProductSegmentationTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model_id = 'damo/cv_F3Net_product-segmentation' | |||||
| self.input = { | |||||
| 'input_path': 'data/test/images/product_segmentation.jpg' | |||||
| } | |||||
| def pipeline_inference(self, pipeline: Pipeline, input: str): | |||||
| result = pipeline(input) | |||||
| cv2.imwrite('test_product_segmentation_mask.jpg', | |||||
| result[OutputKeys.MASKS]) | |||||
| logger.info('test done') | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| product_segmentation = pipeline( | |||||
| Tasks.product_segmentation, model=self.model_id) | |||||
| self.pipeline_inference(product_segmentation, self.input) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| product_segmentation = pipeline(Tasks.product_segmentation) | |||||
| self.pipeline_inference(product_segmentation, self.input) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||