Browse Source

add nextvit-small_image-classification_Dailylife-labels model

支持1130新上线模.
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10886253
master^2
ziyuan.tw yingda.chen 3 years ago
parent
commit
31316b8d29
10 changed files with 1329 additions and 14 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/cv/image_classification/backbones/__init__.py
  3. +541
    -0
      modelscope/models/cv/image_classification/backbones/nextvit.py
  4. +24
    -8
      modelscope/models/cv/image_classification/mmcls_model.py
  5. +100
    -0
      modelscope/models/cv/image_classification/utils.py
  6. +18
    -6
      modelscope/pipelines/cv/image_classification_pipeline.py
  7. +34
    -0
      modelscope/preprocessors/image.py
  8. +502
    -0
      modelscope/trainers/cv/image_classifition_trainer.py
  9. +9
    -0
      tests/pipelines/test_general_image_classification.py
  10. +96
    -0
      tests/trainers/test_general_image_classification_trainer.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -185,6 +185,7 @@ class Pipelines(object):
live_category = 'live-category'
general_image_classification = 'vit-base_image-classification_ImageNet-labels'
daily_image_classification = 'vit-base_image-classification_Dailylife-labels'
nextvit_small_daily_image_classification = 'nextvit-small_image-classification_Dailylife-labels'
image_color_enhance = 'csrnet-image-color-enhance'
virtual_try_on = 'virtual-try-on'
image_colorization = 'unet-image-colorization'
@@ -330,6 +331,7 @@ class Trainers(object):
image_inpainting = 'image-inpainting'
referring_video_object_segmentation = 'referring-video-object-segmentation'
image_classification_team = 'image-classification-team'
image_classification = 'image-classification'

# nlp trainers
bert_sentiment_analysis = 'bert-sentiment-analysis'
@@ -365,6 +367,7 @@ class Preprocessors(object):
image_portrait_enhancement_preprocessor = 'image-portrait-enhancement-preprocessor'
video_summarization_preprocessor = 'video-summarization-preprocessor'
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'
image_classification_bypass_preprocessor = 'image-classification-bypass-preprocessor'

# nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer'


+ 2
- 0
modelscope/models/cv/image_classification/backbones/__init__.py View File

@@ -0,0 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .nextvit import NextViT

+ 541
- 0
modelscope/models/cv/image_classification/backbones/nextvit.py View File

@@ -0,0 +1,541 @@
# Part of the implementation is borrowed and modified from Next-ViT,
# publicly available at https://github.com/bytedance/Next-ViT
import collections.abc
import itertools
import math
import os
import warnings
from functools import partial
from typing import Dict, Sequence

import torch
import torch.nn as nn
from einops import rearrange
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.models.builder import BACKBONES
from mmcv.cnn.bricks import DropPath, build_activation_layer, build_norm_layer
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm

NORM_EPS = 1e-5


def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.

if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)

with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
ll = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)

# Uniformly fill tensor with values from [ll, u], then translate to
# [2ll-1, 2u-1].
tensor.uniform_(2 * ll - 1, 2 * u - 1)

# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()

# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)

# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)


class ConvBNReLU(nn.Module):

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=1,
groups=groups,
bias=False)
self.norm = nn.BatchNorm2d(out_channels, eps=NORM_EPS)
self.act = nn.ReLU(inplace=True)

def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
return x


def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v


class PatchEmbed(nn.Module):

def __init__(self, in_channels, out_channels, stride=1):
super(PatchEmbed, self).__init__()
norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
if stride == 2:
self.avgpool = nn.AvgPool2d((2, 2),
stride=2,
ceil_mode=True,
count_include_pad=False)
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, bias=False)
self.norm = norm_layer(out_channels)
elif in_channels != out_channels:
self.avgpool = nn.Identity()
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, bias=False)
self.norm = norm_layer(out_channels)
else:
self.avgpool = nn.Identity()
self.conv = nn.Identity()
self.norm = nn.Identity()

def forward(self, x):
return self.norm(self.conv(self.avgpool(x)))


class MHCA(nn.Module):
"""
Multi-Head Convolutional Attention
"""

def __init__(self, out_channels, head_dim):
super(MHCA, self).__init__()
norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
self.group_conv3x3 = nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
groups=out_channels // head_dim,
bias=False)
self.norm = norm_layer(out_channels)
self.act = nn.ReLU(inplace=True)
self.projection = nn.Conv2d(
out_channels, out_channels, kernel_size=1, bias=False)

def forward(self, x):
out = self.group_conv3x3(x)
out = self.norm(out)
out = self.act(out)
out = self.projection(out)
return out


class Mlp(nn.Module):

def __init__(self,
in_features,
out_features=None,
mlp_ratio=None,
drop=0.,
bias=True):
super().__init__()
out_features = out_features or in_features
hidden_dim = _make_divisible(in_features * mlp_ratio, 32)
self.conv1 = nn.Conv2d(
in_features, hidden_dim, kernel_size=1, bias=bias)
self.act = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(
hidden_dim, out_features, kernel_size=1, bias=bias)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.conv1(x)
x = self.act(x)
x = self.drop(x)
x = self.conv2(x)
x = self.drop(x)
return x


