diff --git a/data/test/images/product_segmentation.jpg b/data/test/images/product_segmentation.jpg new file mode 100644 index 00000000..c188a69e --- /dev/null +++ b/data/test/images/product_segmentation.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a16038f7809127eb3e03cbae049592d193707e095309daca78f7d108d67fe4ec +size 108357 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index ae8b5297..33273502 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -42,6 +42,7 @@ class Models(object): hand_static = 'hand-static' face_human_hand_detection = 'face-human-hand-detection' face_emotion = 'face-emotion' + product_segmentation = 'product-segmentation' # EasyCV models yolox = 'YOLOX' @@ -185,6 +186,7 @@ class Pipelines(object): hand_static = 'hand-static' face_human_hand_detection = 'face-human-hand-detection' face_emotion = 'face-emotion' + product_segmentation = 'product-segmentation' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/product_segmentation/__init__.py b/modelscope/models/cv/product_segmentation/__init__.py new file mode 100644 index 00000000..e87c8db1 --- /dev/null +++ b/modelscope/models/cv/product_segmentation/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .seg_infer import F3NetProductSegmentation + +else: + _import_structure = {'seg_infer': ['F3NetProductSegmentation']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/product_segmentation/net.py b/modelscope/models/cv/product_segmentation/net.py new file mode 100644 index 00000000..454c99d8 --- /dev/null +++ b/modelscope/models/cv/product_segmentation/net.py @@ -0,0 +1,197 @@ +# The implementation here is modified based on F3Net, +# originally Apache 2.0 License and publicly avaialbe at https://github.com/weijun88/F3Net + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Bottleneck(nn.Module): + + def __init__(self, + inplanes, + planes, + stride=1, + downsample=None, + dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=(3 * dilation - 1) // 2, + bias=False, + dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.downsample = downsample + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x)), inplace=True) + out = F.relu(self.bn2(self.conv2(out)), inplace=True) + out = self.bn3(self.conv3(out)) + if self.downsample is not None: + x = self.downsample(x) + return F.relu(out + x, inplace=True) + + +class ResNet(nn.Module): + + def __init__(self): + super(ResNet, self).__init__() + self.inplanes = 64 + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self.make_layer(64, 3, stride=1, dilation=1) + self.layer2 = self.make_layer(128, 4, stride=2, dilation=1) + self.layer3 = self.make_layer(256, 6, stride=2, dilation=1) + self.layer4 = self.make_layer(512, 3, stride=2, dilation=1) + + def make_layer(self, planes, blocks, stride, dilation): + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * 4, + kernel_size=1, + stride=stride, + bias=False), nn.BatchNorm2d(planes * 4)) + layers = [ + Bottleneck( + self.inplanes, planes, stride, downsample, dilation=dilation) + ] + self.inplanes = planes * 4 + for _ in range(1, blocks): + layers.append(Bottleneck(self.inplanes, planes, dilation=dilation)) + return nn.Sequential(*layers) + + def forward(self, x): + x = x.reshape(1, 3, 448, 448) + out1 = F.relu(self.bn1(self.conv1(x)), inplace=True) + out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1) + out2 = self.layer1(out1) + out3 = self.layer2(out2) + out4 = self.layer3(out3) + out5 = self.layer4(out4) + return out2, out3, out4, out5 + + +class CFM(nn.Module): + + def __init__(self): + super(CFM, self).__init__() + self.conv1h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn1h = nn.BatchNorm2d(64) + self.conv2h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn2h = nn.BatchNorm2d(64) + self.conv3h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn3h = nn.BatchNorm2d(64) + self.conv4h = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn4h = nn.BatchNorm2d(64) + + self.conv1v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn1v = nn.BatchNorm2d(64) + self.conv2v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn2v = nn.BatchNorm2d(64) + self.conv3v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn3v = nn.BatchNorm2d(64) + self.conv4v = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.bn4v = nn.BatchNorm2d(64) + + def forward(self, left, down): + if down.size()[2:] != left.size()[2:]: + down = F.interpolate(down, size=left.size()[2:], mode='bilinear') + out1h = F.relu(self.bn1h(self.conv1h(left)), inplace=True) + out2h = F.relu(self.bn2h(self.conv2h(out1h)), inplace=True) + out1v = F.relu(self.bn1v(self.conv1v(down)), inplace=True) + out2v = F.relu(self.bn2v(self.conv2v(out1v)), inplace=True) + fuse = out2h * out2v + out3h = F.relu(self.bn3h(self.conv3h(fuse)), inplace=True) + out1h + out4h = F.relu(self.bn4h(self.conv4h(out3h)), inplace=True) + out3v = F.relu(self.bn3v(self.conv3v(fuse)), inplace=True) + out1v + out4v = F.relu(self.bn4v(self.conv4v(out3v)), inplace=True) + return out4h, out4v + + +class Decoder(nn.Module): + + def __init__(self): + super(Decoder, self).__init__() + self.cfm45 = CFM() + self.cfm34 = CFM() + self.cfm23 = CFM() + + def forward(self, out2h, out3h, out4h, out5v, fback=None): + if fback is not None: + refine5 = F.interpolate( + fback, size=out5v.size()[2:], mode='bilinear') + refine4 = F.interpolate( + fback, size=out4h.size()[2:], mode='bilinear') + refine3 = F.interpolate( + fback, size=out3h.size()[2:], mode='bilinear') + refine2 = F.interpolate( + fback, size=out2h.size()[2:], mode='bilinear') + out5v = out5v + refine5 + out4h, out4v = self.cfm45(out4h + refine4, out5v) + out3h, out3v = self.cfm34(out3h + refine3, out4v) + out2h, pred = self.cfm23(out2h + refine2, out3v) + else: + out4h, out4v = self.cfm45(out4h, out5v) + out3h, out3v = self.cfm34(out3h, out4v) + out2h, pred = self.cfm23(out2h, out3v) + return out2h, out3h, out4h, out5v, pred + + +class F3Net(nn.Module): + + def __init__(self): + super(F3Net, self).__init__() + self.bkbone = ResNet() + self.squeeze5 = nn.Sequential( + nn.Conv2d(2048, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + self.squeeze4 = nn.Sequential( + nn.Conv2d(1024, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + self.squeeze3 = nn.Sequential( + nn.Conv2d(512, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + self.squeeze2 = nn.Sequential( + nn.Conv2d(256, 64, 1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) + + self.decoder1 = Decoder() + self.decoder2 = Decoder() + self.linearp1 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearp2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + + self.linearr2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearr3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearr4 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + self.linearr5 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1) + + def forward(self, x, shape=None): + x = x.reshape(1, 3, 448, 448) + out2h, out3h, out4h, out5v = self.bkbone(x) + out2h, out3h, out4h, out5v = self.squeeze2(out2h), self.squeeze3( + out3h), self.squeeze4(out4h), self.squeeze5(out5v) + out2h, out3h, out4h, out5v, pred1 = self.decoder1( + out2h, out3h, out4h, out5v) + out2h, out3h, out4h, out5v, pred2 = self.decoder2( + out2h, out3h, out4h, out5v, pred1) + + shape = x.size()[2:] if shape is None else shape + pred1 = F.interpolate( + self.linearp1(pred1), size=shape, mode='bilinear') + pred2 = F.interpolate( + self.linearp2(pred2), size=shape, mode='bilinear') + + out2h = F.interpolate( + self.linearr2(out2h), size=shape, mode='bilinear') + out3h = F.interpolate( + self.linearr3(out3h), size=shape, mode='bilinear') + out4h = F.interpolate( + self.linearr4(out4h), size=shape, mode='bilinear') + out5h = F.interpolate( + self.linearr5(out5v), size=shape, mode='bilinear') + return pred1, pred2, out2h, out3h, out4h, out5h diff --git a/modelscope/models/cv/product_segmentation/seg_infer.py b/modelscope/models/cv/product_segmentation/seg_infer.py new file mode 100644 index 00000000..876fac66 --- /dev/null +++ b/modelscope/models/cv/product_segmentation/seg_infer.py @@ -0,0 +1,77 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import cv2 +import numpy as np +import torch +from PIL import Image + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .net import F3Net + +logger = get_logger() + + +def load_state_dict(model_dir, device): + _dict = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=device) + state_dict = {} + for k, v in _dict.items(): + if k.startswith('module'): + k = k[7:] + state_dict[k] = v + return state_dict + + +@MODELS.register_module( + Tasks.product_segmentation, module_name=Models.product_segmentation) +class F3NetForProductSegmentation(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + + self.model = F3Net() + if torch.cuda.is_available(): + self.device = 'cuda' + logger.info('Use GPU') + else: + self.device = 'cpu' + logger.info('Use CPU') + + self.params = load_state_dict(model_dir, self.device) + self.model.load_state_dict(self.params) + self.model.to(self.device) + self.model.eval() + self.model.to(self.device) + + def forward(self, x): + pred_result = self.model(x) + return pred_result + + +mean, std = np.array([[[124.55, 118.90, + 102.94]]]), np.array([[[56.77, 55.97, 57.50]]]) + + +def inference(model, device, input_path): + img = Image.open(input_path) + img = np.array(img.convert('RGB')).astype(np.float32) + img = (img - mean) / std + img = cv2.resize(img, dsize=(448, 448), interpolation=cv2.INTER_LINEAR) + img = torch.from_numpy(img) + img = img.permute(2, 0, 1) + img = img.to(device).float() + outputs = model(img) + out = outputs[0] + pred = (torch.sigmoid(out[0, 0]) * 255).cpu().numpy() + pred[pred < 20] = 0 + pred = pred[:, :, np.newaxis] + pred = np.round(pred) + logger.info('Inference Done') + return pred diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 9811811e..d8d2458a 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -674,5 +674,12 @@ TASK_OUTPUTS = { # { # {'output': 'Happiness', 'boxes': (203, 104, 663, 564)} # } - Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES] + Tasks.face_emotion: [OutputKeys.OUTPUT, OutputKeys.BOXES], + + # { + # "masks": [ + # np.array # 2D array containing only 0, 255 + # ] + # } + Tasks.product_segmentation: [OutputKeys.MASKS], } diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index ff56658f..7fa66b5f 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -187,6 +187,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { (Pipelines.face_human_hand_detection, 'damo/cv_nanodet_face-human-hand-detection'), Tasks.face_emotion: (Pipelines.face_emotion, 'damo/cv_face-emotion'), + Tasks.product_segmentation: (Pipelines.product_segmentation, + 'damo/cv_F3Net_product-segmentation'), } diff --git a/modelscope/pipelines/cv/product_segmentation_pipeline.py b/modelscope/pipelines/cv/product_segmentation_pipeline.py new file mode 100644 index 00000000..244b01d7 --- /dev/null +++ b/modelscope/pipelines/cv/product_segmentation_pipeline.py @@ -0,0 +1,40 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.product_segmentation import seg_infer +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.product_segmentation, module_name=Pipelines.product_segmentation) +class F3NetForProductSegmentationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create product segmentation pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + mask = seg_infer.inference(self.model, self.device, + input['input_path']) + return {OutputKeys.MASKS: mask} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 4aff1d05..7968fcd1 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -45,6 +45,7 @@ class CVTasks(object): hand_static = 'hand-static' face_human_hand_detection = 'face-human-hand-detection' face_emotion = 'face-emotion' + product_segmentation = 'product-segmentation' # image editing skin_retouching = 'skin-retouching' diff --git a/tests/pipelines/test_product_segmentation.py b/tests/pipelines/test_product_segmentation.py new file mode 100644 index 00000000..8f41c13c --- /dev/null +++ b/tests/pipelines/test_product_segmentation.py @@ -0,0 +1,43 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class ProductSegmentationTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_F3Net_product-segmentation' + self.input = { + 'input_path': 'data/test/images/product_segmentation.jpg' + } + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + cv2.imwrite('test_product_segmentation_mask.jpg', + result[OutputKeys.MASKS]) + logger.info('test done') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + product_segmentation = pipeline( + Tasks.product_segmentation, model=self.model_id) + self.pipeline_inference(product_segmentation, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + product_segmentation = pipeline(Tasks.product_segmentation) + self.pipeline_inference(product_segmentation, self.input) + + +if __name__ == '__main__': + unittest.main()