diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 1fccb46e..7e66f792 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -185,6 +185,7 @@ class Pipelines(object): live_category = 'live-category' general_image_classification = 'vit-base_image-classification_ImageNet-labels' daily_image_classification = 'vit-base_image-classification_Dailylife-labels' + nextvit_small_daily_image_classification = 'nextvit-small_image-classification_Dailylife-labels' image_color_enhance = 'csrnet-image-color-enhance' virtual_try_on = 'virtual-try-on' image_colorization = 'unet-image-colorization' @@ -330,6 +331,7 @@ class Trainers(object): image_inpainting = 'image-inpainting' referring_video_object_segmentation = 'referring-video-object-segmentation' image_classification_team = 'image-classification-team' + image_classification = 'image-classification' # nlp trainers bert_sentiment_analysis = 'bert-sentiment-analysis' @@ -365,6 +367,7 @@ class Preprocessors(object): image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor' video_summarization_preprocessor = 'video-summarization-preprocessor' movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor' + image_classification_bypass_preprocessor = 'image-classification-bypass-preprocessor' # nlp preprocessor sen_sim_tokenizer = 'sen-sim-tokenizer' diff --git a/modelscope/models/cv/image_classification/backbones/__init__.py b/modelscope/models/cv/image_classification/backbones/__init__.py new file mode 100644 index 00000000..79a3a4ed --- /dev/null +++ b/modelscope/models/cv/image_classification/backbones/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from .nextvit import NextViT diff --git a/modelscope/models/cv/image_classification/backbones/nextvit.py b/modelscope/models/cv/image_classification/backbones/nextvit.py new file mode 100644 index 00000000..ecf0d15e --- /dev/null +++ b/modelscope/models/cv/image_classification/backbones/nextvit.py @@ -0,0 +1,541 @@ +# Part of the implementation is borrowed and modified from Next-ViT, +# publicly available at https://github.com/bytedance/Next-ViT +import collections.abc +import itertools +import math +import os +import warnings +from functools import partial +from typing import Dict, Sequence + +import torch +import torch.nn as nn +from einops import rearrange +from mmcls.models.backbones.base_backbone import BaseBackbone +from mmcls.models.builder import BACKBONES +from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer +from mmcv.runner import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +NORM_EPS = 1e-5 + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + ll = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [ll, u], then translate to + # [2ll-1, 2u-1]. + tensor.uniform_(2 * ll - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class ConvBNReLU(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + groups=1): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=1, + groups=groups, + bias=False) + self.norm = nn.BatchNorm2d(out_channels, eps=NORM_EPS) + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + x = self.act(x) + return x + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class PatchEmbed(nn.Module): + + def __init__(self, in_channels, out_channels, stride=1): + super(PatchEmbed, self).__init__() + norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS) + if stride == 2: + self.avgpool = nn.AvgPool2d((2, 2), + stride=2, + ceil_mode=True, + count_include_pad=False) + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, bias=False) + self.norm = norm_layer(out_channels) + elif in_channels != out_channels: + self.avgpool = nn.Identity() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, bias=False) + self.norm = norm_layer(out_channels) + else: + self.avgpool = nn.Identity() + self.conv = nn.Identity() + self.norm = nn.Identity() + + def forward(self, x): + return self.norm(self.conv(self.avgpool(x))) + + +class MHCA(nn.Module): + """ + Multi-Head Convolutional Attention + """ + + def __init__(self, out_channels, head_dim): + super(MHCA, self).__init__() + norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS) + self.group_conv3x3 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + groups=out_channels // head_dim, + bias=False) + self.norm = norm_layer(out_channels) + self.act = nn.ReLU(inplace=True) + self.projection = nn.Conv2d( + out_channels, out_channels, kernel_size=1, bias=False) + + def forward(self, x): + out = self.group_conv3x3(x) + out = self.norm(out) + out = self.act(out) + out = self.projection(out) + return out + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + out_features=None, + mlp_ratio=None, + drop=0., + bias=True): + super().__init__() + out_features = out_features or in_features + hidden_dim = _make_divisible(in_features * mlp_ratio, 32) + self.conv1 = nn.Conv2d( + in_features, hidden_dim, kernel_size=1, bias=bias) + self.act = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + hidden_dim, out_features, kernel_size=1, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.conv1(x) + x = self.act(x) + x = self.drop(x) + x = self.conv2(x) + x = self.drop(x) + return x + + +class NCB(nn.Module): + """ + Next Convolution Block + """ + + def __init__(self, + in_channels, + out_channels, + stride=1, + path_dropout=0, + drop=0, + head_dim=32, + mlp_ratio=3): + super(NCB, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS) + assert out_channels % head_dim == 0 + + self.patch_embed = PatchEmbed(in_channels, out_channels, stride) + self.mhca = MHCA(out_channels, head_dim) + self.attention_path_dropout = DropPath(path_dropout) + + self.norm = norm_layer(out_channels) + self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop, bias=True) + self.mlp_path_dropout = DropPath(path_dropout) + self.is_bn_merged = False + + def forward(self, x): + x = self.patch_embed(x) + x = x + self.attention_path_dropout(self.mhca(x)) + if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged: + out = self.norm(x) + else: + out = x + x = x + self.mlp_path_dropout(self.mlp(out)) + return x + + +class E_MHSA(nn.Module): + """ + Efficient Multi-Head Self Attention + """ + + def __init__(self, + dim, + out_dim=None, + head_dim=32, + qkv_bias=True, + qk_scale=None, + attn_drop=0, + proj_drop=0., + sr_ratio=1): + super().__init__() + self.dim = dim + self.out_dim = out_dim if out_dim is not None else dim + self.num_heads = self.dim // head_dim + self.scale = qk_scale or head_dim**-0.5 + self.q = nn.Linear(dim, self.dim, bias=qkv_bias) + self.k = nn.Linear(dim, self.dim, bias=qkv_bias) + self.v = nn.Linear(dim, self.dim, bias=qkv_bias) + self.proj = nn.Linear(self.dim, self.out_dim) + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + self.N_ratio = sr_ratio**2 + if sr_ratio > 1: + self.sr = nn.AvgPool1d( + kernel_size=self.N_ratio, stride=self.N_ratio) + self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS) + self.is_bn_merge = False + + def forward(self, x): + B, N, C = x.shape + q = self.q(x) + q = q.reshape(B, N, self.num_heads, + int(C // self.num_heads)).permute(0, 2, 1, 3) + + if self.sr_ratio > 1: + x_ = x.transpose(1, 2) + x_ = self.sr(x_) + if not torch.onnx.is_in_onnx_export() and not self.is_bn_merge: + x_ = self.norm(x_) + x_ = x_.transpose(1, 2) + k = self.k(x_) + k = k.reshape(B, -1, self.num_heads, + int(C // self.num_heads)).permute(0, 2, 3, 1) + v = self.v(x_) + v = v.reshape(B, -1, self.num_heads, + int(C // self.num_heads)).permute(0, 2, 1, 3) + else: + k = self.k(x) + k = k.reshape(B, -1, self.num_heads, + int(C // self.num_heads)).permute(0, 2, 3, 1) + v = self.v(x) + v = v.reshape(B, -1, self.num_heads, + int(C // self.num_heads)).permute(0, 2, 1, 3) + attn = (q @ k) * self.scale + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class NTB(nn.Module): + """ + Next Transformer Block + """ + + def __init__( + self, + in_channels, + out_channels, + path_dropout, + stride=1, + sr_ratio=1, + mlp_ratio=2, + head_dim=32, + mix_block_ratio=0.75, + attn_drop=0, + drop=0, + ): + super(NTB, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.mix_block_ratio = mix_block_ratio + norm_func = partial(nn.BatchNorm2d, eps=NORM_EPS) + + self.mhsa_out_channels = _make_divisible( + int(out_channels * mix_block_ratio), 32) + self.mhca_out_channels = out_channels - self.mhsa_out_channels + + self.patch_embed = PatchEmbed(in_channels, self.mhsa_out_channels, + stride) + self.norm1 = norm_func(self.mhsa_out_channels) + self.e_mhsa = E_MHSA( + self.mhsa_out_channels, + head_dim=head_dim, + sr_ratio=sr_ratio, + attn_drop=attn_drop, + proj_drop=drop) + self.mhsa_path_dropout = DropPath(path_dropout * mix_block_ratio) + + self.projection = PatchEmbed( + self.mhsa_out_channels, self.mhca_out_channels, stride=1) + self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim) + self.mhca_path_dropout = DropPath(path_dropout * (1 - mix_block_ratio)) + + self.norm2 = norm_func(out_channels) + self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop) + self.mlp_path_dropout = DropPath(path_dropout) + + self.is_bn_merged = False + + def forward(self, x): + x = self.patch_embed(x) + B, C, H, W = x.shape + if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged: + out = self.norm1(x) + else: + out = x + out = rearrange(out, 'b c h w -> b (h w) c') # b n c + out = self.mhsa_path_dropout(self.e_mhsa(out)) + x = x + rearrange(out, 'b (h w) c -> b c h w', h=H) + + out = self.projection(x) + out = out + self.mhca_path_dropout(self.mhca(out)) + x = torch.cat([x, out], dim=1) + + if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged: + out = self.norm2(x) + else: + out = x + x = x + self.mlp_path_dropout(self.mlp(out)) + return x + + +@BACKBONES.register_module() +class NextViT(BaseBackbone): + stem_chs = { + 'x_small': [64, 32, 64], + 'small': [64, 32, 64], + 'base': [64, 32, 64], + 'large': [64, 32, 64], + } + depths = { + 'x_small': [1, 1, 5, 1], + 'small': [3, 4, 10, 3], + 'base': [3, 4, 20, 3], + 'large': [3, 4, 30, 3], + } + + def __init__(self, + arch='small', + path_dropout=0.2, + attn_drop=0, + drop=0, + strides=[1, 2, 2, 2], + sr_ratios=[8, 4, 2, 1], + head_dim=32, + mix_block_ratio=0.75, + resume='', + with_extra_norm=True, + norm_eval=False, + norm_cfg=None, + out_indices=-1, + frozen_stages=-1, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + stem_chs = self.stem_chs[arch] + depths = self.depths[arch] + + self.frozen_stages = frozen_stages + self.with_extra_norm = with_extra_norm + self.norm_eval = norm_eval + self.stage1_out_channels = [96] * (depths[0]) + self.stage2_out_channels = [192] * (depths[1] - 1) + [256] + self.stage3_out_channels = [384, 384, 384, 384, 512] * (depths[2] // 5) + self.stage4_out_channels = [768] * (depths[3] - 1) + [1024] + self.stage_out_channels = [ + self.stage1_out_channels, self.stage2_out_channels, + self.stage3_out_channels, self.stage4_out_channels + ] + + # Next Hybrid Strategy + self.stage1_block_types = [NCB] * depths[0] + self.stage2_block_types = [NCB] * (depths[1] - 1) + [NTB] + self.stage3_block_types = [NCB, NCB, NCB, NCB, NTB] * (depths[2] // 5) + self.stage4_block_types = [NCB] * (depths[3] - 1) + [NTB] + self.stage_block_types = [ + self.stage1_block_types, self.stage2_block_types, + self.stage3_block_types, self.stage4_block_types + ] + + self.stem = nn.Sequential( + ConvBNReLU(3, stem_chs[0], kernel_size=3, stride=2), + ConvBNReLU(stem_chs[0], stem_chs[1], kernel_size=3, stride=1), + ConvBNReLU(stem_chs[1], stem_chs[2], kernel_size=3, stride=1), + ConvBNReLU(stem_chs[2], stem_chs[2], kernel_size=3, stride=2), + ) + input_channel = stem_chs[-1] + features = [] + idx = 0 + dpr = [x.item() for x in torch.linspace(0, path_dropout, sum(depths)) + ] # stochastic depth decay rule + for stage_id in range(len(depths)): + numrepeat = depths[stage_id] + output_channels = self.stage_out_channels[stage_id] + block_types = self.stage_block_types[stage_id] + for block_id in range(numrepeat): + if strides[stage_id] == 2 and block_id == 0: + stride = 2 + else: + stride = 1 + output_channel = output_channels[block_id] + block_type = block_types[block_id] + if block_type is NCB: + layer = NCB( + input_channel, + output_channel, + stride=stride, + path_dropout=dpr[idx + block_id], + drop=drop, + head_dim=head_dim) + features.append(layer) + elif block_type is NTB: + layer = NTB( + input_channel, + output_channel, + path_dropout=dpr[idx + block_id], + stride=stride, + sr_ratio=sr_ratios[stage_id], + head_dim=head_dim, + mix_block_ratio=mix_block_ratio, + attn_drop=attn_drop, + drop=drop) + features.append(layer) + input_channel = output_channel + idx += numrepeat + self.features = nn.Sequential(*features) + self.norm = nn.BatchNorm2d(output_channel, eps=NORM_EPS) + + if isinstance(out_indices, int): + out_indices = [out_indices] + assert isinstance(out_indices, Sequence), \ + f'"out_indices" must by a sequence or int, ' \ + f'get {type(out_indices)} instead.' + for i, index in enumerate(out_indices): + if index < 0: + out_indices[i] = sum(depths) + index + assert out_indices[i] >= 0, f'Invalid out_indices {index}' + self.stage_out_idx = out_indices + + if norm_cfg is not None: + self = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self) + + def init_weights(self): + super(NextViT, self).init_weights() + if (isinstance(self.init_cfg, dict) + and self.init_cfg['type'] == 'Pretrained'): + # Suppress default init if use pretrained model. + return + + self._initialize_weights() + + def _initialize_weights(self): + for n, m in self.named_modules(): + if isinstance(m, (nn.BatchNorm2d, + nn.BatchNorm1d)): # nn.GroupNorm, nn.LayerNorm, + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=.02) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + outputs = list() + x = self.stem(x) + stage_id = 0 + for idx, layer in enumerate(self.features): + x = layer(x) + if idx == self.stage_out_idx[stage_id]: + if self.with_extra_norm: + x = self.norm(x) + outputs.append(x) + stage_id += 1 + return tuple(outputs) + + def _freeze_stages(self): + if self.frozen_stages > 0: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + for idx, layer in enumerate(self.features): + if idx <= self.stage_out_idx[self.frozen_stages - 1]: + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(NextViT, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() diff --git a/modelscope/models/cv/image_classification/mmcls_model.py b/modelscope/models/cv/image_classification/mmcls_model.py index a6789d0b..bd37d3de 100644 --- a/modelscope/models/cv/image_classification/mmcls_model.py +++ b/modelscope/models/cv/image_classification/mmcls_model.py @@ -1,9 +1,10 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. import os from modelscope.metainfo import Models from modelscope.models.base.base_torch_model import TorchModel from modelscope.models.builder import MODELS -from modelscope.utils.constant import Tasks +from modelscope.utils.constant import ModelFile, Tasks @MODELS.register_module( @@ -13,16 +14,25 @@ class ClassificationModel(TorchModel): def __init__(self, model_dir: str, **kwargs): import mmcv from mmcls.models import build_classifier + import modelscope.models.cv.image_classification.backbones + from modelscope.utils.hub import read_config super().__init__(model_dir) - config = os.path.join(model_dir, 'config.py') - - cfg = mmcv.Config.fromfile(config) - cfg.model.pretrained = None - self.cls_model = build_classifier(cfg.model) - + self.config_type = 'ms_config' + mm_config = os.path.join(model_dir, 'config.py') + if os.path.exists(mm_config): + cfg = mmcv.Config.fromfile(mm_config) + cfg.model.pretrained = None + self.cls_model = build_classifier(cfg.model) + self.config_type = 'mmcv_config' + else: + cfg = read_config(model_dir) + cfg.model.mm_model.pretrained = None + self.cls_model = build_classifier(cfg.model.mm_model) + self.config_type = 'ms_config' self.cfg = cfg + self.ms_model_dir = model_dir self.load_pretrained_checkpoint() @@ -33,7 +43,13 @@ class ClassificationModel(TorchModel): def load_pretrained_checkpoint(self): import mmcv - checkpoint_path = os.path.join(self.ms_model_dir, 'checkpoints.pth') + if os.path.exists( + os.path.join(self.ms_model_dir, ModelFile.TORCH_MODEL_FILE)): + checkpoint_path = os.path.join(self.ms_model_dir, + ModelFile.TORCH_MODEL_FILE) + else: + checkpoint_path = os.path.join(self.ms_model_dir, + 'checkpoints.pth') if os.path.exists(checkpoint_path): checkpoint = mmcv.runner.load_checkpoint( self.cls_model, checkpoint_path, map_location='cpu') diff --git a/modelscope/models/cv/image_classification/utils.py b/modelscope/models/cv/image_classification/utils.py new file mode 100644 index 00000000..32777b9b --- /dev/null +++ b/modelscope/models/cv/image_classification/utils.py @@ -0,0 +1,100 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp + +import numpy as np +from mmcls.datasets.base_dataset import BaseDataset + + +def get_trained_checkpoints_name(work_path): + import os + file_list = os.listdir(work_path) + last = 0 + model_name = None + # find the best model + if model_name is None: + for f_name in file_list: + if 'best_' in f_name and f_name.endswith('.pth'): + best_epoch = f_name.replace('.pth', '').split('_')[-1] + if best_epoch.isdigit(): + last = int(best_epoch) + model_name = f_name + return model_name + # or find the latest model + if model_name is None: + for f_name in file_list: + if 'epoch_' in f_name and f_name.endswith('.pth'): + epoch_num = f_name.replace('epoch_', '').replace('.pth', '') + if not epoch_num.isdigit(): + continue + ind = int(epoch_num) + if ind > last: + last = ind + model_name = f_name + return model_name + + +def preprocess_transform(cfgs): + if cfgs is None: + return None + for i, cfg in enumerate(cfgs): + if cfg.type == 'Resize': + if isinstance(cfg.size, list): + cfgs[i].size = tuple(cfg.size) + return cfgs + + +def get_ms_dataset_root(ms_dataset): + if ms_dataset is None or len(ms_dataset) < 1: + return None + try: + data_root = ms_dataset[0]['image:FILE'].split('extracted')[0] + path_post = ms_dataset[0]['image:FILE'].split('extracted')[1].split( + '/') + extracted_data_root = osp.join(data_root, 'extracted', path_post[1], + path_post[2]) + return extracted_data_root + except Exception as e: + raise ValueError(f'Dataset Error: {e}') + return None + + +def get_classes(classes=None): + import mmcv + if isinstance(classes, str): + # take it as a file path + class_names = mmcv.list_from_file(classes) + elif isinstance(classes, (tuple, list)): + class_names = classes + else: + raise ValueError(f'Unsupported type {type(classes)} of classes.') + + return class_names + + +class MmDataset(BaseDataset): + + def __init__(self, ms_dataset, pipeline, classes=None, test_mode=False): + self.ms_dataset = ms_dataset + if len(self.ms_dataset) < 1: + raise ValueError('Dataset Error: dataset is empty') + super(MmDataset, self).__init__( + data_prefix='', + pipeline=pipeline, + classes=classes, + test_mode=test_mode) + + def load_annotations(self): + if self.CLASSES is None: + raise ValueError( + f'Dataset Error: Not found classesname.txt: {self.CLASSES}') + + data_infos = [] + for data_info in self.ms_dataset: + filename = data_info['image:FILE'] + gt_label = data_info['category'] + info = {'img_prefix': self.data_prefix} + info['img_info'] = {'filename': filename} + info['gt_label'] = np.array(gt_label, dtype=np.int64) + data_infos.append(info) + + return data_infos diff --git a/modelscope/pipelines/cv/image_classification_pipeline.py b/modelscope/pipelines/cv/image_classification_pipeline.py index 8d4f7694..b9d7376b 100644 --- a/modelscope/pipelines/cv/image_classification_pipeline.py +++ b/modelscope/pipelines/cv/image_classification_pipeline.py @@ -45,6 +45,9 @@ class ImageClassificationPipeline(Pipeline): @PIPELINES.register_module( Tasks.image_classification, module_name=Pipelines.daily_image_classification) +@PIPELINES.register_module( + Tasks.image_classification, + module_name=Pipelines.nextvit_small_daily_image_classification) class GeneralImageClassificationPipeline(Pipeline): def __init__(self, model: str, **kwargs): @@ -60,6 +63,7 @@ class GeneralImageClassificationPipeline(Pipeline): def preprocess(self, input: Input) -> Dict[str, Any]: from mmcls.datasets.pipelines import Compose from mmcv.parallel import collate, scatter + from modelscope.models.cv.image_classification.utils import preprocess_transform if isinstance(input, str): img = np.array(load_image(input)) elif isinstance(input, PIL.Image.Image): @@ -72,12 +76,20 @@ class GeneralImageClassificationPipeline(Pipeline): raise TypeError(f'input should be either str, PIL.Image,' f' np.array, but got {type(input)}') - mmcls_cfg = self.model.cfg - # build the data pipeline - if mmcls_cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile': - mmcls_cfg.data.test.pipeline.pop(0) - data = dict(img=img) - test_pipeline = Compose(mmcls_cfg.data.test.pipeline) + cfg = self.model.cfg + + if self.model.config_type == 'mmcv_config': + if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile': + cfg.data.test.pipeline.pop(0) + data = dict(img=img) + test_pipeline = Compose(cfg.data.test.pipeline) + else: + if cfg.preprocessor.val[0]['type'] == 'LoadImageFromFile': + cfg.preprocessor.val.pop(0) + data = dict(img=img) + data_pipeline = preprocess_transform(cfg.preprocessor.val) + test_pipeline = Compose(data_pipeline) + data = test_pipeline(data) data = collate([data], samples_per_gpu=1) if next(self.model.parameters()).is_cuda: diff --git a/modelscope/preprocessors/image.py b/modelscope/preprocessors/image.py index 60f6e0eb..f0401f16 100644 --- a/modelscope/preprocessors/image.py +++ b/modelscope/preprocessors/image.py @@ -289,3 +289,37 @@ class VideoSummarizationPreprocessor(Preprocessor): Dict[str, Any]: the preprocessed data """ return data + + +@PREPROCESSORS.register_module( + Fields.cv, + module_name=Preprocessors.image_classification_bypass_preprocessor) +class ImageClassificationBypassPreprocessor(Preprocessor): + + def __init__(self, *args, **kwargs): + """image classification bypass preprocessor in the fine-tune scenario + """ + super().__init__(*args, **kwargs) + + self.training = kwargs.pop('training', True) + self.preprocessor_train_cfg = kwargs.pop('train', None) + self.preprocessor_val_cfg = kwargs.pop('val', None) + + def train(self): + self.training = True + return + + def eval(self): + self.training = False + return + + def __call__(self, results: Dict[str, Any]): + """process the raw input data + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + Dict[str, Any] | None: the preprocessed data + """ + pass diff --git a/modelscope/trainers/cv/image_classifition_trainer.py b/modelscope/trainers/cv/image_classifition_trainer.py new file mode 100644 index 00000000..21e98910 --- /dev/null +++ b/modelscope/trainers/cv/image_classifition_trainer.py @@ -0,0 +1,502 @@ +# Part of the implementation is borrowed and modified from mmclassification, +# publicly available at https://github.com/open-mmlab/mmclassification +import copy +import os +import os.path as osp +import time +from typing import Callable, Dict, Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torch.utils.data import Dataset + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.models.base import TorchModel +from modelscope.msdatasets.ms_dataset import MsDataset +from modelscope.preprocessors.base import Preprocessor +from modelscope.trainers.base import BaseTrainer +from modelscope.trainers.builder import TRAINERS +from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile +from modelscope.utils.logger import get_logger + + +def train_model(model, + dataset, + cfg, + distributed=False, + val_dataset=None, + timestamp=None, + device=None, + meta=None): + import torch + import warnings + from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook, + build_optimizer, build_runner, get_dist_info) + from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook + from mmcls.datasets import build_dataloader + from mmcls.utils import (wrap_distributed_model, + wrap_non_distributed_model) + from mmcv.parallel import MMDataParallel, MMDistributedDataParallel + + logger = get_logger() + + # prepare data loaders + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] + sampler_cfg = cfg.train.get('sampler', None) + + data_loaders = [ + build_dataloader( + ds, + cfg.train.dataloader.batch_size_per_gpu, + cfg.train.dataloader.workers_per_gpu, + # cfg.gpus will be ignored if distributed + num_gpus=len(cfg.gpu_ids), + dist=distributed, + round_up=True, + seed=cfg.seed, + sampler_cfg=sampler_cfg) for ds in dataset + ] + + # put model on gpus + if distributed: + find_unused_parameters = cfg.get('find_unused_parameters', False) + # Sets the `find_unused_parameters` parameter in + # torch.nn.parallel.DistributedDataParallel + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters) + else: + if device == 'cpu': + logger.warning( + 'The argument `device` is deprecated. To use cpu to train, ' + 'please refers to https://mmclassification.readthedocs.io/en' + '/latest/getting_started.html#train-a-model') + model = model.cpu() + else: + model = MMDataParallel(model, device_ids=cfg.gpu_ids) + if not model.device_ids: + from mmcv import __version__, digit_version + assert digit_version(__version__) >= (1, 4, 4), \ + 'To train with CPU, please confirm your mmcv version ' \ + 'is not lower than v1.4.4' + + # build runner + optimizer = build_optimizer(model, cfg.train.optimizer) + + if cfg.train.get('runner') is None: + cfg.train.runner = { + 'type': 'EpochBasedRunner', + 'max_epochs': cfg.train.max_epochs + } + logger.warning( + 'config is now expected to have a `runner` section, ' + 'please set `runner` in your config.', UserWarning) + + runner = build_runner( + cfg.train.runner, + default_args=dict( + model=model, + batch_processor=None, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta)) + + # an ugly walkaround to make the .log and .log.json filenames the same + runner.timestamp = timestamp + + # fp16 setting + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + optimizer_config = Fp16OptimizerHook( + **cfg.train.optimizer_config, **fp16_cfg, distributed=distributed) + elif distributed and 'type' not in cfg.train.optimizer_config: + optimizer_config = DistOptimizerHook(**cfg.train.optimizer_config) + else: + optimizer_config = cfg.train.optimizer_config + + # register hooks + runner.register_training_hooks( + cfg.train.lr_config, + optimizer_config, + cfg.train.checkpoint_config, + cfg.train.log_config, + cfg.train.get('momentum_config', None), + custom_hooks_config=cfg.train.get('custom_hooks', None)) + if distributed and cfg.train.runner['type'] == 'EpochBasedRunner': + runner.register_hook(DistSamplerSeedHook()) + + # register eval hooks + if val_dataset is not None: + val_dataloader = build_dataloader( + val_dataset, + samples_per_gpu=cfg.evaluation.dataloader.batch_size_per_gpu, + workers_per_gpu=cfg.evaluation.dataloader.workers_per_gpu, + dist=distributed, + shuffle=False, + round_up=True) + eval_cfg = cfg.train.get('evaluation', {}) + eval_cfg['by_epoch'] = cfg.train.runner['type'] != 'IterBasedRunner' + eval_hook = DistEvalHook if distributed else EvalHook + # `EvalHook` needs to be executed after `IterTimerHook`. + # Otherwise, it will cause a bug if use `IterBasedRunner`. + # Refers to https://github.com/open-mmlab/mmcv/issues/1261 + runner.register_hook( + eval_hook(val_dataloader, **eval_cfg), priority='LOW') + + if cfg.train.resume_from: + runner.resume(cfg.train.resume_from, map_location='cpu') + elif cfg.train.load_from: + runner.load_checkpoint(cfg.train.load_from) + + cfg.train.workflow = [tuple(flow) for flow in cfg.train.workflow] + runner.run(data_loaders, cfg.train.workflow) + + +@TRAINERS.register_module(module_name=Trainers.image_classification) +class ImageClassifitionTrainer(BaseTrainer): + + def __init__( + self, + model: Optional[Union[TorchModel, nn.Module, str]] = None, + cfg_file: Optional[str] = None, + arg_parse_fn: Optional[Callable] = None, + data_collator: Optional[Union[Callable, Dict[str, + Callable]]] = None, + train_dataset: Optional[Union[MsDataset, Dataset]] = None, + eval_dataset: Optional[Union[MsDataset, Dataset]] = None, + preprocessor: Optional[Union[Preprocessor, + Dict[str, Preprocessor]]] = None, + optimizers: Tuple[torch.optim.Optimizer, + torch.optim.lr_scheduler._LRScheduler] = (None, + None), + model_revision: Optional[str] = DEFAULT_MODEL_REVISION, + seed: int = 0, + cfg_modify_fn: Optional[Callable] = None, + **kwargs): + """ High-level finetune api for Image Classifition. + + Args: + model: model id + model_version: model version, default is None. + cfg_modify_fn: An input fn which is used to modify the cfg read out of the file. + """ + import torch + import mmcv + from modelscope.models.cv.image_classification.utils import get_ms_dataset_root, get_classes + from mmcls.models import build_classifier + from mmcv.runner import get_dist_info, init_dist + from mmcls.apis import set_random_seed + from mmcls.utils import collect_env + import modelscope.models.cv.image_classification.backbones + + self._seed = seed + set_random_seed(self._seed) + if isinstance(model, str): + if os.path.exists(model): + self.model_dir = model if os.path.isdir( + model) else os.path.dirname(model) + else: + self.model_dir = snapshot_download( + model, revision=model_revision) + if cfg_file is None: + cfg_file = os.path.join(self.model_dir, + ModelFile.CONFIGURATION) + else: + assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!' + self.model_dir = os.path.dirname(cfg_file) + + super().__init__(cfg_file, arg_parse_fn) + cfg = self.cfg + + if 'work_dir' in kwargs: + self.work_dir = kwargs['work_dir'] + else: + self.work_dir = self.cfg.train.get('work_dir', './work_dir') + mmcv.mkdir_or_exist(osp.abspath(self.work_dir)) + cfg.work_dir = self.work_dir + + # evaluate config seting + self.eval_checkpoint_path = os.path.join(self.model_dir, + ModelFile.TORCH_MODEL_FILE) + + # train config seting + if 'resume_from' in kwargs: + cfg.train.resume_from = kwargs['resume_from'] + else: + cfg.train.resume_from = cfg.train.get('resume_from', None) + + if 'load_from' in kwargs: + cfg.train.load_from = kwargs['load_from'] + else: + if cfg.train.get('resume_from', None) is None: + cfg.train.load_from = os.path.join(self.model_dir, + ModelFile.TORCH_MODEL_FILE) + + if 'device' in kwargs: + cfg.device = kwargs['device'] + else: + cfg.device = cfg.get('device', 'cuda') + + if 'gpu_ids' in kwargs: + cfg.gpu_ids = kwargs['gpu_ids'][0:1] + else: + cfg.gpu_ids = [0] + + if 'fp16' in kwargs: + cfg.fp16 = None if kwargs['fp16'] is None else kwargs['fp16'] + else: + cfg.fp16 = None + + # no_validate=True will not evaluate checkpoint during training + cfg.no_validate = kwargs.get('no_validate', False) + + if cfg_modify_fn is not None: + cfg = cfg_modify_fn(cfg) + + if 'max_epochs' not in kwargs: + assert hasattr( + self.cfg.train, + 'max_epochs'), 'max_epochs is missing in configuration file' + self.max_epochs = self.cfg.train.max_epochs + else: + self.max_epochs = kwargs['max_epochs'] + cfg.train.max_epochs = self.max_epochs + if cfg.train.get('runner', None) is not None: + cfg.train.runner.max_epochs = self.max_epochs + + if 'launcher' in kwargs: + distributed = True + dist_params = kwargs['dist_params'] \ + if 'dist_params' in kwargs else {'backend': 'nccl'} + init_dist(kwargs['launcher'], **dist_params) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = list(range(world_size)) + else: + distributed = False + + # init the logger before other steps + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + log_file = osp.join(self.work_dir, f'{timestamp}.log') + logger = get_logger(log_file=log_file) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) + dash_line = '-' * 60 + '\n' + logger.info('Environment info:\n' + dash_line + env_info + '\n' + + dash_line) + meta['env_info'] = env_info + meta['config'] = cfg.pretty_text + # log some basic info + logger.info(f'Distributed training: {distributed}') + logger.info(f'Config:\n{cfg.pretty_text}') + + # set random seeds + cfg.seed = self._seed + _deterministic = kwargs.get('deterministic', False) + logger.info(f'Set random seed to {cfg.seed}, ' + f'deterministic: {_deterministic}') + set_random_seed(cfg.seed, deterministic=_deterministic) + + meta['seed'] = cfg.seed + meta['exp_name'] = osp.basename(cfg_file) + + # dataset + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + + # model + model = build_classifier(self.cfg.model.mm_model) + model.init_weights() + + self.cfg = cfg + self.device = cfg.device + self.cfg_file = cfg_file + self.model = model + self.distributed = distributed + self.timestamp = timestamp + self.meta = meta + self.logger = logger + + def train(self, *args, **kwargs): + from mmcls import __version__ + from modelscope.models.cv.image_classification.utils import get_ms_dataset_root, MmDataset, preprocess_transform + from mmcls.utils import setup_multi_processes + + if self.train_dataset is None: + raise ValueError( + "Not found train dataset, please set the 'train_dataset' parameter!" + ) + + self.cfg.model.mm_model.pretrained = None + + # dump config + self.cfg.dump(osp.join(self.work_dir, osp.basename(self.cfg_file))) + + # build the dataloader + if self.cfg.dataset.classes is None: + data_root = get_ms_dataset_root(self.train_dataset) + classname_path = osp.join(data_root, 'classname.txt') + classes = classname_path if osp.exists(classname_path) else None + else: + classes = cfg.dataset.classes + + datasets = [ + MmDataset( + self.train_dataset, + pipeline=self.cfg.preprocessor.train, + classes=classes) + ] + + if len(self.cfg.train.workflow) == 2: + if self.eval_dataset is None: + raise ValueError( + "Not found evaluate dataset, please set the 'eval_dataset' parameter!" + ) + val_data_pipeline = self.cfg.preprocessor.train + val_dataset = MmDataset( + self.eval_dataset, pipeline=val_data_pipeline, classes=classes) + datasets.append(val_dataset) + + # save mmcls version, config file content and class names in + # checkpoints as meta data + self.meta.update( + dict( + mmcls_version=__version__, + config=self.cfg.pretty_text, + CLASSES=datasets[0].CLASSES)) + + val_dataset = None + if not self.cfg.no_validate: + val_dataset = MmDataset( + self.eval_dataset, + pipeline=preprocess_transform(self.cfg.preprocessor.val), + classes=classes) + + # add an attribute for visualization convenience + train_model( + self.model, + datasets, + self.cfg, + distributed=self.distributed, + val_dataset=val_dataset, + timestamp=self.timestamp, + device='cpu' if self.device == 'cpu' else 'cuda', + meta=self.meta) + + def evaluate(self, + checkpoint_path: str = None, + *args, + **kwargs) -> Dict[str, float]: + import warnings + import torch + from modelscope.models.cv.image_classification.utils import ( + get_ms_dataset_root, MmDataset, preprocess_transform, + get_trained_checkpoints_name) + from mmcls.datasets import build_dataloader + from mmcv.runner import get_dist_info, load_checkpoint, wrap_fp16_model + from mmcv.parallel import MMDataParallel, MMDistributedDataParallel + from mmcls.apis import multi_gpu_test, single_gpu_test + from mmcls.utils import setup_multi_processes + + if self.eval_dataset is None: + raise ValueError( + "Not found evaluate dataset, please set the 'eval_dataset' parameter!" + ) + + self.cfg.model.mm_model.pretrained = None + + # build the dataloader + if self.cfg.dataset.classes is None: + data_root = get_ms_dataset_root(self.eval_dataset) + classname_path = osp.join(data_root, 'classname.txt') + classes = classname_path if osp.exists(classname_path) else None + else: + classes = cfg.dataset.classes + dataset = MmDataset( + self.eval_dataset, + pipeline=preprocess_transform(self.cfg.preprocessor.val), + classes=classes) + # the extra round_up data will be removed during gpu/cpu collect + data_loader = build_dataloader( + dataset, + samples_per_gpu=self.cfg.evaluation.dataloader.batch_size_per_gpu, + workers_per_gpu=self.cfg.evaluation.dataloader.workers_per_gpu, + dist=self.distributed, + shuffle=False, + round_up=True) + + model = copy.deepcopy(self.model) + fp16_cfg = self.cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + if checkpoint_path is None: + trained_checkpoints = get_trained_checkpoints_name(self.work_dir) + if trained_checkpoints is not None: + checkpoint = load_checkpoint( + model, + os.path.join(self.work_dir, trained_checkpoints), + map_location='cpu') + else: + checkpoint = load_checkpoint( + model, self.eval_checkpoint_path, map_location='cpu') + else: + checkpoint = load_checkpoint( + model, checkpoint_path, map_location='cpu') + + if 'CLASSES' in checkpoint.get('meta', {}): + CLASSES = checkpoint['meta']['CLASSES'] + else: + from mmcls.datasets import ImageNet + self.logger.warning( + 'Class names are not saved in the checkpoint\'s ' + 'meta data, use imagenet by default.') + CLASSES = ImageNet.CLASSES + + if not self.distributed: + if self.device == 'cpu': + model = model.cpu() + else: + model = MMDataParallel(model, device_ids=self.cfg.gpu_ids) + if not model.device_ids: + assert mmcv.digit_version(mmcv.__version__) >= (1, 4, 4), \ + 'To test with CPU, please confirm your mmcv version ' \ + 'is not lower than v1.4.4' + model.CLASSES = CLASSES + show_kwargs = {} + outputs = single_gpu_test(model, data_loader, False, None, + **show_kwargs) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False) + outputs = multi_gpu_test(model, data_loader, None, True) + + rank, _ = get_dist_info() + if rank == 0: + results = {} + logger = get_logger() + metric_options = self.cfg.evaluation.get('metric_options', {}) + if 'topk' in metric_options.keys(): + metric_options['topk'] = tuple(metric_options['topk']) + if self.cfg.evaluation.metrics: + eval_results = dataset.evaluate( + results=outputs, + metric=self.cfg.evaluation.metrics, + metric_options=metric_options, + logger=logger) + results.update(eval_results) + + return results + + return None diff --git a/tests/pipelines/test_general_image_classification.py b/tests/pipelines/test_general_image_classification.py index d5357f02..7798c399 100644 --- a/tests/pipelines/test_general_image_classification.py +++ b/tests/pipelines/test_general_image_classification.py @@ -31,6 +31,15 @@ class GeneralImageClassificationTest(unittest.TestCase, result = general_image_classification('data/test/images/bird.JPEG') print(result) + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_nextvit(self): + nexit_image_classification = pipeline( + Tasks.image_classification, + model='damo/cv_nextvit-small_image-classification_Dailylife-labels' + ) + result = nexit_image_classification('data/test/images/bird.JPEG') + print(result) + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_Dailylife_default(self): general_image_classification = pipeline(Tasks.image_classification) diff --git a/tests/trainers/test_general_image_classification_trainer.py b/tests/trainers/test_general_image_classification_trainer.py new file mode 100644 index 00000000..e91bde18 --- /dev/null +++ b/tests/trainers/test_general_image_classification_trainer.py @@ -0,0 +1,96 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest +import zipfile +from functools import partial + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.metainfo import Trainers +from modelscope.msdatasets import MsDataset +from modelscope.trainers import build_trainer +from modelscope.utils.config import Config, ConfigDict +from modelscope.utils.constant import DownloadMode, ModelFile +from modelscope.utils.test_utils import test_level + + +class TestGeneralImageClassificationTestTrainer(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + try: + self.train_dataset = MsDataset.load( + 'cats_and_dogs', + namespace='tany0699', + subset_name='default', + split='train') + + self.eval_dataset = MsDataset.load( + 'cats_and_dogs', + namespace='tany0699', + subset_name='default', + split='validation') + except Exception as e: + print(f'Download dataset error: {e}') + + self.max_epochs = 1 + + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_nextvit_dailylife_train(self): + model_id = 'damo/cv_nextvit-small_image-classification_Dailylife-labels' + + def cfg_modify_fn(cfg): + cfg.train.dataloader.batch_size_per_gpu = 32 + cfg.train.dataloader.workers_per_gpu = 1 + cfg.train.max_epochs = self.max_epochs + cfg.model.mm_model.head.num_classes = 2 + cfg.train.optimizer.lr = 1e-4 + cfg.train.lr_config.warmup_iters = 1 + cfg.train.evaluation.metric_options = {'topk': (1, )} + cfg.evaluation.metric_options = {'topk': (1, )} + return cfg + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + cfg_modify_fn=cfg_modify_fn) + + trainer = build_trainer( + name=Trainers.image_classification, default_args=kwargs) + trainer.train() + + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_nextvit_dailylife_eval(self): + model_id = 'damo/cv_nextvit-small_image-classification_Dailylife-labels' + + kwargs = dict( + model=model_id, + work_dir=self.tmp_dir, + train_dataset=None, + eval_dataset=self.eval_dataset) + + trainer = build_trainer( + name=Trainers.image_classification, default_args=kwargs) + result = trainer.evaluate() + print(result) + + +if __name__ == '__main__': + unittest.main()