class NCB(nn.Module):
"""
Next Convolution Block
"""

def __init__(self,
in_channels,
out_channels,
stride=1,
path_dropout=0,
drop=0,
head_dim=32,
mlp_ratio=3):
super(NCB, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
norm_layer = partial(nn.BatchNorm2d, eps=NORM_EPS)
assert out_channels % head_dim == 0

self.patch_embed = PatchEmbed(in_channels, out_channels, stride)
self.mhca = MHCA(out_channels, head_dim)
self.attention_path_dropout = DropPath(path_dropout)

self.norm = norm_layer(out_channels)
self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop, bias=True)
self.mlp_path_dropout = DropPath(path_dropout)
self.is_bn_merged = False

def forward(self, x):
x = self.patch_embed(x)
x = x + self.attention_path_dropout(self.mhca(x))
if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
out = self.norm(x)
else:
out = x
x = x + self.mlp_path_dropout(self.mlp(out))
return x


class E_MHSA(nn.Module):
"""
Efficient Multi-Head Self Attention
"""

def __init__(self,
dim,
out_dim=None,
head_dim=32,
qkv_bias=True,
qk_scale=None,
attn_drop=0,
proj_drop=0.,
sr_ratio=1):
super().__init__()
self.dim = dim
self.out_dim = out_dim if out_dim is not None else dim
self.num_heads = self.dim // head_dim
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(dim, self.dim, bias=qkv_bias)
self.k = nn.Linear(dim, self.dim, bias=qkv_bias)
self.v = nn.Linear(dim, self.dim, bias=qkv_bias)
self.proj = nn.Linear(self.dim, self.out_dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)

self.sr_ratio = sr_ratio
self.N_ratio = sr_ratio**2
if sr_ratio > 1:
self.sr = nn.AvgPool1d(
kernel_size=self.N_ratio, stride=self.N_ratio)
self.norm = nn.BatchNorm1d(dim, eps=NORM_EPS)
self.is_bn_merge = False

