diff --git a/data/test/images/crowd_counting.jpg b/data/test/images/crowd_counting.jpg new file mode 100644 index 00000000..0468fe5b --- /dev/null +++ b/data/test/images/crowd_counting.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:03c9b0ae20b5000b083e8211e2c119176b88db0ea4f48e29b86dcf2f901e382b +size 130079 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 220b3c32..a0aab6d3 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -19,6 +19,7 @@ class Models(object): gpen = 'gpen' product_retrieval_embedding = 'product-retrieval-embedding' body_2d_keypoints = 'body-2d-keypoints' + crowd_counting = 'HRNetCrowdCounting' # nlp models bert = 'bert' @@ -107,6 +108,7 @@ class Pipelines(object): image_to_image_generation = 'image-to-image-generation' skin_retouching = 'unet-skin-retouching' tinynas_classification = 'tinynas-classification' + crowd_counting = 'hrnet-crowd-counting' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 397c2fba..a05bc57d 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from . import (action_recognition, animal_recognition, body_2d_keypoints, - cartoon, cmdssl_video_embedding, face_detection, + cartoon, cmdssl_video_embedding, crowd_counting, face_detection, face_generation, image_classification, image_color_enhance, image_colorization, image_denoise, image_instance_segmentation, image_portrait_enhancement, image_to_image_generation, diff --git a/modelscope/models/cv/crowd_counting/__init__.py b/modelscope/models/cv/crowd_counting/__init__.py new file mode 100644 index 00000000..b5eeb937 --- /dev/null +++ b/modelscope/models/cv/crowd_counting/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .cc_model import HRNetCrowdCounting + +else: + _import_structure = { + 'cc_model': ['HRNetCrowdCounting'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/crowd_counting/cc_model.py b/modelscope/models/cv/crowd_counting/cc_model.py new file mode 100644 index 00000000..4e3d0e9f --- /dev/null +++ b/modelscope/models/cv/crowd_counting/cc_model.py @@ -0,0 +1,34 @@ +import os +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + Tasks.crowd_counting, module_name=Models.crowd_counting) +class HRNetCrowdCounting(TorchModel): + + def __init__(self, model_dir: str): + super().__init__(model_dir) + + from .hrnet_aspp_relu import HighResolutionNet as HRNet_aspp_relu + + domain_center_model = os.path.join( + model_dir, 'average_clip_domain_center_54.97.npz') + net = HRNet_aspp_relu( + attn_weight=1.0, + fix_domain=0, + domain_center_model=domain_center_model) + net.load_state_dict( + torch.load( + os.path.join(model_dir, 'DCANet_final.pth'), + map_location='cpu')) + self.model = net + + def forward(self, inputs): + return self.model(inputs) diff --git a/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py b/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py new file mode 100644 index 00000000..982ba939 --- /dev/null +++ b/modelscope/models/cv/crowd_counting/hrnet_aspp_relu.py @@ -0,0 +1,638 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn) +# https://github.com/HRNet/HRNet-Image-Classification/blob/master/lib/models/cls_hrnet.py +# ------------------------------------------------------------------------------ + +import functools +import logging +import os + +import numpy as np +import torch +import torch._utils +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.utils.logger import get_logger + +BN_MOMENTUM = 0.01 # 0.01 for seg +logger = get_logger() + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d( + planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + + def __init__(self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, + num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.info(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.info(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.info(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d( + num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + nn.BatchNorm2d( + num_inchannels[i], momentum=BN_MOMENTUM), + nn.Upsample( + scale_factor=2**(j - i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + nn.BatchNorm2d( + num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + nn.BatchNorm2d( + num_outchannels_conv3x3, + momentum=BN_MOMENTUM), nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + +class HighResolutionNet(nn.Module): + + def __init__(self, + leaky_relu=False, + attn_weight=1, + fix_domain=1, + domain_center_model='', + **kwargs): + super(HighResolutionNet, self).__init__() + + self.criterion_attn = torch.nn.MSELoss(reduction='sum') + self.domain_center_model = domain_center_model + self.attn_weight = attn_weight + self.fix_domain = fix_domain + self.cosine = 1 + + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + 64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + num_channels = 64 + block = blocks_dict['BOTTLENECK'] + num_blocks = 4 + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion * num_channels + + # -- stage 2 + self.stage2_cfg = {} + self.stage2_cfg['NUM_MODULES'] = 1 + self.stage2_cfg['NUM_BRANCHES'] = 2 + self.stage2_cfg['BLOCK'] = 'BASIC' + self.stage2_cfg['NUM_BLOCKS'] = [4, 4] + self.stage2_cfg['NUM_CHANNELS'] = [40, 80] + self.stage2_cfg['FUSE_METHOD'] = 'SUM' + + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition1 = self._make_transition_layer([stage1_out_channel], + num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + # -- stage 3 + self.stage3_cfg = {} + self.stage3_cfg['NUM_MODULES'] = 4 + self.stage3_cfg['NUM_BRANCHES'] = 3 + self.stage3_cfg['BLOCK'] = 'BASIC' + self.stage3_cfg['NUM_BLOCKS'] = [4, 4, 4] + self.stage3_cfg['NUM_CHANNELS'] = [40, 80, 160] + self.stage3_cfg['FUSE_METHOD'] = 'SUM' + + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + last_inp_channels = np.int(np.sum(pre_stage_channels)) + 256 + self.redc_layer = nn.Sequential( + nn.Conv2d( + in_channels=last_inp_channels, + out_channels=128, + kernel_size=3, + stride=1, + padding=1), + nn.BatchNorm2d(128, momentum=BN_MOMENTUM), + nn.ReLU(True), + ) + + self.aspp = nn.ModuleList(aspp(in_channel=128)) + + # additional layers specfic for Phase 3 + self.pred_conv = nn.Conv2d(128, 512, 3, padding=1) + self.pred_bn = nn.BatchNorm2d(512) + self.GAP = nn.AdaptiveAvgPool2d(1) + + # Specially for hidden domain + # Set the domain for learnable parameters + domain_center_src = np.load(self.domain_center_model) + G_SHA = torch.from_numpy(domain_center_src['G_SHA']).view(1, -1, 1, 1) + G_SHB = torch.from_numpy(domain_center_src['G_SHB']).view(1, -1, 1, 1) + G_QNRF = torch.from_numpy(domain_center_src['G_QNRF']).view( + 1, -1, 1, 1) + + self.n_domain = 3 + + self.G_all = torch.cat( + [G_SHA.clone(), G_SHB.clone(), + G_QNRF.clone()], dim=0) + + self.G_all = nn.Parameter(self.G_all) + + self.last_layer = nn.Sequential( + nn.Conv2d( + in_channels=128, + out_channels=64, + kernel_size=3, + stride=1, + padding=1), + nn.BatchNorm2d(64, momentum=BN_MOMENTUM), + nn.ReLU(True), + nn.Conv2d( + in_channels=64, + out_channels=32, + kernel_size=3, + stride=1, + padding=1), + nn.BatchNorm2d(32, momentum=BN_MOMENTUM), + nn.ReLU(True), + nn.Conv2d( + in_channels=32, + out_channels=1, + kernel_size=1, + stride=1, + padding=0), + ) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d( + num_channels_cur_layer[i], + momentum=BN_MOMENTUM), nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i - num_branches_pre else inchannels + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, + layer_config, + num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, block, num_blocks, + num_inchannels, num_channels, fuse_method, + reset_multi_scale_output)) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + x_head_1 = x + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + + x = self.stage3(x_list) + + # Replace the classification heaeder with custom setting + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate( + x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x2 = F.interpolate( + x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=False) + x = torch.cat([x[0], x1, x2, x_head_1], 1) + # first, reduce the channel down + x = self.redc_layer(x) + + pred_attn = self.GAP(F.relu_(self.pred_bn(self.pred_conv(x)))) + pred_attn = F.softmax(pred_attn, dim=1) + pred_attn_list = torch.chunk(pred_attn, 4, dim=1) + + aspp_out = [] + for k, v in enumerate(self.aspp): + if k % 2 == 0: + aspp_out.append(self.aspp[k + 1](v(x))) + else: + continue + # Using Aspp add, and relu inside + for i in range(4): + x = x + F.relu_(aspp_out[i] * 0.25) * pred_attn_list[i] + + bz = x.size(0) + # -- Besides, we also need to let the prediction attention be close to visable domain + # -- Calculate the domain distance and get the weights + # - First, detach domains + G_all_d = self.G_all.detach() # use detached G_all for calulcating + pred_attn_d = pred_attn.detach().view(bz, 512, 1, 1) + + if self.cosine == 1: + G_A, G_B, G_Q = torch.chunk(G_all_d, self.n_domain, dim=0) + + cos_dis_A = F.cosine_similarity(pred_attn_d, G_A, dim=1).view(-1) + cos_dis_B = F.cosine_similarity(pred_attn_d, G_B, dim=1).view(-1) + cos_dis_Q = F.cosine_similarity(pred_attn_d, G_Q, dim=1).view(-1) + + cos_dis_all = torch.stack([cos_dis_A, cos_dis_B, + cos_dis_Q]).view(bz, -1) # bz*3 + + cos_dis_all = F.softmax(cos_dis_all, dim=1) + + target_attn = cos_dis_all.view(bz, self.n_domain, 1, 1, 1).expand( + bz, self.n_domain, 512, 1, 1) * self.G_all.view( + 1, self.n_domain, 512, 1, 1).expand( + bz, self.n_domain, 512, 1, 1) + target_attn = torch.sum( + target_attn, dim=1, keepdim=False) # bz * 512 * 1 * 1 + + if self.fix_domain: + target_attn = target_attn.detach() + + else: + raise ValueError('Have not implemented not cosine distance yet') + + x = self.last_layer(x) + x = F.relu_(x) + + x = F.interpolate( + x, size=(x0_h * 2, x0_w * 2), mode='bilinear', align_corners=False) + + return x, pred_attn, target_attn + + def init_weights( + self, + pretrained='', + ): + logger.info('=> init weights from normal distribution') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, std=0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained) + logger.info(f'=> loading pretrained model {pretrained}') + model_dict = self.state_dict() + pretrained_dict = { + k: v + for k, v in pretrained_dict.items() if k in model_dict.keys() + } + for k, _ in pretrained_dict.items(): + logger.info(f'=> loading {k} pretrained model {pretrained}') + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + else: + assert 1 == 2 + + +def aspp(aspp_num=4, aspp_stride=2, in_channel=512, use_bn=True): + aspp_list = [] + for i in range(aspp_num): + pad = (i + 1) * aspp_stride + dilate = pad + conv_aspp = nn.Conv2d( + in_channel, in_channel, 3, padding=pad, dilation=dilate) + aspp_list.append(conv_aspp) + if use_bn: + aspp_list.append(nn.BatchNorm2d(in_channel)) + + return aspp_list diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 6a45b3e3..f279f311 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -132,6 +132,8 @@ TASK_OUTPUTS = { # image matting result for single sample # { # "output_img": np.array with shape(h, w, 4) + # for matting or (h, w, 3) for general purpose + # , shape(h, w) for crowd counting # } Tasks.portrait_matting: [OutputKeys.OUTPUT_IMG], @@ -143,6 +145,7 @@ TASK_OUTPUTS = { Tasks.image_color_enhancement: [OutputKeys.OUTPUT_IMG], Tasks.image_denoising: [OutputKeys.OUTPUT_IMG], Tasks.image_portrait_enhancement: [OutputKeys.OUTPUT_IMG], + Tasks.crowd_counting: [OutputKeys.SCORES, OutputKeys.OUTPUT_IMG], # image generation task result for a single image # {"output_img": np.array with shape (h, w, 3)} diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 12d8e4e9..1066fa8d 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -128,6 +128,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_convnextTiny_ocr-recognition_damo'), Tasks.skin_retouching: (Pipelines.skin_retouching, 'damo/cv_unet_skin-retouching'), + Tasks.crowd_counting: (Pipelines.crowd_counting, + 'damo/cv_hrnet_crowd-counting_dcanet'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index c424818b..91a2f1e0 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: from .animal_recognition_pipeline import AnimalRecognitionPipeline from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline + from .crowd_counting_pipeline import CrowdCountingPipeline from .image_detection_pipeline import ImageDetectionPipeline from .face_detection_pipeline import FaceDetectionPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline @@ -40,6 +41,7 @@ else: 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], 'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], + 'crowd_counting_pipeline': ['CrowdCountingPipeline'], 'image_detection_pipeline': ['ImageDetectionPipeline'], 'face_detection_pipeline': ['FaceDetectionPipeline'], 'face_image_generation_pipeline': ['FaceImageGenerationPipeline'], diff --git a/modelscope/pipelines/cv/crowd_counting_pipeline.py b/modelscope/pipelines/cv/crowd_counting_pipeline.py new file mode 100644 index 00000000..3143825b --- /dev/null +++ b/modelscope/pipelines/cv/crowd_counting_pipeline.py @@ -0,0 +1,153 @@ +import math +from typing import Any, Dict + +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.crowd_counting import HRNetCrowdCounting +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.crowd_counting, module_name=Pipelines.crowd_counting) +class CrowdCountingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + model: model id on modelscope hub. + """ + assert isinstance(model, str), 'model must be a single str' + super().__init__(model=model, auto_collate=False, **kwargs) + logger.info(f'loading model from dir {model}') + self.infer_model = HRNetCrowdCounting(model).to(self.device) + self.infer_model.eval() + logger.info('load model done') + + def resize(self, img): + height = img.size[1] + width = img.size[0] + resize_height = height + resize_width = width + if resize_width >= 2048: + tmp = resize_width + resize_width = 2048 + resize_height = (resize_width / tmp) * resize_height + + if resize_height >= 2048: + tmp = resize_height + resize_height = 2048 + resize_width = (resize_height / tmp) * resize_width + + if resize_height <= 416: + tmp = resize_height + resize_height = 416 + resize_width = (resize_height / tmp) * resize_width + if resize_width <= 416: + tmp = resize_width + resize_width = 416 + resize_height = (resize_width / tmp) * resize_height + + # other constraints + if resize_height < resize_width: + if resize_width / resize_height > 2048 / 416: # 1024/416=2.46 + resize_width = 2048 + resize_height = 416 + else: + if resize_height / resize_width > 2048 / 416: + resize_height = 2048 + resize_width = 416 + + resize_height = math.ceil(resize_height / 32) * 32 + resize_width = math.ceil(resize_width / 32) * 32 + img = transforms.Resize([resize_height, resize_width])(img) + return img + + def merge_crops(self, eval_shape, eval_p, pred_m): + for i in range(3): + for j in range(3): + start_h, start_w = math.floor(eval_shape[2] / 4), math.floor( + eval_shape[3] / 4) + valid_h, valid_w = eval_shape[2] // 2, eval_shape[3] // 2 + pred_h = math.floor( + 3 * eval_shape[2] / 4) + (eval_shape[2] // 2) * ( + i - 1) + pred_w = math.floor( + 3 * eval_shape[3] / 4) + (eval_shape[3] // 2) * ( + j - 1) + if i == 0: + valid_h = math.floor(3 * eval_shape[2] / 4) + start_h = 0 + pred_h = 0 + elif i == 2: + valid_h = math.ceil(3 * eval_shape[2] / 4) + + if j == 0: + valid_w = math.floor(3 * eval_shape[3] / 4) + start_w = 0 + pred_w = 0 + elif j == 2: + valid_w = math.ceil(3 * eval_shape[3] / 4) + pred_m[:, :, pred_h:pred_h + valid_h, pred_w:pred_w + + valid_w] += eval_p[i * 3 + j:i * 3 + j + 1, :, + start_h:start_h + valid_h, + start_w:start_w + valid_w] + return pred_m + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input) + img = self.resize(img) + img_ori_tensor = transforms.ToTensor()(img) + img_shape = img_ori_tensor.shape + img = transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225))( + img_ori_tensor) + patch_height, patch_width = (img_shape[1]) // 2, (img_shape[2]) // 2 + imgs = [] + for i in range(3): + for j in range(3): + start_h, start_w = (patch_height // 2) * i, (patch_width + // 2) * j + imgs.append(img[:, start_h:start_h + patch_height, + start_w:start_w + patch_width]) + + imgs = torch.stack(imgs) + eval_img = imgs.to(self.device) + eval_patchs = torch.squeeze(eval_img) + prediction_map = torch.zeros( + (1, 1, img_shape[1] // 2, img_shape[2] // 2)).to(self.device) + result = { + 'img': eval_patchs, + 'map': prediction_map, + } + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + counts, img_data = self.perform_inference(input) + return {OutputKeys.SCORES: counts, OutputKeys.OUTPUT_IMG: img_data} + + @torch.no_grad() + def perform_inference(self, data): + eval_patchs = data['img'] + prediction_map = data['map'] + eval_prediction, _, _ = self.infer_model(eval_patchs) + eval_patchs_shape = eval_prediction.shape + prediction_map = self.merge_crops(eval_patchs_shape, eval_prediction, + prediction_map) + + return torch.sum( + prediction_map, dim=( + 1, 2, + 3)).data.cpu().numpy(), prediction_map.data.cpu().numpy()[0][0] + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index be077551..927eafbd 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -60,6 +60,7 @@ class CVTasks(object): video_category = 'video-category' video_embedding = 'video-embedding' virtual_try_on = 'virtual-try-on' + crowd_counting = 'crowd-counting' class NLPTasks(object): diff --git a/modelscope/utils/file_utils.py b/modelscope/utils/file_utils.py index a04d890f..6d4fcc59 100644 --- a/modelscope/utils/file_utils.py +++ b/modelscope/utils/file_utils.py @@ -3,6 +3,9 @@ import inspect import os +import cv2 +import numpy as np + # TODO: remove this api, unify to flattened args def func_receive_dict_inputs(func): @@ -36,3 +39,19 @@ def get_default_cache_dir(): default_cache_dir = os.path.expanduser( os.path.join('~/.cache', 'modelscope')) return default_cache_dir + + +def numpy_to_cv2img(vis_img): + """to convert a np.array Hotmap with shape(h, w) to cv2 img + + Args: + vis_img (np.array): input data + + Returns: + cv2 img + """ + vis_img = (vis_img - vis_img.min()) / ( + vis_img.max() - vis_img.min() + 1e-5) + vis_img = (vis_img * 255).astype(np.uint8) + vis_img = cv2.applyColorMap(vis_img, cv2.COLORMAP_JET) + return vis_img diff --git a/tests/pipelines/test_crowd_counting.py b/tests/pipelines/test_crowd_counting.py new file mode 100644 index 00000000..a3c59378 --- /dev/null +++ b/tests/pipelines/test_crowd_counting.py @@ -0,0 +1,60 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +import cv2 +import numpy as np +from PIL import Image + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.file_utils import numpy_to_cv2img +from modelscope.utils.logger import get_logger +from modelscope.utils.test_utils import test_level + +logger = get_logger() + + +class CrowdCountingTest(unittest.TestCase): + + def setUp(self) -> None: + self.input_location = 'data/test/images/crowd_counting.jpg' + self.model_id = 'damo/cv_hrnet_crowd-counting_dcanet' + + def save_result(self, result): + print('scores:', result[OutputKeys.SCORES]) + vis_img = result[OutputKeys.OUTPUT_IMG] + vis_img = numpy_to_cv2img(vis_img) + cv2.imwrite('result.jpg', vis_img) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_crowd_counting(self): + crowd_counting = pipeline(Tasks.crowd_counting, model=self.model_id) + result = crowd_counting(self.input_location) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_crowd_counting_with_image(self): + crowd_counting = pipeline(Tasks.crowd_counting, model=self.model_id) + img = Image.open(self.input_location) + result = crowd_counting(img) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_crowd_counting_with_default_task(self): + crowd_counting = pipeline(Tasks.crowd_counting) + result = crowd_counting(self.input_location) + if result: + self.save_result(result) + else: + raise ValueError('process error') + + +if __name__ == '__main__': + unittest.main()