From 9c4765923aa6efb1dd29f05b705799277f934ade Mon Sep 17 00:00:00 2001 From: "lee.lcy" Date: Mon, 22 Aug 2022 21:57:05 +0800 Subject: [PATCH] [to #42322933] add image-reid-person add image-reid-person Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9818427 --- data/test/images/image_reid_person.jpg | 3 + modelscope/metainfo.py | 2 + modelscope/models/cv/__init__.py | 9 +- .../models/cv/image_reid_person/__init__.py | 22 + .../models/cv/image_reid_person/pass_model.py | 136 ++++++ .../cv/image_reid_person/transreid_model.py | 418 ++++++++++++++++++ modelscope/outputs.py | 6 + modelscope/pipelines/builder.py | 2 + modelscope/pipelines/cv/__init__.py | 2 + .../cv/image_reid_person_pipeline.py | 58 +++ modelscope/utils/constant.py | 3 +- tests/pipelines/test_image_reid_person.py | 53 +++ 12 files changed, 709 insertions(+), 5 deletions(-) create mode 100644 data/test/images/image_reid_person.jpg create mode 100644 modelscope/models/cv/image_reid_person/__init__.py create mode 100644 modelscope/models/cv/image_reid_person/pass_model.py create mode 100644 modelscope/models/cv/image_reid_person/transreid_model.py create mode 100644 modelscope/pipelines/cv/image_reid_person_pipeline.py create mode 100644 tests/pipelines/test_image_reid_person.py diff --git a/data/test/images/image_reid_person.jpg b/data/test/images/image_reid_person.jpg new file mode 100644 index 00000000..078468ec --- /dev/null +++ b/data/test/images/image_reid_person.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4c9a7e42edc7065c16972ff56267aad63f5233e36aa5a699b84939f5bad73276 +size 2451 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 0bc16026..4e759305 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -20,6 +20,7 @@ class Models(object): product_retrieval_embedding = 'product-retrieval-embedding' body_2d_keypoints = 'body-2d-keypoints' crowd_counting = 'HRNetCrowdCounting' + image_reid_person = 'passvitb' # nlp models bert = 'bert' @@ -112,6 +113,7 @@ class Pipelines(object): tinynas_classification = 'tinynas-classification' crowd_counting = 'hrnet-crowd-counting' video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking' + image_reid_person = 'passvitb-image-reid-person' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index f2ecd08e..dd7e6724 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -3,7 +3,8 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, cartoon, cmdssl_video_embedding, crowd_counting, face_detection, face_generation, image_classification, image_color_enhance, image_colorization, image_denoise, image_instance_segmentation, - image_portrait_enhancement, image_to_image_generation, - image_to_image_translation, object_detection, - product_retrieval_embedding, salient_detection, - super_resolution, video_single_object_tracking, virual_tryon) + image_portrait_enhancement, image_reid_person, + image_to_image_generation, image_to_image_translation, + object_detection, product_retrieval_embedding, + salient_detection, super_resolution, + video_single_object_tracking, virual_tryon) diff --git a/modelscope/models/cv/image_reid_person/__init__.py b/modelscope/models/cv/image_reid_person/__init__.py new file mode 100644 index 00000000..0fe0bede --- /dev/null +++ b/modelscope/models/cv/image_reid_person/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .pass_model import PASS + +else: + _import_structure = { + 'pass_model': ['PASS'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_reid_person/pass_model.py b/modelscope/models/cv/image_reid_person/pass_model.py new file mode 100644 index 00000000..2222fedb --- /dev/null +++ b/modelscope/models/cv/image_reid_person/pass_model.py @@ -0,0 +1,136 @@ +# The implementation is also open-sourced by the authors as PASS-reID, and is available publicly on +# https://github.com/CASIA-IVA-Lab/PASS-reID + +import os +from enum import Enum + +import torch +import torch.nn as nn + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from .transreid_model import vit_base_patch16_224_TransReID + + +class Fusions(Enum): + CAT = 'cat' + MEAN = 'mean' + + +@MODELS.register_module( + Tasks.image_reid_person, module_name=Models.image_reid_person) +class PASS(TorchModel): + + def __init__(self, cfg: Config, model_dir: str, **kwargs): + super(PASS, self).__init__(model_dir=model_dir) + size_train = cfg.INPUT.SIZE_TRAIN + sie_coe = cfg.MODEL.SIE_COE + stride_size = cfg.MODEL.STRIDE_SIZE + drop_path = cfg.MODEL.DROP_PATH + drop_out = cfg.MODEL.DROP_OUT + att_drop_rate = cfg.MODEL.ATT_DROP_RATE + gem_pooling = cfg.MODEL.GEM_POOLING + stem_conv = cfg.MODEL.STEM_CONV + weight = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + self.neck_feat = cfg.TEST.NECK_FEAT + self.dropout_rate = cfg.MODEL.DROPOUT_RATE + self.num_classes = cfg.DATASETS.NUM_CLASSES + self.multi_neck = cfg.MODEL.MULTI_NECK + self.feat_fusion = cfg.MODEL.FEAT_FUSION + + self.base = vit_base_patch16_224_TransReID( + img_size=size_train, + sie_xishu=sie_coe, + stride_size=stride_size, + drop_path_rate=drop_path, + drop_rate=drop_out, + attn_drop_rate=att_drop_rate, + gem_pool=gem_pooling, + stem_conv=stem_conv) + self.in_planes = self.base.in_planes + + if self.feat_fusion == Fusions.CAT.value: + self.classifier = nn.Linear( + self.in_planes * 2, self.num_classes, bias=False) + elif self.feat_fusion == Fusions.MEAN.value: + self.classifier = nn.Linear( + self.in_planes, self.num_classes, bias=False) + + if self.multi_neck: + self.bottleneck = nn.BatchNorm1d(self.in_planes) + self.bottleneck.bias.requires_grad_(False) + self.bottleneck_1 = nn.BatchNorm1d(self.in_planes) + self.bottleneck_1.bias.requires_grad_(False) + self.bottleneck_2 = nn.BatchNorm1d(self.in_planes) + self.bottleneck_2.bias.requires_grad_(False) + self.bottleneck_3 = nn.BatchNorm1d(self.in_planes) + self.bottleneck_3.bias.requires_grad_(False) + else: + if self.feat_fusion == Fusions.CAT.value: + self.bottleneck = nn.BatchNorm1d(self.in_planes * 2) + self.bottleneck.bias.requires_grad_(False) + elif self.feat_fusion == Fusions.MEAN.value: + self.bottleneck = nn.BatchNorm1d(self.in_planes) + self.bottleneck.bias.requires_grad_(False) + + self.dropout = nn.Dropout(self.dropout_rate) + + self.load_param(weight) + + def forward(self, input): + + global_feat, local_feat_1, local_feat_2, local_feat_3 = self.base( + input) + + # single-neck, almost the same performance + if not self.multi_neck: + if self.feat_fusion == Fusions.MEAN.value: + local_feat = local_feat_1 / 3. + local_feat_2 / 3. + local_feat_3 / 3. + final_feat_before = (global_feat + local_feat) / 2 + elif self.feat_fusion == Fusions.CAT.value: + final_feat_before = torch.cat( + (global_feat, local_feat_1 / 3. + local_feat_2 / 3. + + local_feat_3 / 3.), + dim=1) + + final_feat_after = self.bottleneck(final_feat_before) + # multi-neck + else: + feat = self.bottleneck(global_feat) + local_feat_1_bn = self.bottleneck_1(local_feat_1) + local_feat_2_bn = self.bottleneck_2(local_feat_2) + local_feat_3_bn = self.bottleneck_3(local_feat_3) + + if self.feat_fusion == Fusions.MEAN.value: + final_feat_before = ((global_feat + local_feat_1 / 3 + + local_feat_2 / 3 + local_feat_3 / 3) + / 2.) + final_feat_after = (feat + local_feat_1_bn / 3 + + local_feat_2_bn / 3 + + local_feat_3_bn / 3) / 2. + elif self.feat_fusion == Fusions.CAT.value: + final_feat_before = torch.cat( + (global_feat, local_feat_1 / 3. + local_feat_2 / 3. + + local_feat_3 / 3.), + dim=1) + final_feat_after = torch.cat( + (feat, local_feat_1_bn / 3 + local_feat_2_bn / 3 + + local_feat_3_bn / 3), + dim=1) + + if self.neck_feat == 'after': + return final_feat_after + else: + return final_feat_before + + def load_param(self, trained_path): + param_dict = torch.load(trained_path, map_location='cpu') + for i in param_dict: + try: + self.state_dict()[i.replace('module.', + '')].copy_(param_dict[i]) + except Exception: + continue diff --git a/modelscope/models/cv/image_reid_person/transreid_model.py b/modelscope/models/cv/image_reid_person/transreid_model.py new file mode 100644 index 00000000..275c4e22 --- /dev/null +++ b/modelscope/models/cv/image_reid_person/transreid_model.py @@ -0,0 +1,418 @@ +# The implementation is also open-sourced by the authors as PASS-reID, and is available publicly on +# https://github.com/CASIA-IVA-Lab/PASS-reID + +import collections.abc as container_abcs +from functools import partial +from itertools import repeat + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# From PyTorch internals +def _ntuple(n): + + def parse(x): + if isinstance(x, container_abcs.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) + + +def vit_base_patch16_224_TransReID( + img_size=(256, 128), + stride_size=16, + drop_path_rate=0.1, + camera=0, + view=0, + local_feature=False, + sie_xishu=1.5, + **kwargs): + model = TransReID( + img_size=img_size, + patch_size=16, + stride_size=stride_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + camera=camera, + view=view, + drop_path_rate=drop_path_rate, + sie_xishu=sie_xishu, + local_feature=local_feature, + **kwargs) + return model + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class TransReID(nn.Module): + """Transformer-based Object Re-Identification + """ + + def __init__(self, + img_size=224, + patch_size=16, + stride_size=16, + in_chans=3, + num_classes=1000, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + camera=0, + view=0, + drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), + local_feature=False, + sie_xishu=1.0, + hw_ratio=1, + gem_pool=False, + stem_conv=False): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.local_feature = local_feature + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + stride_size=stride_size, + in_chans=in_chans, + embed_dim=embed_dim, + stem_conv=stem_conv) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part_token1 = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part_token2 = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part_token3 = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + self.cls_pos = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part1_pos = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part2_pos = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.part3_pos = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.cam_num = camera + self.view_num = view + self.sie_xishu = sie_xishu + self.in_planes = 768 + self.gem_pool = gem_pool + + # Initialize SIE Embedding + if camera > 1 and view > 1: + self.sie_embed = nn.Parameter( + torch.zeros(camera * view, 1, embed_dim)) + elif camera > 1: + self.sie_embed = nn.Parameter(torch.zeros(camera, 1, embed_dim)) + elif view > 1: + self.sie_embed = nn.Parameter(torch.zeros(view, 1, embed_dim)) + + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer) for i in range(depth) + ]) + + self.norm = norm_layer(embed_dim) + + # Classifier head + self.fc = nn.Linear(embed_dim, + num_classes) if num_classes > 0 else nn.Identity() + + self.gem = GeneralizedMeanPooling() + + def forward_features(self, x, camera_id, view_id): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand( + B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + part_tokens1 = self.part_token1.expand(B, -1, -1) + part_tokens2 = self.part_token2.expand(B, -1, -1) + part_tokens3 = self.part_token3.expand(B, -1, -1) + x = torch.cat( + (cls_tokens, part_tokens1, part_tokens2, part_tokens3, x), dim=1) + + if self.cam_num > 0 and self.view_num > 0: + x = x + self.pos_embed + self.sie_xishu * self.sie_embed[ + camera_id * self.view_num + view_id] + elif self.cam_num > 0: + x = x + self.pos_embed + self.sie_xishu * self.sie_embed[camera_id] + elif self.view_num > 0: + x = x + self.pos_embed + self.sie_xishu * self.sie_embed[view_id] + else: + x = x + torch.cat((self.cls_pos, self.part1_pos, self.part2_pos, + self.part3_pos, self.pos_embed), + dim=1) + + x = self.pos_drop(x) + + if self.local_feature: + for blk in self.blocks[:-1]: + x = blk(x) + return x + else: + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + if self.gem_pool: + gf = self.gem(x[:, 1:].permute(0, 2, 1)).squeeze() + return x[:, 0] + gf + return x[:, 0], x[:, 1], x[:, 2], x[:, 3] + + def forward(self, x, cam_label=None, view_label=None): + global_feat, local_feat_1, local_feat_2, local_feat_3 = self.forward_features( + x, cam_label, view_label) + return global_feat, local_feat_1, local_feat_2, local_feat_3 + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding with overlapping patches + """ + + def __init__(self, + img_size=224, + patch_size=16, + stride_size=16, + in_chans=3, + embed_dim=768, + stem_conv=False): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + stride_size_tuple = to_2tuple(stride_size) + self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1 + self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1 + self.num_patches = self.num_x * self.num_y + self.img_size = img_size + self.patch_size = patch_size + + self.stem_conv = stem_conv + if self.stem_conv: + hidden_dim = 64 + stem_stride = 2 + stride_size = patch_size = patch_size[0] // stem_stride + self.conv = nn.Sequential( + nn.Conv2d( + in_chans, + hidden_dim, + kernel_size=7, + stride=stem_stride, + padding=3, + bias=False), + IBN(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d( + hidden_dim, + hidden_dim, + kernel_size=3, + stride=1, + padding=1, + bias=False), + IBN(hidden_dim), + nn.ReLU(inplace=True), + nn.Conv2d( + hidden_dim, + hidden_dim, + kernel_size=3, + stride=1, + padding=1, + bias=False), + nn.BatchNorm2d(hidden_dim), + nn.ReLU(inplace=True), + ) + in_chans = hidden_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=stride_size) + + def forward(self, x): + if self.stem_conv: + x = self.conv(x) + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) # [64, 8, 768] + + return x + + +class GeneralizedMeanPooling(nn.Module): + """Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. + The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` + - At p = infinity, one gets Max Pooling + - At p = 1, one gets Average Pooling + The output is of size H x W, for any input size. + The number of output features is equal to the number of input planes. + Args: + output_size: the target output size of the image of the form H x W. + Can be a tuple (H, W) or a single H for a square image H x H + H and W can be either a ``int``, or ``None`` which means the size will + be the same as that of the input. + """ + + def __init__(self, norm=3, output_size=1, eps=1e-6): + super(GeneralizedMeanPooling, self).__init__() + assert norm > 0 + self.p = float(norm) + self.output_size = output_size + self.eps = eps + + def forward(self, x): + x = x.clamp(min=self.eps).pow(self.p) + return F.adaptive_avg_pool1d(x, self.output_size).pow(1. / self.p) + + +class Block(nn.Module): + + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Attention(nn.Module): + + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 200a03cd..640d67fa 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -503,4 +503,10 @@ TASK_OUTPUTS = { # "labels": ["entailment", "contradiction", "neutral"] # } Tasks.visual_entailment: [OutputKeys.SCORES, OutputKeys.LABELS], + + # image person reid result for single sample + # { + # "img_embedding": np.array with shape [1, D], + # } + Tasks.image_reid_person: [OutputKeys.IMG_EMBEDDING], } diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 4105e28b..52dfa41b 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -134,6 +134,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.video_single_object_tracking: (Pipelines.video_single_object_tracking, 'damo/cv_vitb_video-single-object-tracking_ostrack'), + Tasks.image_reid_person: (Pipelines.image_reid_person, + 'damo/cv_passvitb_image-reid-person_market'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index cee91c8e..4ff1b856 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -24,6 +24,7 @@ if TYPE_CHECKING: from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline from .image_matting_pipeline import ImageMattingPipeline from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline + from .image_reid_person_pipeline import ImageReidPersonPipeline from .image_style_transfer_pipeline import ImageStyleTransferPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .image_to_image_generate_pipeline import Image2ImageGenerationPipeline @@ -60,6 +61,7 @@ else: 'image_matting_pipeline': ['ImageMattingPipeline'], 'image_portrait_enhancement_pipeline': ['ImagePortraitEnhancementPipeline'], + 'image_reid_person_pipeline': ['ImageReidPersonPipeline'], 'image_style_transfer_pipeline': ['ImageStyleTransferPipeline'], 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], 'image_to_image_translation_pipeline': diff --git a/modelscope/pipelines/cv/image_reid_person_pipeline.py b/modelscope/pipelines/cv/image_reid_person_pipeline.py new file mode 100644 index 00000000..a14666a1 --- /dev/null +++ b/modelscope/pipelines/cv/image_reid_person_pipeline.py @@ -0,0 +1,58 @@ +import math +import os +from typing import Any, Dict + +import torch +import torchvision.transforms as T +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_reid_person, module_name=Pipelines.image_reid_person) +class ImageReidPersonPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + assert isinstance(model, str), 'model must be a single str' + super().__init__(model=model, auto_collate=False, **kwargs) + logger.info(f'loading model config from dir {model}') + + cfg_path = os.path.join(model, ModelFile.CONFIGURATION) + cfg = Config.from_file(cfg_path) + cfg = cfg.model.cfg + self.model = self.model.to(self.device) + self.model.eval() + + self.val_transforms = T.Compose([ + T.Resize(cfg.INPUT.SIZE_TEST), + T.ToTensor(), + T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) + ]) + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input) + img = self.val_transforms(img) + img = img.unsqueeze(0) + img = img.to(self.device) + return {'img': img} + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + img = input['img'] + img_embedding = self.model(img) + return {OutputKeys.IMG_EMBEDDING: img_embedding} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 993a3e42..fd679d74 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -62,8 +62,9 @@ class CVTasks(object): virtual_try_on = 'virtual-try-on' crowd_counting = 'crowd-counting' - # video related + # reid and tracking video_single_object_tracking = 'video-single-object-tracking' + image_reid_person = 'image-reid-person' class NLPTasks(object): diff --git a/tests/pipelines/test_image_reid_person.py b/tests/pipelines/test_image_reid_person.py new file mode 100644 index 00000000..c3e8d487 --- /dev/null +++ b/tests/pipelines/test_image_reid_person.py @@ -0,0 +1,53 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ImageReidPersonTest(unittest.TestCase): + + def setUp(self) -> None: + self.input_location = 'data/test/images/image_reid_person.jpg' + self.model_id = 'damo/cv_passvitb_image-reid-person_market' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_reid_person(self): + image_reid_person = pipeline( + Tasks.image_reid_person, model=self.model_id) + result = image_reid_person(self.input_location) + assert result and OutputKeys.IMG_EMBEDDING in result + print( + f'The shape of img embedding is: {result[OutputKeys.IMG_EMBEDDING].shape}' + ) + print(f'The img embedding is: {result[OutputKeys.IMG_EMBEDDING]}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_image_reid_person_with_image(self): + image_reid_person = pipeline( + Tasks.image_reid_person, model=self.model_id) + img = Image.open(self.input_location) + result = image_reid_person(img) + assert result and OutputKeys.IMG_EMBEDDING in result + print( + f'The shape of img embedding is: {result[OutputKeys.IMG_EMBEDDING].shape}' + ) + print(f'The img embedding is: {result[OutputKeys.IMG_EMBEDDING]}') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_image_reid_person_with_default_model(self): + image_reid_person = pipeline(Tasks.image_reid_person) + result = image_reid_person(self.input_location) + assert result and OutputKeys.IMG_EMBEDDING in result + print( + f'The shape of img embedding is: {result[OutputKeys.IMG_EMBEDDING].shape}' + ) + print(f'The img embedding is: {result[OutputKeys.IMG_EMBEDDING]}') + + +if __name__ == '__main__': + unittest.main()