支持1130新上线模.
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10886253
master^2
| @@ -185,6 +185,7 @@ class Pipelines(object): | |||||
| live_category = 'live-category' | live_category = 'live-category' | ||||
| general_image_classification = 'vit-base_image-classification_ImageNet-labels' | general_image_classification = 'vit-base_image-classification_ImageNet-labels' | ||||
| daily_image_classification = 'vit-base_image-classification_Dailylife-labels' | daily_image_classification = 'vit-base_image-classification_Dailylife-labels' | ||||
| nextvit_small_daily_image_classification = 'nextvit-small_image-classification_Dailylife-labels' | |||||
| image_color_enhance = 'csrnet-image-color-enhance' | image_color_enhance = 'csrnet-image-color-enhance' | ||||
| virtual_try_on = 'virtual-try-on' | virtual_try_on = 'virtual-try-on' | ||||
| image_colorization = 'unet-image-colorization' | image_colorization = 'unet-image-colorization' | ||||
| @@ -330,6 +331,7 @@ class Trainers(object): | |||||
| image_inpainting = 'image-inpainting' | image_inpainting = 'image-inpainting' | ||||
| referring_video_object_segmentation = 'referring-video-object-segmentation' | referring_video_object_segmentation = 'referring-video-object-segmentation' | ||||
| image_classification_team = 'image-classification-team' | image_classification_team = 'image-classification-team' | ||||
| image_classification = 'image-classification' | |||||
| # nlp trainers | # nlp trainers | ||||
| bert_sentiment_analysis = 'bert-sentiment-analysis' | bert_sentiment_analysis = 'bert-sentiment-analysis' | ||||
| @@ -365,6 +367,7 @@ class Preprocessors(object): | |||||
| image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' | image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' | ||||
| video_summarization_preprocessor = 'video-summarization-preprocessor' | video_summarization_preprocessor = 'video-summarization-preprocessor' | ||||
| movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor' | movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor' | ||||
| image_classification_bypass_preprocessor = 'image-classification-bypass-preprocessor' | |||||
| # nlp preprocessor | # nlp preprocessor | ||||
| sen_sim_tokenizer = 'sen-sim-tokenizer' | sen_sim_tokenizer = 'sen-sim-tokenizer' | ||||
| @@ -0,0 +1,2 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .nextvit import NextViT | |||||
| @@ -0,0 +1,541 @@ | |||||
| # Part of the implementation is borrowed and modified from Next-ViT, | |||||
| # publicly available at https://github.com/bytedance/Next-ViT | |||||
| import collections.abc | |||||
| import itertools | |||||
| import math | |||||
| import os | |||||
| import warnings | |||||
| from functools import partial | |||||
| from typing import Dict, Sequence | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from einops import rearrange | |||||
| from mmcls.models.backbones.base_backbone import BaseBackbone | |||||
| from mmcls.models.builder import BACKBONES | |||||
| from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer | |||||
| from mmcv.runner import BaseModule | |||||
| from torch.nn.modules.batchnorm import _BatchNorm | |||||
| NORM_EPS = 1e-5 | |||||
| def _no_grad_trunc_normal_(tensor, mean, std, a, b): | |||||
| # Cut & paste from PyTorch official master until it's in a few official releases - RW | |||||
| # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |||||
| def norm_cdf(x): | |||||
| # Computes standard normal cumulative distribution function | |||||
| return (1. + math.erf(x / math.sqrt(2.))) / 2. | |||||
| if (mean < a - 2 * std) or (mean > b + 2 * std): | |||||
| warnings.warn( | |||||
| 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' | |||||
| 'The distribution of values may be incorrect.', | |||||
| stacklevel=2) | |||||
| with torch.no_grad(): | |||||
| # Values are generated by using a truncated uniform distribution and | |||||
| # then using the inverse CDF for the normal distribution. | |||||
| # Get upper and lower cdf values | |||||
| ll = norm_cdf((a - mean) / std) | |||||
| u = norm_cdf((b - mean) / std) | |||||
| # Uniformly fill tensor with values from [ll, u], then translate to | |||||
| # [2ll-1, 2u-1]. | |||||
| tensor.uniform_(2 * ll - 1, 2 * u - 1) | |||||
| # Use inverse cdf transform for normal distribution to get truncated | |||||
| # standard normal | |||||
| tensor.erfinv_() | |||||
| # Transform to proper mean, std | |||||
| tensor.mul_(std * math.sqrt(2.)) | |||||
| tensor.add_(mean) | |||||
| # Clamp to ensure it's in the proper range | |||||
| tensor.clamp_(min=a, max=b) | |||||
| return tensor | |||||
| def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): | |||||
| return _no_grad_trunc_normal_(tensor, mean, std, a, b) | |||||
| class ConvBNReLU(nn.Module): | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size, | |||||
| stride, | |||||
| groups=1): | |||||
| super(ConvBNReLU, self).__init__() | |||||
| self.conv = nn.Conv2d( | |||||
| in_channels, | |||||
| out_channels, | |||||
| kernel_size=kernel_size, | |||||
| stride=stride, | |||||
| padding=1, | |||||
| groups=groups, | |||||
| bias=False) | |||||
| self.norm = nn.BatchNorm2d(out_channels, eps=NORM_EPS) | |||||
| self.act = nn.ReLU(inplace=True) | |||||
| def forward(self, x): | |||||
| x = self.conv(x) | |||||
| x = self.norm(x) | |||||
| x = self.act(x) | |||||
| return x | |||||
| def _make_divisible(v, divisor, min_value=None): | |||||
| if min_value is None: | |||||
| min_value = divisor | |||||
| new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) | |||||
| # Make sure that round down does not go down by more than 10%. | |||||
| if new_v < 0.9 * v: | |||||
| new_v += divisor | |||||
| return new_v | |||||
| class PatchEmbed(nn.Module): | |||||
| def __init__(self, in_channels, out_channels, stride=1): | |||||
| super(PatchEmbed, self).__init__() | |||||
| norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS) | |||||
| if stride == 2: | |||||
| self.avgpool = nn.AvgPool2d((2, 2), | |||||
| stride=2, | |||||
| ceil_mode=True, | |||||
| count_include_pad=False) | |||||
| self.conv = nn.Conv2d( | |||||
| in_channels, out_channels, kernel_size=1, stride=1, bias=False) | |||||
| self.norm = norm_layer(out_channels) | |||||
| elif in_channels != out_channels: | |||||
| self.avgpool = nn.Identity() | |||||
| self.conv = nn.Conv2d( | |||||
| in_channels, out_channels, kernel_size=1, stride=1, bias=False) | |||||
| self.norm = norm_layer(out_channels) | |||||
| else: | |||||
| self.avgpool = nn.Identity() | |||||
| self.conv = nn.Identity() | |||||
| self.norm = nn.Identity() | |||||
| def forward(self, x): | |||||
| return self.norm(self.conv(self.avgpool(x))) | |||||
| class MHCA(nn.Module): | |||||
| """ | |||||
| Multi-Head Convolutional Attention | |||||
| """ | |||||
| def __init__(self, out_channels, head_dim): | |||||
| super(MHCA, self).__init__() | |||||
| norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS) | |||||
| self.group_conv3x3 = nn.Conv2d( | |||||
| out_channels, | |||||
| out_channels, | |||||
| kernel_size=3, | |||||
| stride=1, | |||||
| padding=1, | |||||
| groups=out_channels // head_dim, | |||||
| bias=False) | |||||
| self.norm = norm_layer(out_channels) | |||||
| self.act = nn.ReLU(inplace=True) | |||||
| self.projection = nn.Conv2d( | |||||
| out_channels, out_channels, kernel_size=1, bias=False) | |||||
| def forward(self, x): | |||||
| out = self.group_conv3x3(x) | |||||
| out = self.norm(out) | |||||
| out = self.act(out) | |||||
| out = self.projection(out) | |||||
| return out | |||||
| class Mlp(nn.Module): | |||||
| def __init__(self, | |||||
| in_features, | |||||
| out_features=None, | |||||
| mlp_ratio=None, | |||||
| drop=0., | |||||
| bias=True): | |||||
| super().__init__() | |||||
| out_features = out_features or in_features | |||||
| hidden_dim = _make_divisible(in_features * mlp_ratio, 32) | |||||
| self.conv1 = nn.Conv2d( | |||||
| in_features, hidden_dim, kernel_size=1, bias=bias) | |||||
| self.act = nn.ReLU(inplace=True) | |||||
| self.conv2 = nn.Conv2d( | |||||
| hidden_dim, out_features, kernel_size=1, bias=bias) | |||||
| self.drop = nn.Dropout(drop) | |||||
| def forward(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.act(x) | |||||
| x = self.drop(x) | |||||
| x = self.conv2(x) | |||||
| x = self.drop(x) | |||||
| return x | |||||
| class NCB(nn.Module): | |||||
| """ | |||||
| Next Convolution Block | |||||
| """ | |||||
| def __init__(self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| stride=1, | |||||
| path_dropout=0, | |||||
| drop=0, | |||||
| head_dim=32, | |||||
| mlp_ratio=3): | |||||
| super(NCB, self).__init__() | |||||
| self.in_channels = in_channels | |||||
| self.out_channels = out_channels | |||||
| norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS) | |||||
| assert out_channels % head_dim == 0 | |||||
| self.patch_embed = PatchEmbed(in_channels, out_channels, stride) | |||||
| self.mhca = MHCA(out_channels, head_dim) | |||||
| self.attention_path_dropout = DropPath(path_dropout) | |||||
| self.norm = norm_layer(out_channels) | |||||
| self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop, bias=True) | |||||
| self.mlp_path_dropout = DropPath(path_dropout) | |||||
| self.is_bn_merged = False | |||||
| def forward(self, x): | |||||
| x = self.patch_embed(x) | |||||
| x = x + self.attention_path_dropout(self.mhca(x)) | |||||
| if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged: | |||||
| out = self.norm(x) | |||||
| else: | |||||
| out = x | |||||
| x = x + self.mlp_path_dropout(self.mlp(out)) | |||||
| return x | |||||
| class E_MHSA(nn.Module): | |||||
| """ | |||||
| Efficient Multi-Head Self Attention | |||||
| """ | |||||
| def __init__(self, | |||||
| dim, | |||||
| out_dim=None, | |||||
| head_dim=32, | |||||
| qkv_bias=True, | |||||
| qk_scale=None, | |||||
| attn_drop=0, | |||||
| proj_drop=0., | |||||
| sr_ratio=1): | |||||
| super().__init__() | |||||
| self.dim = dim | |||||
| self.out_dim = out_dim if out_dim is not None else dim | |||||
| self.num_heads = self.dim // head_dim | |||||
| self.scale = qk_scale or head_dim**-0.5 | |||||
| self.q = nn.Linear(dim, self.dim, bias=qkv_bias) | |||||
| self.k = nn.Linear(dim, self.dim, bias=qkv_bias) | |||||
| self.v = nn.Linear(dim, self.dim, bias=qkv_bias) | |||||
| self.proj = nn.Linear(self.dim, self.out_dim) | |||||
| self.attn_drop = nn.Dropout(attn_drop) | |||||
| self.proj_drop = nn.Dropout(proj_drop) | |||||
| self.sr_ratio = sr_ratio | |||||
| self.N_ratio = sr_ratio**2 | |||||
| if sr_ratio > 1: | |||||
| self.sr = nn.AvgPool1d( | |||||
| kernel_size=self.N_ratio, stride=self.N_ratio) | |||||
| self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS) | |||||
| self.is_bn_merge = False | |||||
| def forward(self, x): | |||||
| B, N, C = x.shape | |||||
| q = self.q(x) | |||||
| q = q.reshape(B, N, self.num_heads, | |||||
| int(C // self.num_heads)).permute(0, 2, 1, 3) | |||||
| if self.sr_ratio > 1: | |||||
| x_ = x.transpose(1, 2) | |||||
| x_ = self.sr(x_) | |||||
| if not torch.onnx.is_in_onnx_export() and not self.is_bn_merge: | |||||
| x_ = self.norm(x_) | |||||
| x_ = x_.transpose(1, 2) | |||||
| k = self.k(x_) | |||||
| k = k.reshape(B, -1, self.num_heads, | |||||
| int(C // self.num_heads)).permute(0, 2, 3, 1) | |||||
| v = self.v(x_) | |||||
| v = v.reshape(B, -1, self.num_heads, | |||||
| int(C // self.num_heads)).permute(0, 2, 1, 3) | |||||
| else: | |||||
| k = self.k(x) | |||||
| k = k.reshape(B, -1, self.num_heads, | |||||
| int(C // self.num_heads)).permute(0, 2, 3, 1) | |||||
| v = self.v(x) | |||||
| v = v.reshape(B, -1, self.num_heads, | |||||
| int(C // self.num_heads)).permute(0, 2, 1, 3) | |||||
| attn = (q @ k) * self.scale | |||||
| attn = attn.softmax(dim=-1) | |||||
| attn = self.attn_drop(attn) | |||||
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |||||
| x = self.proj(x) | |||||
| x = self.proj_drop(x) | |||||
| return x | |||||
| class NTB(nn.Module): | |||||
| """ | |||||
| Next Transformer Block | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| in_channels, | |||||
| out_channels, | |||||
| path_dropout, | |||||
| stride=1, | |||||
| sr_ratio=1, | |||||
| mlp_ratio=2, | |||||
| head_dim=32, | |||||
| mix_block_ratio=0.75, | |||||
| attn_drop=0, | |||||
| drop=0, | |||||
| ): | |||||
| super(NTB, self).__init__() | |||||
| self.in_channels = in_channels | |||||
| self.out_channels = out_channels | |||||
| self.mix_block_ratio = mix_block_ratio | |||||
| norm_func = partial(nn.BatchNorm2d, eps=NORM_EPS) | |||||
| self.mhsa_out_channels = _make_divisible( | |||||
| int(out_channels * mix_block_ratio), 32) | |||||
| self.mhca_out_channels = out_channels - self.mhsa_out_channels | |||||
| self.patch_embed = PatchEmbed(in_channels, self.mhsa_out_channels, | |||||
| stride) | |||||
| self.norm1 = norm_func(self.mhsa_out_channels) | |||||
| self.e_mhsa = E_MHSA( | |||||
| self.mhsa_out_channels, | |||||
| head_dim=head_dim, | |||||
| sr_ratio=sr_ratio, | |||||
| attn_drop=attn_drop, | |||||
| proj_drop=drop) | |||||
| self.mhsa_path_dropout = DropPath(path_dropout * mix_block_ratio) | |||||
| self.projection = PatchEmbed( | |||||
| self.mhsa_out_channels, self.mhca_out_channels, stride=1) | |||||
| self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim) | |||||
| self.mhca_path_dropout = DropPath(path_dropout * (1 - mix_block_ratio)) | |||||
| self.norm2 = norm_func(out_channels) | |||||
| self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop) | |||||
| self.mlp_path_dropout = DropPath(path_dropout) | |||||
| self.is_bn_merged = False | |||||
| def forward(self, x): | |||||
| x = self.patch_embed(x) | |||||
| B, C, H, W = x.shape | |||||
| if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged: | |||||
| out = self.norm1(x) | |||||
| else: | |||||
| out = x | |||||
| out = rearrange(out, 'b c h w -> b (h w) c') # b n c | |||||
| out = self.mhsa_path_dropout(self.e_mhsa(out)) | |||||
| x = x + rearrange(out, 'b (h w) c -> b c h w', h=H) | |||||
| out = self.projection(x) | |||||
| out = out + self.mhca_path_dropout(self.mhca(out)) | |||||
| x = torch.cat([x, out], dim=1) | |||||
| if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged: | |||||
| out = self.norm2(x) | |||||
| else: | |||||
| out = x | |||||
| x = x + self.mlp_path_dropout(self.mlp(out)) | |||||
| return x | |||||
| @BACKBONES.register_module() | |||||
| class NextViT(BaseBackbone): | |||||
| stem_chs = { | |||||
| 'x_small': [64, 32, 64], | |||||
| 'small': [64, 32, 64], | |||||
| 'base': [64, 32, 64], | |||||
| 'large': [64, 32, 64], | |||||
| } | |||||
| depths = { | |||||
| 'x_small': [1, 1, 5, 1], | |||||
| 'small': [3, 4, 10, 3], | |||||
| 'base': [3, 4, 20, 3], | |||||
| 'large': [3, 4, 30, 3], | |||||
| } | |||||
| def __init__(self, | |||||
| arch='small', | |||||
| path_dropout=0.2, | |||||
| attn_drop=0, | |||||
| drop=0, | |||||
| strides=[1, 2, 2, 2], | |||||
| sr_ratios=[8, 4, 2, 1], | |||||
| head_dim=32, | |||||
| mix_block_ratio=0.75, | |||||
| resume='', | |||||
| with_extra_norm=True, | |||||
| norm_eval=False, | |||||
| norm_cfg=None, | |||||
| out_indices=-1, | |||||
| frozen_stages=-1, | |||||
| init_cfg=None): | |||||
| super().__init__(init_cfg=init_cfg) | |||||
| stem_chs = self.stem_chs[arch] | |||||
| depths = self.depths[arch] | |||||
| self.frozen_stages = frozen_stages | |||||
| self.with_extra_norm = with_extra_norm | |||||
| self.norm_eval = norm_eval | |||||
| self.stage1_out_channels = [96] * (depths[0]) | |||||
| self.stage2_out_channels = [192] * (depths[1] - 1) + [256] | |||||
| self.stage3_out_channels = [384, 384, 384, 384, 512] * (depths[2] // 5) | |||||
| self.stage4_out_channels = [768] * (depths[3] - 1) + [1024] | |||||
| self.stage_out_channels = [ | |||||
| self.stage1_out_channels, self.stage2_out_channels, | |||||
| self.stage3_out_channels, self.stage4_out_channels | |||||
| ] | |||||
| # Next Hybrid Strategy | |||||
| self.stage1_block_types = [NCB] * depths[0] | |||||
| self.stage2_block_types = [NCB] * (depths[1] - 1) + [NTB] | |||||
| self.stage3_block_types = [NCB, NCB, NCB, NCB, NTB] * (depths[2] // 5) | |||||
| self.stage4_block_types = [NCB] * (depths[3] - 1) + [NTB] | |||||
| self.stage_block_types = [ | |||||
| self.stage1_block_types, self.stage2_block_types, | |||||
| self.stage3_block_types, self.stage4_block_types | |||||
| ] | |||||
| self.stem = nn.Sequential( | |||||
| ConvBNReLU(3, stem_chs[0], kernel_size=3, stride=2), | |||||
| ConvBNReLU(stem_chs[0], stem_chs[1], kernel_size=3, stride=1), | |||||
| ConvBNReLU(stem_chs[1], stem_chs[2], kernel_size=3, stride=1), | |||||
| ConvBNReLU(stem_chs[2], stem_chs[2], kernel_size=3, stride=2), | |||||
| ) | |||||
| input_channel = stem_chs[-1] | |||||
| features = [] | |||||
| idx = 0 | |||||
| dpr = [x.item() for x in torch.linspace(0, path_dropout, sum(depths)) | |||||
| ] # stochastic depth decay rule | |||||
| for stage_id in range(len(depths)): | |||||
| numrepeat = depths[stage_id] | |||||
| output_channels = self.stage_out_channels[stage_id] | |||||
| block_types = self.stage_block_types[stage_id] | |||||
| for block_id in range(numrepeat): | |||||
| if strides[stage_id] == 2 and block_id == 0: | |||||
| stride = 2 | |||||
| else: | |||||
| stride = 1 | |||||
| output_channel = output_channels[block_id] | |||||
| block_type = block_types[block_id] | |||||
| if block_type is NCB: | |||||
| layer = NCB( | |||||
| input_channel, | |||||
| output_channel, | |||||
| stride=stride, | |||||
| path_dropout=dpr[idx + block_id], | |||||
| drop=drop, | |||||
| head_dim=head_dim) | |||||
| features.append(layer) | |||||
| elif block_type is NTB: | |||||
| layer = NTB( | |||||
| input_channel, | |||||
| output_channel, | |||||
| path_dropout=dpr[idx + block_id], | |||||
| stride=stride, | |||||
| sr_ratio=sr_ratios[stage_id], | |||||
| head_dim=head_dim, | |||||
| mix_block_ratio=mix_block_ratio, | |||||
| attn_drop=attn_drop, | |||||
| drop=drop) | |||||
| features.append(layer) | |||||
| input_channel = output_channel | |||||
| idx += numrepeat | |||||
| self.features = nn.Sequential(*features) | |||||
| self.norm = nn.BatchNorm2d(output_channel, eps=NORM_EPS) | |||||
| if isinstance(out_indices, int): | |||||
| out_indices = [out_indices] | |||||
| assert isinstance(out_indices, Sequence), \ | |||||
| f'"out_indices" must by a sequence or int, ' \ | |||||
| f'get {type(out_indices)} instead.' | |||||
| for i, index in enumerate(out_indices): | |||||
| if index < 0: | |||||
| out_indices[i] = sum(depths) + index | |||||
| assert out_indices[i] >= 0, f'Invalid out_indices {index}' | |||||
| self.stage_out_idx = out_indices | |||||
| if norm_cfg is not None: | |||||
| self = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self) | |||||
| def init_weights(self): | |||||
| super(NextViT, self).init_weights() | |||||
| if (isinstance(self.init_cfg, dict) | |||||
| and self.init_cfg['type'] == 'Pretrained'): | |||||
| # Suppress default init if use pretrained model. | |||||
| return | |||||
| self._initialize_weights() | |||||
| def _initialize_weights(self): | |||||
| for n, m in self.named_modules(): | |||||
| if isinstance(m, (nn.BatchNorm2d, | |||||
| nn.BatchNorm1d)): # nn.GroupNorm, nn.LayerNorm, | |||||
| nn.init.constant_(m.weight, 1.0) | |||||
| nn.init.constant_(m.bias, 0) | |||||
| elif isinstance(m, nn.Linear): | |||||
| trunc_normal_(m.weight, std=.02) | |||||
| if hasattr(m, 'bias') and m.bias is not None: | |||||
| nn.init.constant_(m.bias, 0) | |||||
| elif isinstance(m, nn.Conv2d): | |||||
| trunc_normal_(m.weight, std=.02) | |||||
| if hasattr(m, 'bias') and m.bias is not None: | |||||
| nn.init.constant_(m.bias, 0) | |||||
| def forward(self, x): | |||||
| outputs = list() | |||||
| x = self.stem(x) | |||||
| stage_id = 0 | |||||
| for idx, layer in enumerate(self.features): | |||||
| x = layer(x) | |||||
| if idx == self.stage_out_idx[stage_id]: | |||||
| if self.with_extra_norm: | |||||
| x = self.norm(x) | |||||
| outputs.append(x) | |||||
| stage_id += 1 | |||||
| return tuple(outputs) | |||||
| def _freeze_stages(self): | |||||
| if self.frozen_stages > 0: | |||||
| self.stem.eval() | |||||
| for param in self.stem.parameters(): | |||||
| param.requires_grad = False | |||||
| for idx, layer in enumerate(self.features): | |||||
| if idx <= self.stage_out_idx[self.frozen_stages - 1]: | |||||
| layer.eval() | |||||
| for param in layer.parameters(): | |||||
| param.requires_grad = False | |||||
| def train(self, mode=True): | |||||
| super(NextViT, self).train(mode) | |||||
| self._freeze_stages() | |||||
| if mode and self.norm_eval: | |||||
| for m in self.modules(): | |||||
| # trick: eval have effect on BatchNorm only | |||||
| if isinstance(m, _BatchNorm): | |||||
| m.eval() | |||||
| @@ -1,9 +1,10 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | import os | ||||
| from modelscope.metainfo import Models | from modelscope.metainfo import Models | ||||
| from modelscope.models.base.base_torch_model import TorchModel | from modelscope.models.base.base_torch_model import TorchModel | ||||
| from modelscope.models.builder import MODELS | from modelscope.models.builder import MODELS | ||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| @MODELS.register_module( | @MODELS.register_module( | ||||
| @@ -13,16 +14,25 @@ class ClassificationModel(TorchModel): | |||||
| def __init__(self, model_dir: str, **kwargs): | def __init__(self, model_dir: str, **kwargs): | ||||
| import mmcv | import mmcv | ||||
| from mmcls.models import build_classifier | from mmcls.models import build_classifier | ||||
| import modelscope.models.cv.image_classification.backbones | |||||
| from modelscope.utils.hub import read_config | |||||
| super().__init__(model_dir) | super().__init__(model_dir) | ||||
| config = os.path.join(model_dir, 'config.py') | |||||
| cfg = mmcv.Config.fromfile(config) | |||||
| cfg.model.pretrained = None | |||||
| self.cls_model = build_classifier(cfg.model) | |||||
| self.config_type = 'ms_config' | |||||
| mm_config = os.path.join(model_dir, 'config.py') | |||||
| if os.path.exists(mm_config): | |||||
| cfg = mmcv.Config.fromfile(mm_config) | |||||
| cfg.model.pretrained = None | |||||
| self.cls_model = build_classifier(cfg.model) | |||||
| self.config_type = 'mmcv_config' | |||||
| else: | |||||
| cfg = read_config(model_dir) | |||||
| cfg.model.mm_model.pretrained = None | |||||
| self.cls_model = build_classifier(cfg.model.mm_model) | |||||
| self.config_type = 'ms_config' | |||||
| self.cfg = cfg | self.cfg = cfg | ||||
| self.ms_model_dir = model_dir | self.ms_model_dir = model_dir | ||||
| self.load_pretrained_checkpoint() | self.load_pretrained_checkpoint() | ||||
| @@ -33,7 +43,13 @@ class ClassificationModel(TorchModel): | |||||
| def load_pretrained_checkpoint(self): | def load_pretrained_checkpoint(self): | ||||
| import mmcv | import mmcv | ||||
| checkpoint_path = os.path.join(self.ms_model_dir, 'checkpoints.pth') | |||||
| if os.path.exists( | |||||
| os.path.join(self.ms_model_dir, ModelFile.TORCH_MODEL_FILE)): | |||||
| checkpoint_path = os.path.join(self.ms_model_dir, | |||||
| ModelFile.TORCH_MODEL_FILE) | |||||
| else: | |||||
| checkpoint_path = os.path.join(self.ms_model_dir, | |||||
| 'checkpoints.pth') | |||||
| if os.path.exists(checkpoint_path): | if os.path.exists(checkpoint_path): | ||||
| checkpoint = mmcv.runner.load_checkpoint( | checkpoint = mmcv.runner.load_checkpoint( | ||||
| self.cls_model, checkpoint_path, map_location='cpu') | self.cls_model, checkpoint_path, map_location='cpu') | ||||
| @@ -0,0 +1,100 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import numpy as np | |||||
| from mmcls.datasets.base_dataset import BaseDataset | |||||
| def get_trained_checkpoints_name(work_path): | |||||
| import os | |||||
| file_list = os.listdir(work_path) | |||||
| last = 0 | |||||
| model_name = None | |||||
| # find the best model | |||||
| if model_name is None: | |||||
| for f_name in file_list: | |||||
| if 'best_' in f_name and f_name.endswith('.pth'): | |||||
| best_epoch = f_name.replace('.pth', '').split('_')[-1] | |||||
| if best_epoch.isdigit(): | |||||
| last = int(best_epoch) | |||||
| model_name = f_name | |||||
| return model_name | |||||
| # or find the latest model | |||||
| if model_name is None: | |||||
| for f_name in file_list: | |||||
| if 'epoch_' in f_name and f_name.endswith('.pth'): | |||||
| epoch_num = f_name.replace('epoch_', '').replace('.pth', '') | |||||
| if not epoch_num.isdigit(): | |||||
| continue | |||||
| ind = int(epoch_num) | |||||
| if ind > last: | |||||
| last = ind | |||||
| model_name = f_name | |||||
| return model_name | |||||
| def preprocess_transform(cfgs): | |||||
| if cfgs is None: | |||||
| return None | |||||
| for i, cfg in enumerate(cfgs): | |||||
| if cfg.type == 'Resize': | |||||
| if isinstance(cfg.size, list): | |||||
| cfgs[i].size = tuple(cfg.size) | |||||
| return cfgs | |||||
| def get_ms_dataset_root(ms_dataset): | |||||
| if ms_dataset is None or len(ms_dataset) < 1: | |||||
| return None | |||||
| try: | |||||
| data_root = ms_dataset[0]['image:FILE'].split('extracted')[0] | |||||
| path_post = ms_dataset[0]['image:FILE'].split('extracted')[1].split( | |||||
| '/') | |||||
| extracted_data_root = osp.join(data_root, 'extracted', path_post[1], | |||||
| path_post[2]) | |||||
| return extracted_data_root | |||||
| except Exception as e: | |||||
| raise ValueError(f'Dataset Error: {e}') | |||||
| return None | |||||
| def get_classes(classes=None): | |||||
| import mmcv | |||||
| if isinstance(classes, str): | |||||
| # take it as a file path | |||||
| class_names = mmcv.list_from_file(classes) | |||||
| elif isinstance(classes, (tuple, list)): | |||||
| class_names = classes | |||||
| else: | |||||
| raise ValueError(f'Unsupported type {type(classes)} of classes.') | |||||
| return class_names | |||||
| class MmDataset(BaseDataset): | |||||
| def __init__(self, ms_dataset, pipeline, classes=None, test_mode=False): | |||||
| self.ms_dataset = ms_dataset | |||||
| if len(self.ms_dataset) < 1: | |||||
| raise ValueError('Dataset Error: dataset is empty') | |||||
| super(MmDataset, self).__init__( | |||||
| data_prefix='', | |||||
| pipeline=pipeline, | |||||
| classes=classes, | |||||
| test_mode=test_mode) | |||||
| def load_annotations(self): | |||||
| if self.CLASSES is None: | |||||
| raise ValueError( | |||||
| f'Dataset Error: Not found classesname.txt: {self.CLASSES}') | |||||
| data_infos = [] | |||||
| for data_info in self.ms_dataset: | |||||
| filename = data_info['image:FILE'] | |||||
| gt_label = data_info['category'] | |||||
| info = {'img_prefix': self.data_prefix} | |||||
| info['img_info'] = {'filename': filename} | |||||
| info['gt_label'] = np.array(gt_label, dtype=np.int64) | |||||
| data_infos.append(info) | |||||
| return data_infos | |||||
| @@ -45,6 +45,9 @@ class ImageClassificationPipeline(Pipeline): | |||||
| @PIPELINES.register_module( | @PIPELINES.register_module( | ||||
| Tasks.image_classification, | Tasks.image_classification, | ||||
| module_name=Pipelines.daily_image_classification) | module_name=Pipelines.daily_image_classification) | ||||
| @PIPELINES.register_module( | |||||
| Tasks.image_classification, | |||||
| module_name=Pipelines.nextvit_small_daily_image_classification) | |||||
| class GeneralImageClassificationPipeline(Pipeline): | class GeneralImageClassificationPipeline(Pipeline): | ||||
| def __init__(self, model: str, **kwargs): | def __init__(self, model: str, **kwargs): | ||||
| @@ -60,6 +63,7 @@ class GeneralImageClassificationPipeline(Pipeline): | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| from mmcls.datasets.pipelines import Compose | from mmcls.datasets.pipelines import Compose | ||||
| from mmcv.parallel import collate, scatter | from mmcv.parallel import collate, scatter | ||||
| from modelscope.models.cv.image_classification.utils import preprocess_transform | |||||
| if isinstance(input, str): | if isinstance(input, str): | ||||
| img = np.array(load_image(input)) | img = np.array(load_image(input)) | ||||
| elif isinstance(input, PIL.Image.Image): | elif isinstance(input, PIL.Image.Image): | ||||
| @@ -72,12 +76,20 @@ class GeneralImageClassificationPipeline(Pipeline): | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | raise TypeError(f'input should be either str, PIL.Image,' | ||||
| f' np.array, but got {type(input)}') | f' np.array, but got {type(input)}') | ||||
| mmcls_cfg = self.model.cfg | |||||
| # build the data pipeline | |||||
| if mmcls_cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile': | |||||
| mmcls_cfg.data.test.pipeline.pop(0) | |||||
| data = dict(img=img) | |||||
| test_pipeline = Compose(mmcls_cfg.data.test.pipeline) | |||||
| cfg = self.model.cfg | |||||
| if self.model.config_type == 'mmcv_config': | |||||
| if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile': | |||||
| cfg.data.test.pipeline.pop(0) | |||||
| data = dict(img=img) | |||||
| test_pipeline = Compose(cfg.data.test.pipeline) | |||||
| else: | |||||
| if cfg.preprocessor.val[0]['type'] == 'LoadImageFromFile': | |||||
| cfg.preprocessor.val.pop(0) | |||||
| data = dict(img=img) | |||||
| data_pipeline = preprocess_transform(cfg.preprocessor.val) | |||||
| test_pipeline = Compose(data_pipeline) | |||||
| data = test_pipeline(data) | data = test_pipeline(data) | ||||
| data = collate([data], samples_per_gpu=1) | data = collate([data], samples_per_gpu=1) | ||||
| if next(self.model.parameters()).is_cuda: | if next(self.model.parameters()).is_cuda: | ||||
| @@ -289,3 +289,37 @@ class VideoSummarizationPreprocessor(Preprocessor): | |||||
| Dict[str, Any]: the preprocessed data | Dict[str, Any]: the preprocessed data | ||||
| """ | """ | ||||
| return data | return data | ||||
| @PREPROCESSORS.register_module( | |||||
| Fields.cv, | |||||
| module_name=Preprocessors.image_classification_bypass_preprocessor) | |||||
| class ImageClassificationBypassPreprocessor(Preprocessor): | |||||
| def __init__(self, *args, **kwargs): | |||||
| """image classification bypass preprocessor in the fine-tune scenario | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| self.training = kwargs.pop('training', True) | |||||
| self.preprocessor_train_cfg = kwargs.pop('train', None) | |||||
| self.preprocessor_val_cfg = kwargs.pop('val', None) | |||||
| def train(self): | |||||
| self.training = True | |||||
| return | |||||
| def eval(self): | |||||
| self.training = False | |||||
| return | |||||
| def __call__(self, results: Dict[str, Any]): | |||||
| """process the raw input data | |||||
| Args: | |||||
| results (dict): Result dict from loading pipeline. | |||||
| Returns: | |||||
| Dict[str, Any] | None: the preprocessed data | |||||
| """ | |||||
| pass | |||||
| @@ -0,0 +1,502 @@ | |||||
| # Part of the implementation is borrowed and modified from mmclassification, | |||||
| # publicly available at https://github.com/open-mmlab/mmclassification | |||||
| import copy | |||||
| import os | |||||
| import os.path as osp | |||||
| import time | |||||
| from typing import Callable, Dict, Optional, Tuple, Union | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.utils.data import Dataset | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.msdatasets.ms_dataset import MsDataset | |||||
| from modelscope.preprocessors.base import Preprocessor | |||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| def train_model(model, | |||||
| dataset, | |||||
| cfg, | |||||
| distributed=False, | |||||
| val_dataset=None, | |||||
| timestamp=None, | |||||
| device=None, | |||||
| meta=None): | |||||
| import torch | |||||
| import warnings | |||||
| from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook, | |||||
| build_optimizer, build_runner, get_dist_info) | |||||
| from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook | |||||
| from mmcls.datasets import build_dataloader | |||||
| from mmcls.utils import (wrap_distributed_model, | |||||
| wrap_non_distributed_model) | |||||
| from mmcv.parallel import MMDataParallel, MMDistributedDataParallel | |||||
| logger = get_logger() | |||||
| # prepare data loaders | |||||
| dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] | |||||
| sampler_cfg = cfg.train.get('sampler', None) | |||||
| data_loaders = [ | |||||
| build_dataloader( | |||||
| ds, | |||||
| cfg.train.dataloader.batch_size_per_gpu, | |||||
| cfg.train.dataloader.workers_per_gpu, | |||||
| # cfg.gpus will be ignored if distributed | |||||
| num_gpus=len(cfg.gpu_ids), | |||||
| dist=distributed, | |||||
| round_up=True, | |||||
| seed=cfg.seed, | |||||
| sampler_cfg=sampler_cfg) for ds in dataset | |||||
| ] | |||||
| # put model on gpus | |||||
| if distributed: | |||||
| find_unused_parameters = cfg.get('find_unused_parameters', False) | |||||
| # Sets the `find_unused_parameters` parameter in | |||||
| # torch.nn.parallel.DistributedDataParallel | |||||
| model = MMDistributedDataParallel( | |||||
| model.cuda(), | |||||
| device_ids=[torch.cuda.current_device()], | |||||
| broadcast_buffers=False, | |||||
| find_unused_parameters=find_unused_parameters) | |||||
| else: | |||||
| if device == 'cpu': | |||||
| logger.warning( | |||||
| 'The argument `device` is deprecated. To use cpu to train, ' | |||||
| 'please refers to https://mmclassification.readthedocs.io/en' | |||||
| '/latest/getting_started.html#train-a-model') | |||||
| model = model.cpu() | |||||
| else: | |||||
| model = MMDataParallel(model, device_ids=cfg.gpu_ids) | |||||
| if not model.device_ids: | |||||
| from mmcv import __version__, digit_version | |||||
| assert digit_version(__version__) >= (1, 4, 4), \ | |||||
| 'To train with CPU, please confirm your mmcv version ' \ | |||||
| 'is not lower than v1.4.4' | |||||
| # build runner | |||||
| optimizer = build_optimizer(model, cfg.train.optimizer) | |||||
| if cfg.train.get('runner') is None: | |||||
| cfg.train.runner = { | |||||
| 'type': 'EpochBasedRunner', | |||||
| 'max_epochs': cfg.train.max_epochs | |||||
| } | |||||
| logger.warning( | |||||
| 'config is now expected to have a `runner` section, ' | |||||
| 'please set `runner` in your config.', UserWarning) | |||||
| runner = build_runner( | |||||
| cfg.train.runner, | |||||
| default_args=dict( | |||||
| model=model, | |||||
| batch_processor=None, | |||||
| optimizer=optimizer, | |||||
| work_dir=cfg.work_dir, | |||||
| logger=logger, | |||||
| meta=meta)) | |||||
| # an ugly walkaround to make the .log and .log.json filenames the same | |||||
| runner.timestamp = timestamp | |||||
| # fp16 setting | |||||
| fp16_cfg = cfg.get('fp16', None) | |||||
| if fp16_cfg is not None: | |||||
| optimizer_config = Fp16OptimizerHook( | |||||
| **cfg.train.optimizer_config, **fp16_cfg, distributed=distributed) | |||||
| elif distributed and 'type' not in cfg.train.optimizer_config: | |||||
| optimizer_config = DistOptimizerHook(**cfg.train.optimizer_config) | |||||
| else: | |||||
| optimizer_config = cfg.train.optimizer_config | |||||
| # register hooks | |||||
| runner.register_training_hooks( | |||||
| cfg.train.lr_config, | |||||
| optimizer_config, | |||||
| cfg.train.checkpoint_config, | |||||
| cfg.train.log_config, | |||||
| cfg.train.get('momentum_config', None), | |||||
| custom_hooks_config=cfg.train.get('custom_hooks', None)) | |||||
| if distributed and cfg.train.runner['type'] == 'EpochBasedRunner': | |||||
| runner.register_hook(DistSamplerSeedHook()) | |||||
| # register eval hooks | |||||
| if val_dataset is not None: | |||||
| val_dataloader = build_dataloader( | |||||
| val_dataset, | |||||
| samples_per_gpu=cfg.evaluation.dataloader.batch_size_per_gpu, | |||||
| workers_per_gpu=cfg.evaluation.dataloader.workers_per_gpu, | |||||
| dist=distributed, | |||||
| shuffle=False, | |||||
| round_up=True) | |||||
| eval_cfg = cfg.train.get('evaluation', {}) | |||||
| eval_cfg['by_epoch'] = cfg.train.runner['type'] != 'IterBasedRunner' | |||||
| eval_hook = DistEvalHook if distributed else EvalHook | |||||
| # `EvalHook` needs to be executed after `IterTimerHook`. | |||||
| # Otherwise, it will cause a bug if use `IterBasedRunner`. | |||||
| # Refers to https://github.com/open-mmlab/mmcv/issues/1261 | |||||
| runner.register_hook( | |||||
| eval_hook(val_dataloader, **eval_cfg), priority='LOW') | |||||
| if cfg.train.resume_from: | |||||
| runner.resume(cfg.train.resume_from, map_location='cpu') | |||||
| elif cfg.train.load_from: | |||||
| runner.load_checkpoint(cfg.train.load_from) | |||||
| cfg.train.workflow = [tuple(flow) for flow in cfg.train.workflow] | |||||
| runner.run(data_loaders, cfg.train.workflow) | |||||
| @TRAINERS.register_module(module_name=Trainers.image_classification) | |||||
| class ImageClassifitionTrainer(BaseTrainer): | |||||
| def __init__( | |||||
| self, | |||||
| model: Optional[Union[TorchModel, nn.Module, str]] = None, | |||||
| cfg_file: Optional[str] = None, | |||||
| arg_parse_fn: Optional[Callable] = None, | |||||
| data_collator: Optional[Union[Callable, Dict[str, | |||||
| Callable]]] = None, | |||||
| train_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||||
| eval_dataset: Optional[Union[MsDataset, Dataset]] = None, | |||||
| preprocessor: Optional[Union[Preprocessor, | |||||
| Dict[str, Preprocessor]]] = None, | |||||
| optimizers: Tuple[torch.optim.Optimizer, | |||||
| torch.optim.lr_scheduler._LRScheduler] = (None, | |||||
| None), | |||||
| model_revision: Optional[str] = DEFAULT_MODEL_REVISION, | |||||
| seed: int = 0, | |||||
| cfg_modify_fn: Optional[Callable] = None, | |||||
| **kwargs): | |||||
| """ High-level finetune api for Image Classifition. | |||||
| Args: | |||||
| model: model id | |||||
| model_version: model version, default is None. | |||||
| cfg_modify_fn: An input fn which is used to modify the cfg read out of the file. | |||||
| """ | |||||
| import torch | |||||
| import mmcv | |||||
| from modelscope.models.cv.image_classification.utils import get_ms_dataset_root, get_classes | |||||
| from mmcls.models import build_classifier | |||||
| from mmcv.runner import get_dist_info, init_dist | |||||
| from mmcls.apis import set_random_seed | |||||
| from mmcls.utils import collect_env | |||||
| import modelscope.models.cv.image_classification.backbones | |||||
| self._seed = seed | |||||
| set_random_seed(self._seed) | |||||
| if isinstance(model, str): | |||||
| if os.path.exists(model): | |||||
| self.model_dir = model if os.path.isdir( | |||||
| model) else os.path.dirname(model) | |||||
| else: | |||||
| self.model_dir = snapshot_download( | |||||
| model, revision=model_revision) | |||||
| if cfg_file is None: | |||||
| cfg_file = os.path.join(self.model_dir, | |||||
| ModelFile.CONFIGURATION) | |||||
| else: | |||||
| assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' | |||||
| self.model_dir = os.path.dirname(cfg_file) | |||||
| super().__init__(cfg_file, arg_parse_fn) | |||||
| cfg = self.cfg | |||||
| if 'work_dir' in kwargs: | |||||
| self.work_dir = kwargs['work_dir'] | |||||
| else: | |||||
| self.work_dir = self.cfg.train.get('work_dir', './work_dir') | |||||
| mmcv.mkdir_or_exist(osp.abspath(self.work_dir)) | |||||
| cfg.work_dir = self.work_dir | |||||
| # evaluate config seting | |||||
| self.eval_checkpoint_path = os.path.join(self.model_dir, | |||||
| ModelFile.TORCH_MODEL_FILE) | |||||
| # train config seting | |||||
| if 'resume_from' in kwargs: | |||||
| cfg.train.resume_from = kwargs['resume_from'] | |||||
| else: | |||||
| cfg.train.resume_from = cfg.train.get('resume_from', None) | |||||
| if 'load_from' in kwargs: | |||||
| cfg.train.load_from = kwargs['load_from'] | |||||
| else: | |||||
| if cfg.train.get('resume_from', None) is None: | |||||
| cfg.train.load_from = os.path.join(self.model_dir, | |||||
| ModelFile.TORCH_MODEL_FILE) | |||||
| if 'device' in kwargs: | |||||
| cfg.device = kwargs['device'] | |||||
| else: | |||||
| cfg.device = cfg.get('device', 'cuda') | |||||
| if 'gpu_ids' in kwargs: | |||||
| cfg.gpu_ids = kwargs['gpu_ids'][0:1] | |||||
| else: | |||||
| cfg.gpu_ids = [0] | |||||
| if 'fp16' in kwargs: | |||||
| cfg.fp16 = None if kwargs['fp16'] is None else kwargs['fp16'] | |||||
| else: | |||||
| cfg.fp16 = None | |||||
| # no_validate=True will not evaluate checkpoint during training | |||||
| cfg.no_validate = kwargs.get('no_validate', False) | |||||
| if cfg_modify_fn is not None: | |||||
| cfg = cfg_modify_fn(cfg) | |||||
| if 'max_epochs' not in kwargs: | |||||
| assert hasattr( | |||||
| self.cfg.train, | |||||
| 'max_epochs'), 'max_epochs is missing in configuration file' | |||||
| self.max_epochs = self.cfg.train.max_epochs | |||||
| else: | |||||
| self.max_epochs = kwargs['max_epochs'] | |||||
| cfg.train.max_epochs = self.max_epochs | |||||
| if cfg.train.get('runner', None) is not None: | |||||
| cfg.train.runner.max_epochs = self.max_epochs | |||||
| if 'launcher' in kwargs: | |||||
| distributed = True | |||||
| dist_params = kwargs['dist_params'] \ | |||||
| if 'dist_params' in kwargs else {'backend': 'nccl'} | |||||
| init_dist(kwargs['launcher'], **dist_params) | |||||
| # re-set gpu_ids with distributed training mode | |||||
| _, world_size = get_dist_info() | |||||
| cfg.gpu_ids = list(range(world_size)) | |||||
| else: | |||||
| distributed = False | |||||
| # init the logger before other steps | |||||
| timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) | |||||
| log_file = osp.join(self.work_dir, f'{timestamp}.log') | |||||
| logger = get_logger(log_file=log_file) | |||||
| # init the meta dict to record some important information such as | |||||
| # environment info and seed, which will be logged | |||||
| meta = dict() | |||||
| # log env info | |||||
| env_info_dict = collect_env() | |||||
| env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) | |||||
| dash_line = '-' * 60 + '\n' | |||||
| logger.info('Environment info:\n' + dash_line + env_info + '\n' | |||||
| + dash_line) | |||||
| meta['env_info'] = env_info | |||||
| meta['config'] = cfg.pretty_text | |||||
| # log some basic info | |||||
| logger.info(f'Distributed training: {distributed}') | |||||
| logger.info(f'Config:\n{cfg.pretty_text}') | |||||
| # set random seeds | |||||
| cfg.seed = self._seed | |||||
| _deterministic = kwargs.get('deterministic', False) | |||||
| logger.info(f'Set random seed to {cfg.seed}, ' | |||||
| f'deterministic: {_deterministic}') | |||||
| set_random_seed(cfg.seed, deterministic=_deterministic) | |||||
| meta['seed'] = cfg.seed | |||||
| meta['exp_name'] = osp.basename(cfg_file) | |||||
| # dataset | |||||
| self.train_dataset = train_dataset | |||||
| self.eval_dataset = eval_dataset | |||||
| # model | |||||
| model = build_classifier(self.cfg.model.mm_model) | |||||
| model.init_weights() | |||||
| self.cfg = cfg | |||||
| self.device = cfg.device | |||||
| self.cfg_file = cfg_file | |||||
| self.model = model | |||||
| self.distributed = distributed | |||||
| self.timestamp = timestamp | |||||
| self.meta = meta | |||||
| self.logger = logger | |||||
| def train(self, *args, **kwargs): | |||||
| from mmcls import __version__ | |||||
| from modelscope.models.cv.image_classification.utils import get_ms_dataset_root, MmDataset, preprocess_transform | |||||
| from mmcls.utils import setup_multi_processes | |||||
| if self.train_dataset is None: | |||||
| raise ValueError( | |||||
| "Not found train dataset, please set the 'train_dataset' parameter!" | |||||
| ) | |||||
| self.cfg.model.mm_model.pretrained = None | |||||
| # dump config | |||||
| self.cfg.dump(osp.join(self.work_dir, osp.basename(self.cfg_file))) | |||||
| # build the dataloader | |||||
| if self.cfg.dataset.classes is None: | |||||
| data_root = get_ms_dataset_root(self.train_dataset) | |||||
| classname_path = osp.join(data_root, 'classname.txt') | |||||
| classes = classname_path if osp.exists(classname_path) else None | |||||
| else: | |||||
| classes = cfg.dataset.classes | |||||
| datasets = [ | |||||
| MmDataset( | |||||
| self.train_dataset, | |||||
| pipeline=self.cfg.preprocessor.train, | |||||
| classes=classes) | |||||
| ] | |||||
| if len(self.cfg.train.workflow) == 2: | |||||
| if self.eval_dataset is None: | |||||
| raise ValueError( | |||||
| "Not found evaluate dataset, please set the 'eval_dataset' parameter!" | |||||
| ) | |||||
| val_data_pipeline = self.cfg.preprocessor.train | |||||
| val_dataset = MmDataset( | |||||
| self.eval_dataset, pipeline=val_data_pipeline, classes=classes) | |||||
| datasets.append(val_dataset) | |||||
| # save mmcls version, config file content and class names in | |||||
| # checkpoints as meta data | |||||
| self.meta.update( | |||||
| dict( | |||||
| mmcls_version=__version__, | |||||
| config=self.cfg.pretty_text, | |||||
| CLASSES=datasets[0].CLASSES)) | |||||
| val_dataset = None | |||||
| if not self.cfg.no_validate: | |||||
| val_dataset = MmDataset( | |||||
| self.eval_dataset, | |||||
| pipeline=preprocess_transform(self.cfg.preprocessor.val), | |||||
| classes=classes) | |||||
| # add an attribute for visualization convenience | |||||
| train_model( | |||||
| self.model, | |||||
| datasets, | |||||
| self.cfg, | |||||
| distributed=self.distributed, | |||||
| val_dataset=val_dataset, | |||||
| timestamp=self.timestamp, | |||||
| device='cpu' if self.device == 'cpu' else 'cuda', | |||||
| meta=self.meta) | |||||
| def evaluate(self, | |||||
| checkpoint_path: str = None, | |||||
| *args, | |||||
| **kwargs) -> Dict[str, float]: | |||||
| import warnings | |||||
| import torch | |||||
| from modelscope.models.cv.image_classification.utils import ( | |||||
| get_ms_dataset_root, MmDataset, preprocess_transform, | |||||
| get_trained_checkpoints_name) | |||||
| from mmcls.datasets import build_dataloader | |||||
| from mmcv.runner import get_dist_info, load_checkpoint, wrap_fp16_model | |||||
| from mmcv.parallel import MMDataParallel, MMDistributedDataParallel | |||||
| from mmcls.apis import multi_gpu_test, single_gpu_test | |||||
| from mmcls.utils import setup_multi_processes | |||||
| if self.eval_dataset is None: | |||||
| raise ValueError( | |||||
| "Not found evaluate dataset, please set the 'eval_dataset' parameter!" | |||||
| ) | |||||
| self.cfg.model.mm_model.pretrained = None | |||||
| # build the dataloader | |||||
| if self.cfg.dataset.classes is None: | |||||
| data_root = get_ms_dataset_root(self.eval_dataset) | |||||
| classname_path = osp.join(data_root, 'classname.txt') | |||||
| classes = classname_path if osp.exists(classname_path) else None | |||||
| else: | |||||
| classes = cfg.dataset.classes | |||||
| dataset = MmDataset( | |||||
| self.eval_dataset, | |||||
| pipeline=preprocess_transform(self.cfg.preprocessor.val), | |||||
| classes=classes) | |||||
| # the extra round_up data will be removed during gpu/cpu collect | |||||
| data_loader = build_dataloader( | |||||
| dataset, | |||||
| samples_per_gpu=self.cfg.evaluation.dataloader.batch_size_per_gpu, | |||||
| workers_per_gpu=self.cfg.evaluation.dataloader.workers_per_gpu, | |||||
| dist=self.distributed, | |||||
| shuffle=False, | |||||
| round_up=True) | |||||
| model = copy.deepcopy(self.model) | |||||
| fp16_cfg = self.cfg.get('fp16', None) | |||||
| if fp16_cfg is not None: | |||||
| wrap_fp16_model(model) | |||||
| if checkpoint_path is None: | |||||
| trained_checkpoints = get_trained_checkpoints_name(self.work_dir) | |||||
| if trained_checkpoints is not None: | |||||
| checkpoint = load_checkpoint( | |||||
| model, | |||||
| os.path.join(self.work_dir, trained_checkpoints), | |||||
| map_location='cpu') | |||||
| else: | |||||
| checkpoint = load_checkpoint( | |||||
| model, self.eval_checkpoint_path, map_location='cpu') | |||||
| else: | |||||
| checkpoint = load_checkpoint( | |||||
| model, checkpoint_path, map_location='cpu') | |||||
| if 'CLASSES' in checkpoint.get('meta', {}): | |||||
| CLASSES = checkpoint['meta']['CLASSES'] | |||||
| else: | |||||
| from mmcls.datasets import ImageNet | |||||
| self.logger.warning( | |||||
| 'Class names are not saved in the checkpoint\'s ' | |||||
| 'meta data, use imagenet by default.') | |||||
| CLASSES = ImageNet.CLASSES | |||||
| if not self.distributed: | |||||
| if self.device == 'cpu': | |||||
| model = model.cpu() | |||||
| else: | |||||
| model = MMDataParallel(model, device_ids=self.cfg.gpu_ids) | |||||
| if not model.device_ids: | |||||
| assert mmcv.digit_version(mmcv.__version__) >= (1, 4, 4), \ | |||||
| 'To test with CPU, please confirm your mmcv version ' \ | |||||
| 'is not lower than v1.4.4' | |||||
| model.CLASSES = CLASSES | |||||
| show_kwargs = {} | |||||
| outputs = single_gpu_test(model, data_loader, False, None, | |||||
| **show_kwargs) | |||||
| else: | |||||
| model = MMDistributedDataParallel( | |||||
| model.cuda(), | |||||
| device_ids=[torch.cuda.current_device()], | |||||
| broadcast_buffers=False) | |||||
| outputs = multi_gpu_test(model, data_loader, None, True) | |||||
| rank, _ = get_dist_info() | |||||
| if rank == 0: | |||||
| results = {} | |||||
| logger = get_logger() | |||||
| metric_options = self.cfg.evaluation.get('metric_options', {}) | |||||
| if 'topk' in metric_options.keys(): | |||||
| metric_options['topk'] = tuple(metric_options['topk']) | |||||
| if self.cfg.evaluation.metrics: | |||||
| eval_results = dataset.evaluate( | |||||
| results=outputs, | |||||
| metric=self.cfg.evaluation.metrics, | |||||
| metric_options=metric_options, | |||||
| logger=logger) | |||||
| results.update(eval_results) | |||||
| return results | |||||
| return None | |||||
| @@ -31,6 +31,15 @@ class GeneralImageClassificationTest(unittest.TestCase, | |||||
| result = general_image_classification('data/test/images/bird.JPEG') | result = general_image_classification('data/test/images/bird.JPEG') | ||||
| print(result) | print(result) | ||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_nextvit(self): | |||||
| nexit_image_classification = pipeline( | |||||
| Tasks.image_classification, | |||||
| model='damo/cv_nextvit-small_image-classification_Dailylife-labels' | |||||
| ) | |||||
| result = nexit_image_classification('data/test/images/bird.JPEG') | |||||
| print(result) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
| def test_run_Dailylife_default(self): | def test_run_Dailylife_default(self): | ||||
| general_image_classification = pipeline(Tasks.image_classification) | general_image_classification = pipeline(Tasks.image_classification) | ||||
| @@ -0,0 +1,96 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| import zipfile | |||||
| from functools import partial | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.config import Config, ConfigDict | |||||
| from modelscope.utils.constant import DownloadMode, ModelFile | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class TestGeneralImageClassificationTestTrainer(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| try: | |||||
| self.train_dataset = MsDataset.load( | |||||
| 'cats_and_dogs', | |||||
| namespace='tany0699', | |||||
| subset_name='default', | |||||
| split='train') | |||||
| self.eval_dataset = MsDataset.load( | |||||
| 'cats_and_dogs', | |||||
| namespace='tany0699', | |||||
| subset_name='default', | |||||
| split='validation') | |||||
| except Exception as e: | |||||
| print(f'Download dataset error: {e}') | |||||
| self.max_epochs = 1 | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| def tearDown(self): | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| super().tearDown() | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_nextvit_dailylife_train(self): | |||||
| model_id = 'damo/cv_nextvit-small_image-classification_Dailylife-labels' | |||||
| def cfg_modify_fn(cfg): | |||||
| cfg.train.dataloader.batch_size_per_gpu = 32 | |||||
| cfg.train.dataloader.workers_per_gpu = 1 | |||||
| cfg.train.max_epochs = self.max_epochs | |||||
| cfg.model.mm_model.head.num_classes = 2 | |||||
| cfg.train.optimizer.lr = 1e-4 | |||||
| cfg.train.lr_config.warmup_iters = 1 | |||||
| cfg.train.evaluation.metric_options = {'topk': (1, )} | |||||
| cfg.evaluation.metric_options = {'topk': (1, )} | |||||
| return cfg | |||||
| kwargs = dict( | |||||
| model=model_id, | |||||
| work_dir=self.tmp_dir, | |||||
| train_dataset=self.train_dataset, | |||||
| eval_dataset=self.eval_dataset, | |||||
| cfg_modify_fn=cfg_modify_fn) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.image_classification, default_args=kwargs) | |||||
| trainer.train() | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| self.assertIn(f'{trainer.timestamp}.log.json', results_files) | |||||
| for i in range(self.max_epochs): | |||||
| self.assertIn(f'epoch_{i+1}.pth', results_files) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_nextvit_dailylife_eval(self): | |||||
| model_id = 'damo/cv_nextvit-small_image-classification_Dailylife-labels' | |||||
| kwargs = dict( | |||||
| model=model_id, | |||||
| work_dir=self.tmp_dir, | |||||
| train_dataset=None, | |||||
| eval_dataset=self.eval_dataset) | |||||
| trainer = build_trainer( | |||||
| name=Trainers.image_classification, default_args=kwargs) | |||||
| result = trainer.evaluate() | |||||
| print(result) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||