def forward(self, x):
B, N, C = x.shape
q = self.q(x)
q = q.reshape(B, N, self.num_heads,
int(C // self.num_heads)).permute(0, 2, 1, 3)

if self.sr_ratio > 1:
x_ = x.transpose(1, 2)
x_ = self.sr(x_)
if not torch.onnx.is_in_onnx_export() and not self.is_bn_merge:
x_ = self.norm(x_)
x_ = x_.transpose(1, 2)
k = self.k(x_)
k = k.reshape(B, -1, self.num_heads,
int(C // self.num_heads)).permute(0, 2, 3, 1)
v = self.v(x_)
v = v.reshape(B, -1, self.num_heads,
int(C // self.num_heads)).permute(0, 2, 1, 3)
else:
k = self.k(x)
k = k.reshape(B, -1, self.num_heads,
int(C // self.num_heads)).permute(0, 2, 3, 1)
v = self.v(x)
v = v.reshape(B, -1, self.num_heads,
int(C // self.num_heads)).permute(0, 2, 1, 3)
attn = (q @ k) * self.scale

attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x


class NTB(nn.Module):
"""
Next Transformer Block
"""

def __init__(
self,
in_channels,
out_channels,
path_dropout,
stride=1,
sr_ratio=1,
mlp_ratio=2,
head_dim=32,
mix_block_ratio=0.75,
attn_drop=0,
drop=0,
):
super(NTB, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.mix_block_ratio = mix_block_ratio
norm_func = partial(nn.BatchNorm2d, eps=NORM_EPS)

self.mhsa_out_channels = _make_divisible(
int(out_channels * mix_block_ratio), 32)
self.mhca_out_channels = out_channels - self.mhsa_out_channels

self.patch_embed = PatchEmbed(in_channels, self.mhsa_out_channels,
stride)
self.norm1 = norm_func(self.mhsa_out_channels)
self.e_mhsa = E_MHSA(
self.mhsa_out_channels,
head_dim=head_dim,
sr_ratio=sr_ratio,
attn_drop=attn_drop,
proj_drop=drop)
self.mhsa_path_dropout = DropPath(path_dropout * mix_block_ratio)

self.projection = PatchEmbed(
self.mhsa_out_channels, self.mhca_out_channels, stride=1)
self.mhca = MHCA(self.mhca_out_channels, head_dim=head_dim)
self.mhca_path_dropout = DropPath(path_dropout * (1 - mix_block_ratio))

self.norm2 = norm_func(out_channels)
self.mlp = Mlp(out_channels, mlp_ratio=mlp_ratio, drop=drop)
self.mlp_path_dropout = DropPath(path_dropout)

self.is_bn_merged = False

def forward(self, x):
x = self.patch_embed(x)
B, C, H, W = x.shape
if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
out = self.norm1(x)
else:
out = x
out = rearrange(out, 'b c h w -> b (h w) c') # b n c
out = self.mhsa_path_dropout(self.e_mhsa(out))
x = x + rearrange(out, 'b (h w) c -> b c h w', h=H)

out = self.projection(x)
out = out + self.mhca_path_dropout(self.mhca(out))
x = torch.cat([x, out], dim=1)

if not torch.onnx.is_in_onnx_export() and not self.is_bn_merged:
out = self.norm2(x)
else:
out = x
x = x + self.mlp_path_dropout(self.mlp(out))
return x


@BACKBONES.register_module()
class NextViT(BaseBackbone):
stem_chs = {
'x_small': [64, 32, 64],
'small': [64, 32, 64],
'base': [64, 32, 64],
'large': [64, 32, 64],
}
depths = {
'x_small': [1, 1, 5, 1],
'small': [3, 4, 10, 3],
'base': [3, 4, 20, 3],
'large': [3, 4, 30, 3],
}

def __init__(self,
arch='small',
path_dropout=0.2,
attn_drop=0,
drop=0,
strides=[1, 2, 2, 2],
sr_ratios=[8, 4, 2, 1],
head_dim=32,
mix_block_ratio=0.75,
resume='',
with_extra_norm=True,
norm_eval=False,
norm_cfg=None,
out_indices=-1,
frozen_stages=-1,
init_cfg=None):
super().__init__(init_cfg=init_cfg)

stem_chs = self.stem_chs[arch]
depths = self.depths[arch]

self.frozen_stages = frozen_stages
self.with_extra_norm = with_extra_norm
self.norm_eval = norm_eval
self.stage1_out_channels = [96] * (depths[0])
self.stage2_out_channels = [192] * (depths[1] - 1) + [256]
self.stage3_out_channels = [384, 384, 384, 384, 512] * (depths[2] // 5)
self.stage4_out_channels = [768] * (depths[3] - 1) + [1024]
self.stage_out_channels = [
self.stage1_out_channels, self.stage2_out_channels,
self.stage3_out_channels, self.stage4_out_channels
]

# Next Hybrid Strategy
self.stage1_block_types = [NCB] * depths[0]
self.stage2_block_types = [NCB] * (depths[1] - 1) + [NTB]
self.stage3_block_types = [NCB, NCB, NCB, NCB, NTB] * (depths[2] // 5)
self.stage4_block_types = [NCB] * (depths[3] - 1) + [NTB]
self.stage_block_types = [
self.stage1_block_types, self.stage2_block_types,
self.stage3_block_types, self.stage4_block_types
]

self.stem = nn.Sequential(
ConvBNReLU(3, stem_chs[0], kernel_size=3, stride=2),
ConvBNReLU(stem_chs[0], stem_chs[1], kernel_size=3, stride=1),
ConvBNReLU(stem_chs[1], stem_chs[2], kernel_size=3, stride=1),
ConvBNReLU(stem_chs[2], stem_chs[2], kernel_size=3, stride=2),
)
input_channel = stem_chs[-1]
features = []
idx = 0
dpr = [x.item() for x in torch.linspace(0, path_dropout, sum(depths))
] # stochastic depth decay rule
for stage_id in range(len(depths)):
numrepeat = depths[stage_id]
output_channels = self.stage_out_channels[stage_id]
block_types = self.stage_block_types[stage_id]
for block_id in range(numrepeat):
if strides[stage_id] == 2 and block_id == 0:
stride = 2
else:
stride = 1
output_channel = output_channels[block_id]
block_type = block_types[block_id]
if block_type is NCB:
layer = NCB(
input_channel,
output_channel,
stride=stride,
path_dropout=dpr[idx + block_id],
drop=drop,
head_dim=head_dim)
features.append(layer)
elif block_type is NTB:
layer = NTB(
input_channel,
output_channel,
path_dropout=dpr[idx + block_id],
stride=stride,
sr_ratio=sr_ratios[stage_id],
head_dim=head_dim,
mix_block_ratio=mix_block_ratio,
attn_drop=attn_drop,
drop=drop)
features.append(layer)
input_channel = output_channel
idx += numrepeat
self.features = nn.Sequential(*features)
self.norm = nn.BatchNorm2d(output_channel, eps=NORM_EPS)

if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = sum(depths) + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.stage_out_idx = out_indices

if norm_cfg is not None:
self = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)

def init_weights(self):
super(NextViT, self).init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress default init if use pretrained model.
return

self._initialize_weights()

def _initialize_weights(self):
for n, m in self.named_modules():
if isinstance(m, (nn.BatchNorm2d,
nn.BatchNorm1d)): # nn.GroupNorm, nn.LayerNorm,
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)

def forward(self, x):
outputs = list()
x = self.stem(x)
stage_id = 0
for idx, layer in enumerate(self.features):
x = layer(x)
if idx == self.stage_out_idx[stage_id]:
if self.with_extra_norm:
x = self.norm(x)
outputs.append(x)
stage_id += 1
return tuple(outputs)

def _freeze_stages(self):
if self.frozen_stages > 0:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
for idx, layer in enumerate(self.features):
if idx <= self.stage_out_idx[self.frozen_stages - 1]:
layer.eval()
for param in layer.parameters():
param.requires_grad = False

def train(self, mode=True):
super(NextViT, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()

+ 24
- 8
modelscope/models/cv/image_classification/mmcls_model.py View File

@@ -1,9 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from modelscope.utils.constant import ModelFile, Tasks


@MODELS.register_module(
@@ -13,16 +14,25 @@ class ClassificationModel(TorchModel):
def __init__(self, model_dir: str, **kwargs):
import mmcv
from mmcls.models import build_classifier
import modelscope.models.cv.image_classification.backbones
from modelscope.utils.hub import read_config

super().__init__(model_dir)

config = os.path.join(model_dir, 'config.py')

cfg = mmcv.Config.fromfile(config)
cfg.model.pretrained = None
self.cls_model = build_classifier(cfg.model)

self.config_type = 'ms_config'
mm_config = os.path.join(model_dir, 'config.py')
if os.path.exists(mm_config):
cfg = mmcv.Config.fromfile(mm_config)
cfg.model.pretrained = None
self.cls_model = build_classifier(cfg.model)
self.config_type = 'mmcv_config'
else:
cfg = read_config(model_dir)
cfg.model.mm_model.pretrained = None
self.cls_model = build_classifier(cfg.model.mm_model)
self.config_type = 'ms_config'
self.cfg = cfg

self.ms_model_dir = model_dir

self.load_pretrained_checkpoint()
@@ -33,7 +43,13 @@ class ClassificationModel(TorchModel):

def load_pretrained_checkpoint(self):
import mmcv
checkpoint_path = os.path.join(self.ms_model_dir, 'checkpoints.pth')
if os.path.exists(
os.path.join(self.ms_model_dir, ModelFile.TORCH_MODEL_FILE)):
checkpoint_path = os.path.join(self.ms_model_dir,
ModelFile.TORCH_MODEL_FILE)
else:
checkpoint_path = os.path.join(self.ms_model_dir,
'checkpoints.pth')
if os.path.exists(checkpoint_path):
checkpoint = mmcv.runner.load_checkpoint(
self.cls_model, checkpoint_path, map_location='cpu')


+ 100
- 0
modelscope/models/cv/image_classification/utils.py View File

@@ -0,0 +1,100 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp

import numpy as np
from mmcls.datasets.base_dataset import BaseDataset


def get_trained_checkpoints_name(work_path):
import os
file_list = os.listdir(work_path)
last = 0
model_name = None
# find the best model
if model_name is None:
for f_name in file_list:
if 'best_' in f_name and f_name.endswith('.pth'):
best_epoch = f_name.replace('.pth', '').split('_')[-1]
if best_epoch.isdigit():
last = int(best_epoch)
model_name = f_name
return model_name
# or find the latest model
if model_name is None:
for f_name in file_list:
if 'epoch_' in f_name and f_name.endswith('.pth'):
epoch_num = f_name.replace('epoch_', '').replace('.pth', '')
if not epoch_num.isdigit():
continue
ind = int(epoch_num)
if ind > last:
last = ind
model_name = f_name
return model_name


def preprocess_transform(cfgs):
if cfgs is None:
return None
for i, cfg in enumerate(cfgs):
if cfg.type == 'Resize':
if isinstance(cfg.size, list):
cfgs[i].size = tuple(cfg.size)
return cfgs


def get_ms_dataset_root(ms_dataset):
if ms_dataset is None or len(ms_dataset) < 1:
return None
try:
data_root = ms_dataset[0]['image:FILE'].split('extracted')[0]
path_post = ms_dataset[0]['image:FILE'].split('extracted')[1].split(
'/')
extracted_data_root = osp.join(data_root, 'extracted', path_post[1],
path_post[2])
return extracted_data_root
except Exception as e:
raise ValueError(f'Dataset Error: {e}')
return None


def get_classes(classes=None):
import mmcv
if isinstance(classes, str):
# take it as a file path
class_names = mmcv.list_from_file(classes)
elif isinstance(classes, (tuple, list)):
class_names = classes
else:
raise ValueError(f'Unsupported type {type(classes)} of classes.')

return class_names


class MmDataset(BaseDataset):

def __init__(self, ms_dataset, pipeline, classes=None, test_mode=False):
self.ms_dataset = ms_dataset
if len(self.ms_dataset) < 1:
raise ValueError('Dataset Error: dataset is empty')
super(MmDataset, self).__init__(
data_prefix='',
pipeline=pipeline,
classes=classes,
test_mode=test_mode)

def load_annotations(self):
if self.CLASSES is None:
raise ValueError(
f'Dataset Error: Not found classesname.txt: {self.CLASSES}')

data_infos = []
for data_info in self.ms_dataset:
filename = data_info['image:FILE']
gt_label = data_info['category']
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)

return data_infos

+ 18
- 6
modelscope/pipelines/cv/image_classification_pipeline.py View File

@@ -45,6 +45,9 @@ class ImageClassificationPipeline(Pipeline):
@PIPELINES.register_module(
Tasks.image_classification,
module_name=Pipelines.daily_image_classification)
@PIPELINES.register_module(
Tasks.image_classification,
module_name=Pipelines.nextvit_small_daily_image_classification)
class GeneralImageClassificationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
@@ -60,6 +63,7 @@ class GeneralImageClassificationPipeline(Pipeline):
def preprocess(self, input: Input) -> Dict[str, Any]:
from mmcls.datasets.pipelines import Compose
from mmcv.parallel import collate, scatter
from modelscope.models.cv.image_classification.utils import preprocess_transform
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
@@ -72,12 +76,20 @@ class GeneralImageClassificationPipeline(Pipeline):
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')

mmcls_cfg = self.model.cfg
# build the data pipeline
if mmcls_cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
mmcls_cfg.data.test.pipeline.pop(0)
data = dict(img=img)
test_pipeline = Compose(mmcls_cfg.data.test.pipeline)
cfg = self.model.cfg

if self.model.config_type == 'mmcv_config':
if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
cfg.data.test.pipeline.pop(0)
data = dict(img=img)
test_pipeline = Compose(cfg.data.test.pipeline)
else:
if cfg.preprocessor.val[0]['type'] == 'LoadImageFromFile':
cfg.preprocessor.val.pop(0)
data = dict(img=img)
data_pipeline = preprocess_transform(cfg.preprocessor.val)
test_pipeline = Compose(data_pipeline)

data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(self.model.parameters()).is_cuda:


+ 34
- 0
modelscope/preprocessors/image.py View File

@@ -289,3 +289,37 @@ class VideoSummarizationPreprocessor(Preprocessor):
Dict[str, Any]: the preprocessed data
"""
return data


@PREPROCESSORS.register_module(
Fields.cv,
module_name=Preprocessors.image_classification_bypass_preprocessor)
class ImageClassificationBypassPreprocessor(Preprocessor):

def __init__(self, *args, **kwargs):
"""image classification bypass preprocessor in the fine-tune scenario
"""
super().__init__(*args, **kwargs)

self.training = kwargs.pop('training', True)
self.preprocessor_train_cfg = kwargs.pop('train', None)
self.preprocessor_val_cfg = kwargs.pop('val', None)

def train(self):
self.training = True
return

def eval(self):
self.training = False
return

def __call__(self, results: Dict[str, Any]):
"""process the raw input data

Args:
results (dict): Result dict from loading pipeline.

Returns:
Dict[str, Any] | None: the preprocessed data
"""
pass

+ 502
- 0
modelscope/trainers/cv/image_classifition_trainer.py View File

@@ -0,0 +1,502 @@
# Part of the implementation is borrowed and modified from mmclassification,
# publicly available at https://github.com/open-mmlab/mmclassification
import copy
import os
import os.path as osp
import time
from typing import Callable, Dict, Optional, Tuple, Union

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models.base import TorchModel
from modelscope.msdatasets.ms_dataset import MsDataset
from modelscope.preprocessors.base import Preprocessor
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
from modelscope.utils.logger import get_logger


def train_model(model,
dataset,
cfg,
distributed=False,
val_dataset=None,
timestamp=None,
device=None,
meta=None):
import torch
import warnings
from mmcv.runner import (DistSamplerSeedHook, Fp16OptimizerHook,
build_optimizer, build_runner, get_dist_info)
from mmcls.core import DistEvalHook, DistOptimizerHook, EvalHook
from mmcls.datasets import build_dataloader
from mmcls.utils import (wrap_distributed_model,
wrap_non_distributed_model)
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel

logger = get_logger()

# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
sampler_cfg = cfg.train.get('sampler', None)

data_loaders = [
build_dataloader(
ds,
cfg.train.dataloader.batch_size_per_gpu,
cfg.train.dataloader.workers_per_gpu,
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
round_up=True,
seed=cfg.seed,
sampler_cfg=sampler_cfg) for ds in dataset
]

# put model on gpus
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if device == 'cpu':
logger.warning(
'The argument `device` is deprecated. To use cpu to train, '
'please refers to https://mmclassification.readthedocs.io/en'
'/latest/getting_started.html#train-a-model')
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
if not model.device_ids:
from mmcv import __version__, digit_version
assert digit_version(__version__) >= (1, 4, 4), \
'To train with CPU, please confirm your mmcv version ' \
'is not lower than v1.4.4'

# build runner
optimizer = build_optimizer(model, cfg.train.optimizer)

if cfg.train.get('runner') is None:
cfg.train.runner = {
'type': 'EpochBasedRunner',
'max_epochs': cfg.train.max_epochs
}
logger.warning(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)

runner = build_runner(
cfg.train.runner,
default_args=dict(
model=model,
batch_processor=None,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))

# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp

# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.train.optimizer_config, **fp16_cfg, distributed=distributed)
elif distributed and 'type' not in cfg.train.optimizer_config:
optimizer_config = DistOptimizerHook(**cfg.train.optimizer_config)
else:
optimizer_config = cfg.train.optimizer_config

# register hooks
runner.register_training_hooks(
cfg.train.lr_config,
optimizer_config,
cfg.train.checkpoint_config,
cfg.train.log_config,
cfg.train.get('momentum_config', None),
custom_hooks_config=cfg.train.get('custom_hooks', None))
if distributed and cfg.train.runner['type'] == 'EpochBasedRunner':
runner.register_hook(DistSamplerSeedHook())

# register eval hooks
if val_dataset is not None:
val_dataloader = build_dataloader(
val_dataset,
samples_per_gpu=cfg.evaluation.dataloader.batch_size_per_gpu,
workers_per_gpu=cfg.evaluation.dataloader.workers_per_gpu,
dist=distributed,
shuffle=False,
round_up=True)
eval_cfg = cfg.train.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.train.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
# `EvalHook` needs to be executed after `IterTimerHook`.
# Otherwise, it will cause a bug if use `IterBasedRunner`.
# Refers to https://github.com/open-mmlab/mmcv/issues/1261
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')

if cfg.train.resume_from:
runner.resume(cfg.train.resume_from, map_location='cpu')
elif cfg.train.load_from:
runner.load_checkpoint(cfg.train.load_from)

cfg.train.workflow = [tuple(flow) for flow in cfg.train.workflow]
runner.run(data_loaders, cfg.train.workflow)


@TRAINERS.register_module(module_name=Trainers.image_classification)
class ImageClassifitionTrainer(BaseTrainer):

def __init__(
self,
model: Optional[Union[TorchModel, nn.Module, str]] = None,
cfg_file: Optional[str] = None,
arg_parse_fn: Optional[Callable] = None,
data_collator: Optional[Union[Callable, Dict[str,
Callable]]] = None,
train_dataset: Optional[Union[MsDataset, Dataset]] = None,
eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
preprocessor: Optional[Union[Preprocessor,
Dict[str, Preprocessor]]] = None,
optimizers: Tuple[torch.optim.Optimizer,
torch.optim.lr_scheduler._LRScheduler] = (None,
None),
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
seed: int = 0,
cfg_modify_fn: Optional[Callable] = None,
**kwargs):
""" High-level finetune api for Image Classifition.

Args:
model: model id
model_version: model version, default is None.
cfg_modify_fn: An input fn which is used to modify the cfg read out of the file.
"""
import torch
import mmcv
from modelscope.models.cv.image_classification.utils import get_ms_dataset_root, get_classes
from mmcls.models import build_classifier
from mmcv.runner import get_dist_info, init_dist
from mmcls.apis import set_random_seed
from mmcls.utils import collect_env
import modelscope.models.cv.image_classification.backbones

self._seed = seed
set_random_seed(self._seed)
if isinstance(model, str):
if os.path.exists(model):
self.model_dir = model if os.path.isdir(
model) else os.path.dirname(model)
else:
self.model_dir = snapshot_download(
model, revision=model_revision)
if cfg_file is None:
cfg_file = os.path.join(self.model_dir,
ModelFile.CONFIGURATION)
else:
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
self.model_dir = os.path.dirname(cfg_file)

super().__init__(cfg_file, arg_parse_fn)
cfg = self.cfg

if 'work_dir' in kwargs:
self.work_dir = kwargs['work_dir']
else:
self.work_dir = self.cfg.train.get('work_dir', './work_dir')
mmcv.mkdir_or_exist(osp.abspath(self.work_dir))
cfg.work_dir = self.work_dir

# evaluate config seting
self.eval_checkpoint_path = os.path.join(self.model_dir,
ModelFile.TORCH_MODEL_FILE)

# train config seting
if 'resume_from' in kwargs:
cfg.train.resume_from = kwargs['resume_from']
else:
cfg.train.resume_from = cfg.train.get('resume_from', None)

if 'load_from' in kwargs:
cfg.train.load_from = kwargs['load_from']
else:
if cfg.train.get('resume_from', None) is None:
cfg.train.load_from = os.path.join(self.model_dir,
ModelFile.TORCH_MODEL_FILE)

if 'device' in kwargs:
cfg.device = kwargs['device']
else:
cfg.device = cfg.get('device', 'cuda')

if 'gpu_ids' in kwargs:
cfg.gpu_ids = kwargs['gpu_ids'][0:1]
else:
cfg.gpu_ids = [0]

if 'fp16' in kwargs:
cfg.fp16 = None if kwargs['fp16'] is None else kwargs['fp16']
else:
cfg.fp16 = None

# no_validate=True will not evaluate checkpoint during training
cfg.no_validate = kwargs.get('no_validate', False)

if cfg_modify_fn is not None:
cfg = cfg_modify_fn(cfg)

if 'max_epochs' not in kwargs:
assert hasattr(
self.cfg.train,
'max_epochs'), 'max_epochs is missing in configuration file'
self.max_epochs = self.cfg.train.max_epochs
else:
self.max_epochs = kwargs['max_epochs']
cfg.train.max_epochs = self.max_epochs
if cfg.train.get('runner', None) is not None:
cfg.train.runner.max_epochs = self.max_epochs

if 'launcher' in kwargs:
distributed = True
dist_params = kwargs['dist_params'] \
if 'dist_params' in kwargs else {'backend': 'nccl'}
init_dist(kwargs['launcher'], **dist_params)
# re-set gpu_ids with distributed training mode
_, world_size = get_dist_info()
cfg.gpu_ids = list(range(world_size))
else:
distributed = False

# init the logger before other steps
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(self.work_dir, f'{timestamp}.log')
logger = get_logger(log_file=log_file)

# init the meta dict to record some important information such as
# environment info and seed, which will be logged
meta = dict()
# log env info
env_info_dict = collect_env()
env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
dash_line = '-' * 60 + '\n'
logger.info('Environment info:\n' + dash_line + env_info + '\n'
+ dash_line)
meta['env_info'] = env_info
meta['config'] = cfg.pretty_text
# log some basic info
logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')

# set random seeds
cfg.seed = self._seed
_deterministic = kwargs.get('deterministic', False)
logger.info(f'Set random seed to {cfg.seed}, '
f'deterministic: {_deterministic}')
set_random_seed(cfg.seed, deterministic=_deterministic)

meta['seed'] = cfg.seed
meta['exp_name'] = osp.basename(cfg_file)

# dataset
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset

# model
model = build_classifier(self.cfg.model.mm_model)
model.init_weights()

self.cfg = cfg
self.device = cfg.device
self.cfg_file = cfg_file
self.model = model
self.distributed = distributed
self.timestamp = timestamp
self.meta = meta
self.logger = logger

def train(self, *args, **kwargs):
from mmcls import __version__
from modelscope.models.cv.image_classification.utils import get_ms_dataset_root, MmDataset, preprocess_transform
from mmcls.utils import setup_multi_processes

if self.train_dataset is None:
raise ValueError(
"Not found train dataset, please set the 'train_dataset' parameter!"
)

self.cfg.model.mm_model.pretrained = None

# dump config
self.cfg.dump(osp.join(self.work_dir, osp.basename(self.cfg_file)))

# build the dataloader
if self.cfg.dataset.classes is None:
data_root = get_ms_dataset_root(self.train_dataset)
classname_path = osp.join(data_root, 'classname.txt')
classes = classname_path if osp.exists(classname_path) else None
else:
classes = cfg.dataset.classes

datasets = [
MmDataset(
self.train_dataset,
pipeline=self.cfg.preprocessor.train,
classes=classes)
]

if len(self.cfg.train.workflow) == 2:
if self.eval_dataset is None:
raise ValueError(
"Not found evaluate dataset, please set the 'eval_dataset' parameter!"
)
val_data_pipeline = self.cfg.preprocessor.train
val_dataset = MmDataset(
self.eval_dataset, pipeline=val_data_pipeline, classes=classes)
datasets.append(val_dataset)

# save mmcls version, config file content and class names in
# checkpoints as meta data
self.meta.update(
dict(
mmcls_version=__version__,
config=self.cfg.pretty_text,
CLASSES=datasets[0].CLASSES))

val_dataset = None
if not self.cfg.no_validate:
val_dataset = MmDataset(
self.eval_dataset,
pipeline=preprocess_transform(self.cfg.preprocessor.val),
classes=classes)

# add an attribute for visualization convenience
train_model(
self.model,
datasets,
self.cfg,
distributed=self.distributed,
val_dataset=val_dataset,
timestamp=self.timestamp,
device='cpu' if self.device == 'cpu' else 'cuda',
meta=self.meta)

def evaluate(self,
checkpoint_path: str = None,
*args,
**kwargs) -> Dict[str, float]:
import warnings
import torch
from modelscope.models.cv.image_classification.utils import (
get_ms_dataset_root, MmDataset, preprocess_transform,
get_trained_checkpoints_name)
from mmcls.datasets import build_dataloader
from mmcv.runner import get_dist_info, load_checkpoint, wrap_fp16_model
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcls.apis import multi_gpu_test, single_gpu_test
from mmcls.utils import setup_multi_processes

if self.eval_dataset is None:
raise ValueError(
"Not found evaluate dataset, please set the 'eval_dataset' parameter!"
)

self.cfg.model.mm_model.pretrained = None

# build the dataloader
if self.cfg.dataset.classes is None:
data_root = get_ms_dataset_root(self.eval_dataset)
classname_path = osp.join(data_root, 'classname.txt')
classes = classname_path if osp.exists(classname_path) else None
else:
classes = cfg.dataset.classes
dataset = MmDataset(
self.eval_dataset,
pipeline=preprocess_transform(self.cfg.preprocessor.val),
classes=classes)
# the extra round_up data will be removed during gpu/cpu collect
data_loader = build_dataloader(
dataset,
samples_per_gpu=self.cfg.evaluation.dataloader.batch_size_per_gpu,
workers_per_gpu=self.cfg.evaluation.dataloader.workers_per_gpu,
dist=self.distributed,
shuffle=False,
round_up=True)

model = copy.deepcopy(self.model)
fp16_cfg = self.cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
if checkpoint_path is None:
trained_checkpoints = get_trained_checkpoints_name(self.work_dir)
if trained_checkpoints is not None:
checkpoint = load_checkpoint(
model,
os.path.join(self.work_dir, trained_checkpoints),
map_location='cpu')
else:
checkpoint = load_checkpoint(
model, self.eval_checkpoint_path, map_location='cpu')
else:
checkpoint = load_checkpoint(
model, checkpoint_path, map_location='cpu')

if 'CLASSES' in checkpoint.get('meta', {}):
CLASSES = checkpoint['meta']['CLASSES']
else:
from mmcls.datasets import ImageNet
self.logger.warning(
'Class names are not saved in the checkpoint\'s '
'meta data, use imagenet by default.')
CLASSES = ImageNet.CLASSES

if not self.distributed:
if self.device == 'cpu':
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=self.cfg.gpu_ids)
if not model.device_ids:
assert mmcv.digit_version(mmcv.__version__) >= (1, 4, 4), \
'To test with CPU, please confirm your mmcv version ' \
'is not lower than v1.4.4'
model.CLASSES = CLASSES
show_kwargs = {}
outputs = single_gpu_test(model, data_loader, False, None,
**show_kwargs)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
outputs = multi_gpu_test(model, data_loader, None, True)

rank, _ = get_dist_info()
if rank == 0:
results = {}
logger = get_logger()
metric_options = self.cfg.evaluation.get('metric_options', {})
if 'topk' in metric_options.keys():
metric_options['topk'] = tuple(metric_options['topk'])
if self.cfg.evaluation.metrics:
eval_results = dataset.evaluate(
results=outputs,
metric=self.cfg.evaluation.metrics,
metric_options=metric_options,
logger=logger)
results.update(eval_results)

return results

return None

+ 9
- 0
tests/pipelines/test_general_image_classification.py View File

@@ -31,6 +31,15 @@ class GeneralImageClassificationTest(unittest.TestCase,
result = general_image_classification('data/test/images/bird.JPEG')
print(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_nextvit(self):
nexit_image_classification = pipeline(
Tasks.image_classification,
model='damo/cv_nextvit-small_image-classification_Dailylife-labels'
)
result = nexit_image_classification('data/test/images/bird.JPEG')
print(result)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_Dailylife_default(self):
general_image_classification = pipeline(Tasks.image_classification)


+ 96
- 0
tests/trainers/test_general_image_classification_trainer.py View File

@@ -0,0 +1,96 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import zipfile
from functools import partial

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.test_utils import test_level


class TestGeneralImageClassificationTestTrainer(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))

try:
self.train_dataset = MsDataset.load(
'cats_and_dogs',
namespace='tany0699',
subset_name='default',
split='train')

self.eval_dataset = MsDataset.load(
'cats_and_dogs',
namespace='tany0699',
subset_name='default',
split='validation')
except Exception as e:
print(f'Download dataset error: {e}')

self.max_epochs = 1

self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_nextvit_dailylife_train(self):
model_id = 'damo/cv_nextvit-small_image-classification_Dailylife-labels'

def cfg_modify_fn(cfg):
cfg.train.dataloader.batch_size_per_gpu = 32
cfg.train.dataloader.workers_per_gpu = 1
cfg.train.max_epochs = self.max_epochs
cfg.model.mm_model.head.num_classes = 2
cfg.train.optimizer.lr = 1e-4
cfg.train.lr_config.warmup_iters = 1
cfg.train.evaluation.metric_options = {'topk': (1, )}
cfg.evaluation.metric_options = {'topk': (1, )}
return cfg

kwargs = dict(
model=model_id,
work_dir=self.tmp_dir,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
cfg_modify_fn=cfg_modify_fn)

trainer = build_trainer(
name=Trainers.image_classification, default_args=kwargs)
trainer.train()

results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_nextvit_dailylife_eval(self):
model_id = 'damo/cv_nextvit-small_image-classification_Dailylife-labels'

kwargs = dict(
model=model_id,
work_dir=self.tmp_dir,
train_dataset=None,
eval_dataset=self.eval_dataset)

trainer = build_trainer(
name=Trainers.image_classification, default_args=kwargs)
result = trainer.evaluate()
print(result)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save