diff --git a/.gitattributes b/.gitattributes index 0d4c368e..88ef2f44 100644 --- a/.gitattributes +++ b/.gitattributes @@ -3,3 +3,4 @@ *.mp4 filter=lfs diff=lfs merge=lfs -text *.wav filter=lfs diff=lfs merge=lfs -text *.JPEG filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text diff --git a/data/test/images/image_wolf.jpeg b/data/test/images/image_wolf.jpeg new file mode 100644 index 00000000..32d0c567 --- /dev/null +++ b/data/test/images/image_wolf.jpeg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cbe3c719d25c2c90349c3c280e74f46f315a490443655ceba8b8a203af0f7259 +size 171378 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index b32fed0d..7cb064ae 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -102,6 +102,7 @@ class Pipelines(object): image_portrait_enhancement = 'gpen-image-portrait-enhancement' image_to_image_generation = 'image-to-image-generation' skin_retouching = 'unet-skin-retouching' + tinynas_classification = 'tinynas-classification' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/tinynas_classfication/__init__.py b/modelscope/models/cv/tinynas_classfication/__init__.py new file mode 100644 index 00000000..6c2f89ee --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model_zoo import get_zennet + +else: + _import_structure = { + 'model_zoo': ['get_zennet'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/tinynas_classfication/basic_blocks.py b/modelscope/models/cv/tinynas_classfication/basic_blocks.py new file mode 100644 index 00000000..50548dcc --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/basic_blocks.py @@ -0,0 +1,1309 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import uuid + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from .global_utils import (create_netblock_list_from_str_inner, + get_right_parentheses_index) + + +class PlainNetBasicBlockClass(nn.Module): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=1, + no_create=False, + block_name=None, + **kwargs): + super(PlainNetBasicBlockClass, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.no_create = no_create + self.block_name = block_name + if self.block_name is None: + self.block_name = 'uuid{}'.format(uuid.uuid4().hex) + + def forward(self, x): + raise RuntimeError('Not implemented') + + def __str__(self): + return type(self).__name__ + '({},{},{})'.format( + self.in_channels, self.out_channels, self.stride) + + def __repr__(self): + return type(self).__name__ + '({}|{},{},{})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride) + + def get_output_resolution(self, input_resolution): + raise RuntimeError('Not implemented') + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert PlainNetBasicBlockClass.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + @classmethod + def is_instance_from_str(cls, s): + if s.startswith(cls.__name__ + '(') and s[-1] == ')': + return True + else: + return False + + +class AdaptiveAvgPool(PlainNetBasicBlockClass): + + def __init__(self, out_channels, output_size, no_create=False, **kwargs): + super(AdaptiveAvgPool, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.output_size = output_size + self.no_create = no_create + if not no_create: + self.netblock = nn.AdaptiveAvgPool2d( + output_size=(self.output_size, self.output_size)) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return type(self).__name__ + '({},{})'.format( + self.out_channels // self.output_size**2, self.output_size) + + def __repr__(self): + return type(self).__name__ + '({}|{},{})'.format( + self.block_name, self.out_channels // self.output_size**2, + self.output_size) + + def get_output_resolution(self, input_resolution): + return self.output_size + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert AdaptiveAvgPool.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('AdaptiveAvgPool('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + out_channels = int(param_str_split[0]) + output_size = int(param_str_split[1]) + return AdaptiveAvgPool( + out_channels=out_channels, + output_size=output_size, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + +class BN(PlainNetBasicBlockClass): + + def __init__(self, + out_channels=None, + copy_from=None, + no_create=False, + **kwargs): + super(BN, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.BatchNorm2d) + self.in_channels = copy_from.weight.shape[0] + self.out_channels = copy_from.weight.shape[0] + assert out_channels is None or out_channels == self.out_channels + self.netblock = copy_from + + else: + self.in_channels = out_channels + self.out_channels = out_channels + if no_create: + return + else: + self.netblock = nn.BatchNorm2d(num_features=self.out_channels) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'BN({})'.format(self.out_channels) + + def __repr__(self): + return 'BN({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert BN.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('BN('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + out_channels = int(param_str) + return BN( + out_channels=out_channels, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + +class ConvKX(PlainNetBasicBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + groups=1, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKX, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.Conv2d) + self.in_channels = copy_from.in_channels + self.out_channels = copy_from.out_channels + self.kernel_size = copy_from.kernel_size[0] + self.stride = copy_from.stride[0] + self.groups = copy_from.groups + assert in_channels is None or in_channels == self.in_channels + assert out_channels is None or out_channels == self.out_channels + assert kernel_size is None or kernel_size == self.kernel_size + assert stride is None or stride == self.stride + self.netblock = copy_from + else: + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.groups = groups + self.kernel_size = kernel_size + self.padding = (self.kernel_size - 1) // 2 + if no_create or self.in_channels == 0 or self.out_channels == 0 or self.kernel_size == 0 \ + or self.stride == 0: + return + else: + self.netblock = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + bias=False, + groups=self.groups) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return type(self).__name__ + '({},{},{},{})'.format( + self.in_channels, self.out_channels, self.kernel_size, self.stride) + + def __repr__(self): + return type(self).__name__ + '({}|{},{},{},{})'.format( + self.block_name, self.in_channels, self.out_channels, + self.kernel_size, self.stride) + + def get_output_resolution(self, input_resolution): + return input_resolution // self.stride + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert cls.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + split_str = param_str.split(',') + in_channels = int(split_str[0]) + out_channels = int(split_str[1]) + kernel_size = int(split_str[2]) + stride = int(split_str[3]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class ConvDW(PlainNetBasicBlockClass): + + def __init__(self, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvDW, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.Conv2d) + self.in_channels = copy_from.in_channels + self.out_channels = copy_from.out_channels + self.kernel_size = copy_from.kernel_size[0] + self.stride = copy_from.stride[0] + assert self.in_channels == self.out_channels + assert out_channels is None or out_channels == self.out_channels + assert kernel_size is None or kernel_size == self.kernel_size + assert stride is None or stride == self.stride + + self.netblock = copy_from + else: + + self.in_channels = out_channels + self.out_channels = out_channels + self.stride = stride + self.kernel_size = kernel_size + + self.padding = (self.kernel_size - 1) // 2 + if no_create or self.in_channels == 0 or self.out_channels == 0 or self.kernel_size == 0 \ + or self.stride == 0: + return + else: + self.netblock = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + bias=False, + groups=self.in_channels) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'ConvDW({},{},{})'.format(self.out_channels, self.kernel_size, + self.stride) + + def __repr__(self): + return 'ConvDW({}|{},{},{})'.format(self.block_name, self.out_channels, + self.kernel_size, self.stride) + + def get_output_resolution(self, input_resolution): + return input_resolution // self.stride + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert ConvDW.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('ConvDW('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + split_str = param_str.split(',') + out_channels = int(split_str[0]) + kernel_size = int(split_str[1]) + stride = int(split_str[2]) + return ConvDW( + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class ConvKXG2(ConvKX): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKXG2, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + copy_from=copy_from, + no_create=no_create, + groups=2, + **kwargs) + + +class ConvKXG4(ConvKX): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKXG4, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + copy_from=copy_from, + no_create=no_create, + groups=4, + **kwargs) + + +class ConvKXG8(ConvKX): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKXG8, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + copy_from=copy_from, + no_create=no_create, + groups=8, + **kwargs) + + +class ConvKXG16(ConvKX): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKXG16, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + copy_from=copy_from, + no_create=no_create, + groups=16, + **kwargs) + + +class ConvKXG32(ConvKX): + + def __init__(self, + in_channels=None, + out_channels=None, + kernel_size=None, + stride=None, + copy_from=None, + no_create=False, + **kwargs): + super(ConvKXG32, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + copy_from=copy_from, + no_create=no_create, + groups=32, + **kwargs) + + +class Flatten(PlainNetBasicBlockClass): + + def __init__(self, out_channels, no_create=False, **kwargs): + super(Flatten, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.no_create = no_create + + def forward(self, x): + return torch.flatten(x, 1) + + def __str__(self): + return 'Flatten({})'.format(self.out_channels) + + def __repr__(self): + return 'Flatten({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return 1 + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Flatten.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('Flatten('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + return Flatten( + out_channels=out_channels, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class Linear(PlainNetBasicBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + bias=True, + copy_from=None, + no_create=False, + **kwargs): + super(Linear, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + assert isinstance(copy_from, nn.Linear) + self.in_channels = copy_from.weight.shape[1] + self.out_channels = copy_from.weight.shape[0] + self.use_bias = copy_from.bias is not None + assert in_channels is None or in_channels == self.in_channels + assert out_channels is None or out_channels == self.out_channels + + self.netblock = copy_from + else: + + self.in_channels = in_channels + self.out_channels = out_channels + self.use_bias = bias + if not no_create: + self.netblock = nn.Linear( + self.in_channels, self.out_channels, bias=self.use_bias) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'Linear({},{},{})'.format(self.in_channels, self.out_channels, + int(self.use_bias)) + + def __repr__(self): + return 'Linear({}|{},{},{})'.format(self.block_name, self.in_channels, + self.out_channels, + int(self.use_bias)) + + def get_output_resolution(self, input_resolution): + assert input_resolution == 1 + return 1 + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Linear.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('Linear('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + split_str = param_str.split(',') + in_channels = int(split_str[0]) + out_channels = int(split_str[1]) + use_bias = int(split_str[2]) + + return Linear( + in_channels=in_channels, + out_channels=out_channels, + bias=use_bias == 1, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + +class MaxPool(PlainNetBasicBlockClass): + + def __init__(self, + out_channels, + kernel_size, + stride, + no_create=False, + **kwargs): + super(MaxPool, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = (kernel_size - 1) // 2 + self.no_create = no_create + if not no_create: + self.netblock = nn.MaxPool2d( + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding) + + def forward(self, x): + return self.netblock(x) + + def __str__(self): + return 'MaxPool({},{},{})'.format(self.out_channels, self.kernel_size, + self.stride) + + def __repr__(self): + return 'MaxPool({}|{},{},{})'.format(self.block_name, + self.out_channels, + self.kernel_size, self.stride) + + def get_output_resolution(self, input_resolution): + return input_resolution // self.stride + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert MaxPool.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('MaxPool('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + out_channels = int(param_str_split[0]) + kernel_size = int(param_str_split[1]) + stride = int(param_str_split[2]) + return MaxPool( + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class Sequential(PlainNetBasicBlockClass): + + def __init__(self, block_list, no_create=False, **kwargs): + super(Sequential, self).__init__(**kwargs) + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) + self.in_channels = block_list[0].in_channels + self.out_channels = block_list[-1].out_channels + self.no_create = no_create + res = 1024 + for block in self.block_list: + res = block.get_output_resolution(res) + self.stride = 1024 // res + + def forward(self, x): + output = x + for inner_block in self.block_list: + output = inner_block(output) + return output + + def __str__(self): + s = 'Sequential(' + for inner_block in self.block_list: + s += str(inner_block) + s += ')' + return s + + def __repr__(self): + return str(self) + + def get_output_resolution(self, input_resolution): + the_res = input_resolution + for the_block in self.block_list: + the_res = the_block.get_output_resolution(the_res) + return the_res + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Sequential.is_instance_from_str(s) + the_right_paraen_idx = get_right_parentheses_index(s) + param_str = s[len('Sequential(') + 1:the_right_paraen_idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, netblocks_dict=bottom_basic_dict, no_create=no_create) + assert len(remaining_s) == 0 + if the_block_list is None or len(the_block_list) == 0: + return None, '' + return Sequential( + block_list=the_block_list, + no_create=no_create, + block_name=tmp_block_name), '' + + +class MultiSumBlock(PlainNetBasicBlockClass): + + def __init__(self, block_list, no_create=False, **kwargs): + super(MultiSumBlock, self).__init__(**kwargs) + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) + self.in_channels = np.max([x.in_channels for x in block_list]) + self.out_channels = np.max([x.out_channels for x in block_list]) + self.no_create = no_create + + res = 1024 + res = self.block_list[0].get_output_resolution(res) + self.stride = 1024 // res + + def forward(self, x): + output = self.block_list[0](x) + for inner_block in self.block_list[1:]: + output2 = inner_block(x) + output = output + output2 + return output + + def __str__(self): + s = 'MultiSumBlock({}|'.format(self.block_name) + for inner_block in self.block_list: + s += str(inner_block) + ';' + s = s[:-1] + s += ')' + return s + + def __repr__(self): + return str(self) + + def get_output_resolution(self, input_resolution): + the_res = self.block_list[0].get_output_resolution(input_resolution) + for the_block in self.block_list: + assert the_res == the_block.get_output_resolution(input_resolution) + + return the_res + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert MultiSumBlock.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('MultiSumBlock('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + the_s = param_str + + the_block_list = [] + while len(the_s) > 0: + tmp_block_list, remaining_s = create_netblock_list_from_str_inner( + the_s, netblocks_dict=bottom_basic_dict, no_create=no_create) + the_s = remaining_s + if tmp_block_list is None: + pass + elif len(tmp_block_list) == 1: + the_block_list.append(tmp_block_list[0]) + else: + the_block_list.append( + Sequential(block_list=tmp_block_list, no_create=no_create)) + pass + + if len(the_block_list) == 0: + return None, s[idx + 1:] + + return MultiSumBlock( + block_list=the_block_list, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + +class MultiCatBlock(PlainNetBasicBlockClass): + + def __init__(self, block_list, no_create=False, **kwargs): + super(MultiCatBlock, self).__init__(**kwargs) + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) + self.in_channels = np.max([x.in_channels for x in block_list]) + self.out_channels = np.sum([x.out_channels for x in block_list]) + self.no_create = no_create + + res = 1024 + res = self.block_list[0].get_output_resolution(res) + self.stride = 1024 // res + + def forward(self, x): + output_list = [] + for inner_block in self.block_list: + output = inner_block(x) + output_list.append(output) + + return torch.cat(output_list, dim=1) + + def __str__(self): + s = 'MultiCatBlock({}|'.format(self.block_name) + for inner_block in self.block_list: + s += str(inner_block) + ';' + + s = s[:-1] + s += ')' + return s + + def __repr__(self): + return str(self) + + def get_output_resolution(self, input_resolution): + the_res = self.block_list[0].get_output_resolution(input_resolution) + for the_block in self.block_list: + assert the_res == the_block.get_output_resolution(input_resolution) + + return the_res + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert MultiCatBlock.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('MultiCatBlock('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + the_s = param_str + + the_block_list = [] + while len(the_s) > 0: + tmp_block_list, remaining_s = create_netblock_list_from_str_inner( + the_s, netblocks_dict=bottom_basic_dict, no_create=no_create) + the_s = remaining_s + if tmp_block_list is None: + pass + elif len(tmp_block_list) == 1: + the_block_list.append(tmp_block_list[0]) + else: + the_block_list.append( + Sequential(block_list=tmp_block_list, no_create=no_create)) + + if len(the_block_list) == 0: + return None, s[idx + 1:] + + return MultiCatBlock( + block_list=the_block_list, + block_name=tmp_block_name, + no_create=no_create), s[idx + 1:] + + +class RELU(PlainNetBasicBlockClass): + + def __init__(self, out_channels, no_create=False, **kwargs): + super(RELU, self).__init__(**kwargs) + self.in_channels = out_channels + self.out_channels = out_channels + self.no_create = no_create + + def forward(self, x): + return F.relu(x) + + def __str__(self): + return 'RELU({})'.format(self.out_channels) + + def __repr__(self): + return 'RELU({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert RELU.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('RELU('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + return RELU( + out_channels=out_channels, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class ResBlock(PlainNetBasicBlockClass): + """ + ResBlock(in_channles, inner_blocks_str). If in_channels is missing, use block_list[0].in_channels as in_channels + """ + + def __init__(self, + block_list, + in_channels=None, + stride=None, + no_create=False, + **kwargs): + super(ResBlock, self).__init__(**kwargs) + self.block_list = block_list + self.stride = stride + self.no_create = no_create + if not no_create: + self.module_list = nn.ModuleList(block_list) + + if in_channels is None: + self.in_channels = block_list[0].in_channels + else: + self.in_channels = in_channels + self.out_channels = block_list[-1].out_channels + + if self.stride is None: + tmp_input_res = 1024 + tmp_output_res = self.get_output_resolution(tmp_input_res) + self.stride = tmp_input_res // tmp_output_res + + self.proj = None + if self.stride > 1 or self.in_channels != self.out_channels: + self.proj = nn.Sequential( + nn.Conv2d(self.in_channels, self.out_channels, 1, self.stride), + nn.BatchNorm2d(self.out_channels), + ) + + def forward(self, x): + if len(self.block_list) == 0: + return x + + output = x + for inner_block in self.block_list: + output = inner_block(output) + + if self.proj is not None: + output = output + self.proj(x) + else: + output = output + x + + return output + + def __str__(self): + s = 'ResBlock({},{},'.format(self.in_channels, self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def __repr__(self): + s = 'ResBlock({}|{},{},'.format(self.block_name, self.in_channels, + self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def get_output_resolution(self, input_resolution): + the_res = input_resolution + for the_block in self.block_list: + the_res = the_block.get_output_resolution(the_res) + + return the_res + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert ResBlock.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + the_stride = None + param_str = s[len('ResBlock('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + first_comma_index = param_str.find(',') + if first_comma_index < 0 or not param_str[0:first_comma_index].isdigit( + ): + in_channels = None + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + else: + in_channels = int(param_str[0:first_comma_index]) + param_str = param_str[first_comma_index + 1:] + second_comma_index = param_str.find(',') + if second_comma_index < 0 or not param_str[ + 0:second_comma_index].isdigit(): + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + else: + the_stride = int(param_str[0:second_comma_index]) + param_str = param_str[second_comma_index + 1:] + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + pass + pass + + assert len(remaining_s) == 0 + if the_block_list is None or len(the_block_list) == 0: + return None, s[idx + 1:] + return ResBlock( + block_list=the_block_list, + in_channels=in_channels, + stride=the_stride, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class ResBlockProj(PlainNetBasicBlockClass): + """ + ResBlockProj(in_channles, inner_blocks_str). If in_channels is missing, use block_list[0].in_channels as in_channels + """ + + def __init__(self, + block_list, + in_channels=None, + stride=None, + no_create=False, + **kwargs): + super(ResBlockProj, self).__init__(**kwargs) + self.block_list = block_list + self.stride = stride + self.no_create = no_create + if not no_create: + self.module_list = nn.ModuleList(block_list) + + if in_channels is None: + self.in_channels = block_list[0].in_channels + else: + self.in_channels = in_channels + self.out_channels = block_list[-1].out_channels + + if self.stride is None: + tmp_input_res = 1024 + tmp_output_res = self.get_output_resolution(tmp_input_res) + self.stride = tmp_input_res // tmp_output_res + + self.proj = nn.Sequential( + nn.Conv2d(self.in_channels, self.out_channels, 1, self.stride), + nn.BatchNorm2d(self.out_channels), + ) + + def forward(self, x): + if len(self.block_list) == 0: + return x + + output = x + for inner_block in self.block_list: + output = inner_block(output) + output = output + self.proj(x) + return output + + def __str__(self): + s = 'ResBlockProj({},{},'.format(self.in_channels, self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def __repr__(self): + s = 'ResBlockProj({}|{},{},'.format(self.block_name, self.in_channels, + self.stride) + for inner_block in self.block_list: + s += str(inner_block) + + s += ')' + return s + + def get_output_resolution(self, input_resolution): + the_res = input_resolution + for the_block in self.block_list: + the_res = the_block.get_output_resolution(the_res) + + return the_res + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert ResBlockProj.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + the_stride = None + param_str = s[len('ResBlockProj('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + first_comma_index = param_str.find(',') + if first_comma_index < 0 or not param_str[0:first_comma_index].isdigit( + ): + in_channels = None + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + else: + in_channels = int(param_str[0:first_comma_index]) + param_str = param_str[first_comma_index + 1:] + second_comma_index = param_str.find(',') + if second_comma_index < 0 or not param_str[ + 0:second_comma_index].isdigit(): + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + else: + the_stride = int(param_str[0:second_comma_index]) + param_str = param_str[second_comma_index + 1:] + the_block_list, remaining_s = create_netblock_list_from_str_inner( + param_str, + netblocks_dict=bottom_basic_dict, + no_create=no_create) + pass + pass + + assert len(remaining_s) == 0 + if the_block_list is None or len(the_block_list) == 0: + return None, s[idx + 1:] + return ResBlockProj( + block_list=the_block_list, + in_channels=in_channels, + stride=the_stride, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class SE(PlainNetBasicBlockClass): + + def __init__(self, + out_channels=None, + copy_from=None, + no_create=False, + **kwargs): + super(SE, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + raise RuntimeError('Not implemented') + else: + self.in_channels = out_channels + self.out_channels = out_channels + self.se_ratio = 0.25 + self.se_channels = max( + 1, int(round(self.out_channels * self.se_ratio))) + if no_create or self.out_channels == 0: + return + else: + self.netblock = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d( + in_channels=self.out_channels, + out_channels=self.se_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), nn.BatchNorm2d(self.se_channels), + nn.ReLU(), + nn.Conv2d( + in_channels=self.se_channels, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), nn.BatchNorm2d(self.out_channels), + nn.Sigmoid()) + + def forward(self, x): + se_x = self.netblock(x) + return se_x * x + + def __str__(self): + return 'SE({})'.format(self.out_channels) + + def __repr__(self): + return 'SE({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert SE.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('SE('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + return SE( + out_channels=out_channels, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +class SwishImplementation(torch.autograd.Function): + + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class Swish(PlainNetBasicBlockClass): + + def __init__(self, + out_channels=None, + copy_from=None, + no_create=False, + **kwargs): + super(Swish, self).__init__(**kwargs) + self.no_create = no_create + + if copy_from is not None: + raise RuntimeError('Not implemented') + else: + self.in_channels = out_channels + self.out_channels = out_channels + + def forward(self, x): + return SwishImplementation.apply(x) + + def __str__(self): + return 'Swish({})'.format(self.out_channels) + + def __repr__(self): + return 'Swish({}|{})'.format(self.block_name, self.out_channels) + + def get_output_resolution(self, input_resolution): + return input_resolution + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert Swish.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len('Swish('):idx] + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + out_channels = int(param_str) + return Swish( + out_channels=out_channels, + no_create=no_create, + block_name=tmp_block_name), s[idx + 1:] + + +bottom_basic_dict = { + 'AdaptiveAvgPool': AdaptiveAvgPool, + 'BN': BN, + 'ConvDW': ConvDW, + 'ConvKX': ConvKX, + 'ConvKXG2': ConvKXG2, + 'ConvKXG4': ConvKXG4, + 'ConvKXG8': ConvKXG8, + 'ConvKXG16': ConvKXG16, + 'ConvKXG32': ConvKXG32, + 'Flatten': Flatten, + 'Linear': Linear, + 'MaxPool': MaxPool, + 'PlainNetBasicBlockClass': PlainNetBasicBlockClass, + 'RELU': RELU, + 'SE': SE, + 'Swish': Swish, +} + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'MultiSumBlock': MultiSumBlock, + 'MultiCatBlock': MultiCatBlock, + 'ResBlock': ResBlock, + 'ResBlockProj': ResBlockProj, + 'Sequential': Sequential, + } + netblocks_dict.update(this_py_file_netblocks_dict) + netblocks_dict.update(bottom_basic_dict) + return netblocks_dict diff --git a/modelscope/models/cv/tinynas_classfication/global_utils.py b/modelscope/models/cv/tinynas_classfication/global_utils.py new file mode 100644 index 00000000..022c61a0 --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/global_utils.py @@ -0,0 +1,65 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + + +def smart_round(x, base=None): + if base is None: + if x > 32 * 8: + round_base = 32 + elif x > 16 * 8: + round_base = 16 + else: + round_base = 8 + else: + round_base = base + + return max(round_base, round(x / float(round_base)) * round_base) + + +def get_right_parentheses_index(s): + left_paren_count = 0 + for index, x in enumerate(s): + + if x == '(': + left_paren_count += 1 + elif x == ')': + left_paren_count -= 1 + if left_paren_count == 0: + return index + else: + pass + return None + + +def create_netblock_list_from_str_inner(s, + no_create=False, + netblocks_dict=None, + **kwargs): + block_list = [] + while len(s) > 0: + is_found_block_class = False + for the_block_class_name in netblocks_dict.keys(): + tmp_idx = s.find('(') + if tmp_idx > 0 and s[0:tmp_idx] == the_block_class_name: + is_found_block_class = True + the_block_class = netblocks_dict[the_block_class_name] + the_block, remaining_s = the_block_class.create_from_str( + s, no_create=no_create, **kwargs) + if the_block is not None: + block_list.append(the_block) + s = remaining_s + if len(s) > 0 and s[0] == ';': + return block_list, s[1:] + break + assert is_found_block_class + return block_list, '' + + +def create_netblock_list_from_str(s, + no_create=False, + netblocks_dict=None, + **kwargs): + the_list, remaining_s = create_netblock_list_from_str_inner( + s, no_create=no_create, netblocks_dict=netblocks_dict, **kwargs) + assert len(remaining_s) == 0 + return the_list diff --git a/modelscope/models/cv/tinynas_classfication/master_net.py b/modelscope/models/cv/tinynas_classfication/master_net.py new file mode 100644 index 00000000..e2bc47e0 --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/master_net.py @@ -0,0 +1,94 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import torch +import torch.nn.functional as F +from torch import nn + +from . import basic_blocks, plain_net_utils + + +class PlainNet(plain_net_utils.PlainNet): + + def __init__(self, + argv=None, + opt=None, + num_classes=None, + plainnet_struct=None, + no_create=False, + no_reslink=None, + no_BN=None, + use_se=None, + dropout=None, + **kwargs): + + module_opt = None + + if no_BN is None: + if module_opt is not None: + no_BN = module_opt.no_BN + else: + no_BN = False + + if no_reslink is None: + if module_opt is not None: + no_reslink = module_opt.no_reslink + else: + no_reslink = False + + if use_se is None: + if module_opt is not None: + use_se = module_opt.use_se + else: + use_se = False + + if dropout is None: + if module_opt is not None: + self.dropout = module_opt.dropout + else: + self.dropout = None + else: + self.dropout = dropout + + super(PlainNet, self).__init__( + argv=argv, + opt=opt, + num_classes=num_classes, + plainnet_struct=plainnet_struct, + no_create=no_create, + no_reslink=no_reslink, + no_BN=no_BN, + use_se=use_se, + **kwargs) + self.last_channels = self.block_list[-1].out_channels + self.fc_linear = basic_blocks.Linear( + in_channels=self.last_channels, + out_channels=self.num_classes, + no_create=no_create) + + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + + for layer in self.modules(): + if isinstance(layer, nn.BatchNorm2d): + layer.eps = 1e-3 + + def forward(self, x): + output = x + for block_id, the_block in enumerate(self.block_list): + output = the_block(output) + if self.dropout is not None: + dropout_p = float(block_id) / len( + self.block_list) * self.dropout + output = F.dropout( + output, dropout_p, training=self.training, inplace=True) + + output = F.adaptive_avg_pool2d(output, output_size=1) + if self.dropout is not None: + output = F.dropout( + output, self.dropout, training=self.training, inplace=True) + output = torch.flatten(output, 1) + output = self.fc_linear(output) + return output diff --git a/modelscope/models/cv/tinynas_classfication/model_zoo.py b/modelscope/models/cv/tinynas_classfication/model_zoo.py new file mode 100644 index 00000000..a49b053b --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/model_zoo.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +from . import master_net + + +def get_zennet(): + model_plainnet_str = ( + 'SuperConvK3BNRELU(3,32,2,1)' + 'SuperResK1K5K1(32,80,2,32,1)SuperResK1K7K1(80,432,2,128,5)' + 'SuperResK1K7K1(432,640,2,192,3)SuperResK1K7K1(640,1008,1,160,5)' + 'SuperResK1K7K1(1008,976,1,160,4)SuperResK1K5K1(976,2304,2,384,5)' + 'SuperResK1K5K1(2304,2496,1,384,5)SuperConvK1BNRELU(2496,3072,1,1)') + use_SE = False + num_classes = 1000 + + model = master_net.PlainNet( + num_classes=num_classes, + plainnet_struct=model_plainnet_str, + use_se=use_SE) + + return model diff --git a/modelscope/models/cv/tinynas_classfication/plain_net_utils.py b/modelscope/models/cv/tinynas_classfication/plain_net_utils.py new file mode 100644 index 00000000..844535ed --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/plain_net_utils.py @@ -0,0 +1,89 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +from torch import nn + +from . import (basic_blocks, super_blocks, super_res_idwexkx, super_res_k1kxk1, + super_res_kxkx) +from .global_utils import create_netblock_list_from_str_inner + + +class PlainNet(nn.Module): + + def __init__(self, + argv=None, + opt=None, + num_classes=None, + plainnet_struct=None, + no_create=False, + **kwargs): + super(PlainNet, self).__init__() + self.argv = argv + self.opt = opt + self.num_classes = num_classes + self.plainnet_struct = plainnet_struct + + self.module_opt = None + + if self.num_classes is None: + self.num_classes = self.module_opt.num_classes + + if self.plainnet_struct is None and self.module_opt.plainnet_struct is not None: + self.plainnet_struct = self.module_opt.plainnet_struct + + if self.plainnet_struct is None: + if hasattr(opt, 'plainnet_struct_txt' + ) and opt.plainnet_struct_txt is not None: + plainnet_struct_txt = opt.plainnet_struct_txt + else: + plainnet_struct_txt = self.module_opt.plainnet_struct_txt + + if plainnet_struct_txt is not None: + with open(plainnet_struct_txt, 'r') as fid: + the_line = fid.readlines()[0].strip() + self.plainnet_struct = the_line + pass + + if self.plainnet_struct is None: + return + + the_s = self.plainnet_struct + + block_list, remaining_s = create_netblock_list_from_str_inner( + the_s, + netblocks_dict=_all_netblocks_dict_, + no_create=no_create, + **kwargs) + assert len(remaining_s) == 0 + + self.block_list = block_list + if not no_create: + self.module_list = nn.ModuleList(block_list) + + def forward(self, x): + output = x + for the_block in self.block_list: + output = the_block(output) + return output + + def __str__(self): + s = '' + for the_block in self.block_list: + s += str(the_block) + return s + + def __repr__(self): + return str(self) + + +_all_netblocks_dict_ = {} +_all_netblocks_dict_ = basic_blocks.register_netblocks_dict( + _all_netblocks_dict_) +_all_netblocks_dict_ = super_blocks.register_netblocks_dict( + _all_netblocks_dict_) +_all_netblocks_dict_ = super_res_kxkx.register_netblocks_dict( + _all_netblocks_dict_) +_all_netblocks_dict_ = super_res_k1kxk1.register_netblocks_dict( + _all_netblocks_dict_) +_all_netblocks_dict_ = super_res_idwexkx.register_netblocks_dict( + _all_netblocks_dict_) diff --git a/modelscope/models/cv/tinynas_classfication/super_blocks.py b/modelscope/models/cv/tinynas_classfication/super_blocks.py new file mode 100644 index 00000000..25862255 --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/super_blocks.py @@ -0,0 +1,228 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import uuid + +from torch import nn + +from . import basic_blocks, global_utils +from .global_utils import get_right_parentheses_index + + +class PlainNetSuperBlockClass(basic_blocks.PlainNetBasicBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + no_create=False, + **kwargs): + super(PlainNetSuperBlockClass, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.sub_layers = sub_layers + self.no_create = no_create + self.block_list = None + self.module_list = None + + def forward(self, x): + output = x + for block in self.block_list: + output = block(output) + return output + + def __str__(self): + return type(self).__name__ + '({},{},{},{})'.format( + self.in_channels, self.out_channels, self.stride, self.sub_layers) + + def __repr__(self): + return type(self).__name__ + '({}|{},{},{},{})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, + self.sub_layers) + + def get_output_resolution(self, input_resolution): + resolution = input_resolution + for block in self.block_list: + resolution = block.get_output_resolution(resolution) + return resolution + + @classmethod + def create_from_str(cls, s, no_create=False, **kwargs): + assert cls.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + sub_layers = int(param_str_split[3]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + sub_layers=sub_layers, + block_name=tmp_block_name, + no_create=no_create, + **kwargs), s[idx + 1:] + + +class SuperConvKXBNRELU(PlainNetSuperBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + kernel_size=None, + no_create=False, + no_reslink=False, + no_BN=False, + **kwargs): + super(SuperConvKXBNRELU, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + if not self.no_BN: + inner_str = 'ConvKX({},{},{},{})BN({})RELU({})'.format( + last_channels, self.out_channels, self.kernel_size, + current_stride, self.out_channels, self.out_channels) + else: + inner_str = 'ConvKX({},{},{},{})RELU({})'.format( + last_channels, self.out_channels, self.kernel_size, + current_stride, self.out_channels) + full_str += inner_str + + last_channels = out_channels + current_stride = 1 + pass + + netblocks_dict = basic_blocks.register_netblocks_dict({}) + self.block_list = global_utils.create_netblock_list_from_str( + full_str, + no_create=no_create, + netblocks_dict=netblocks_dict, + no_reslink=no_reslink, + no_BN=no_BN) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{})'.format( + self.in_channels, self.out_channels, self.stride, self.sub_layers) + + def __repr__(self): + return type( + self + ).__name__ + '({}|in={},out={},stride={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, + self.sub_layers, self.kernel_size) + + +class SuperConvK1BNRELU(SuperConvKXBNRELU): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperConvK1BNRELU, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + sub_layers=sub_layers, + kernel_size=1, + no_create=no_create, + **kwargs) + + +class SuperConvK3BNRELU(SuperConvKXBNRELU): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperConvK3BNRELU, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, + **kwargs) + + +class SuperConvK5BNRELU(SuperConvKXBNRELU): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperConvK5BNRELU, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, + **kwargs) + + +class SuperConvK7BNRELU(SuperConvKXBNRELU): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperConvK7BNRELU, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, + **kwargs) + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperConvK1BNRELU': SuperConvK1BNRELU, + 'SuperConvK3BNRELU': SuperConvK3BNRELU, + 'SuperConvK5BNRELU': SuperConvK5BNRELU, + 'SuperConvK7BNRELU': SuperConvK7BNRELU, + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict diff --git a/modelscope/models/cv/tinynas_classfication/super_res_idwexkx.py b/modelscope/models/cv/tinynas_classfication/super_res_idwexkx.py new file mode 100644 index 00000000..7d005069 --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/super_res_idwexkx.py @@ -0,0 +1,451 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import uuid + +from torch import nn + +from . import basic_blocks, global_utils +from .global_utils import get_right_parentheses_index +from .super_blocks import PlainNetSuperBlockClass + + +class SuperResIDWEXKX(PlainNetSuperBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + kernel_size=None, + expension=None, + no_create=False, + no_reslink=False, + no_BN=False, + use_se=False, + **kwargs): + super(SuperResIDWEXKX, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.expension = expension + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + + self.use_se = use_se + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + dw_channels = global_utils.smart_round( + self.bottleneck_channels * self.expension, base=8) + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, + dw_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + + inner_str += 'ConvDW({},{},{})'.format(dw_channels, + self.kernel_size, + current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + if self.use_se: + inner_str += 'SE({})'.format(dw_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(dw_channels, + bottleneck_channels, 1, + 1) + if not self.no_BN: + inner_str += 'BN({})'.format(bottleneck_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format( + inner_str, self.out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format( + inner_str, self.out_channels) + + else: + res_str = '{}RELU({})'.format(inner_str, self.out_channels) + + full_str += res_str + + inner_str = '' + dw_channels = global_utils.smart_round( + self.out_channels * self.expension, base=8) + inner_str += 'ConvKX({},{},{},{})'.format(bottleneck_channels, + dw_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + + inner_str += 'ConvDW({},{},{})'.format(dw_channels, + self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(dw_channels) + inner_str += 'RELU({})'.format(dw_channels) + if self.use_se: + inner_str += 'SE({})'.format(dw_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(dw_channels, + self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + res_str = 'ResBlock({})RELU({})'.format( + inner_str, self.out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, self.out_channels) + + full_str += res_str + last_channels = out_channels + current_stride = 1 + pass + + netblocks_dict = basic_blocks.register_netblocks_dict({}) + self.block_list = global_utils.create_netblock_list_from_str( + full_str, + netblocks_dict=netblocks_dict, + no_create=no_create, + no_reslink=no_reslink, + no_BN=no_BN, + **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format( + self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type( + self + ).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers, self.kernel_size) + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + block_name=tmp_block_name, + **kwargs), s[idx + 1:] + + +class SuperResIDWE1K3(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE1K3, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + expension=1.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE2K3(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE2K3, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + expension=2.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE4K3(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE4K3, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + expension=4.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE6K3(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE6K3, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + expension=6.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE1K5(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE1K5, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + expension=1.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE2K5(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE2K5, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + expension=2.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE4K5(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE4K5, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + expension=4.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE6K5(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE6K5, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + expension=6.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE1K7(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE1K7, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + expension=1.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE2K7(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE2K7, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + expension=2.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE4K7(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE4K7, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + expension=4.0, + no_create=no_create, + **kwargs) + + +class SuperResIDWE6K7(SuperResIDWEXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResIDWE6K7, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + expension=6.0, + no_create=no_create, + **kwargs) + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperResIDWE1K3': SuperResIDWE1K3, + 'SuperResIDWE2K3': SuperResIDWE2K3, + 'SuperResIDWE4K3': SuperResIDWE4K3, + 'SuperResIDWE6K3': SuperResIDWE6K3, + 'SuperResIDWE1K5': SuperResIDWE1K5, + 'SuperResIDWE2K5': SuperResIDWE2K5, + 'SuperResIDWE4K5': SuperResIDWE4K5, + 'SuperResIDWE6K5': SuperResIDWE6K5, + 'SuperResIDWE1K7': SuperResIDWE1K7, + 'SuperResIDWE2K7': SuperResIDWE2K7, + 'SuperResIDWE4K7': SuperResIDWE4K7, + 'SuperResIDWE6K7': SuperResIDWE6K7, + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict diff --git a/modelscope/models/cv/tinynas_classfication/super_res_k1kxk1.py b/modelscope/models/cv/tinynas_classfication/super_res_k1kxk1.py new file mode 100644 index 00000000..3ca68742 --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/super_res_k1kxk1.py @@ -0,0 +1,238 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import uuid + +from torch import nn + +from . import basic_blocks, global_utils +from .global_utils import get_right_parentheses_index +from .super_blocks import PlainNetSuperBlockClass + + +class SuperResK1KXK1(PlainNetSuperBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + kernel_size=None, + no_create=False, + no_reslink=False, + no_BN=False, + use_se=False, + **kwargs): + super(SuperResK1KXK1, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, + self.bottleneck_channels, + 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, + self.bottleneck_channels, + self.kernel_size, + current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, + self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format( + inner_str, out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format( + inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + inner_str = '' + inner_str += 'ConvKX({},{},{},{})'.format(self.out_channels, + self.bottleneck_channels, + 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, + self.bottleneck_channels, + self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, + self.out_channels, 1, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + res_str = 'ResBlock({})RELU({})'.format( + inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + last_channels = out_channels + current_stride = 1 + pass + + netblocks_dict = basic_blocks.register_netblocks_dict({}) + self.block_list = global_utils.create_netblock_list_from_str( + full_str, + netblocks_dict=netblocks_dict, + no_create=no_create, + no_reslink=no_reslink, + no_BN=no_BN, + **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format( + self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type( + self + ).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers, self.kernel_size) + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + block_name=tmp_block_name, + **kwargs), s[idx + 1:] + + +class SuperResK1K3K1(SuperResK1KXK1): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK1K3K1, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, + **kwargs) + + +class SuperResK1K5K1(SuperResK1KXK1): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK1K5K1, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, + **kwargs) + + +class SuperResK1K7K1(SuperResK1KXK1): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK1K7K1, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, + **kwargs) + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperResK1K3K1': SuperResK1K3K1, + 'SuperResK1K5K1': SuperResK1K5K1, + 'SuperResK1K7K1': SuperResK1K7K1, + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict diff --git a/modelscope/models/cv/tinynas_classfication/super_res_kxkx.py b/modelscope/models/cv/tinynas_classfication/super_res_kxkx.py new file mode 100644 index 00000000..a694fdbe --- /dev/null +++ b/modelscope/models/cv/tinynas_classfication/super_res_kxkx.py @@ -0,0 +1,202 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. + +import uuid + +from torch import nn + +from . import basic_blocks, global_utils +from .global_utils import get_right_parentheses_index +from .super_blocks import PlainNetSuperBlockClass + + +class SuperResKXKX(PlainNetSuperBlockClass): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + kernel_size=None, + no_create=False, + no_reslink=False, + no_BN=False, + use_se=False, + **kwargs): + super(SuperResKXKX, self).__init__(**kwargs) + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.bottleneck_channels = bottleneck_channels + self.sub_layers = sub_layers + self.kernel_size = kernel_size + self.no_create = no_create + self.no_reslink = no_reslink + self.no_BN = no_BN + self.use_se = use_se + + full_str = '' + last_channels = in_channels + current_stride = stride + for i in range(self.sub_layers): + inner_str = '' + + inner_str += 'ConvKX({},{},{},{})'.format(last_channels, + self.bottleneck_channels, + self.kernel_size, + current_stride) + if not self.no_BN: + inner_str += 'BN({})'.format(self.bottleneck_channels) + inner_str += 'RELU({})'.format(self.bottleneck_channels) + if self.use_se: + inner_str += 'SE({})'.format(bottleneck_channels) + + inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, + self.out_channels, + self.kernel_size, 1) + if not self.no_BN: + inner_str += 'BN({})'.format(self.out_channels) + + if not self.no_reslink: + if i == 0: + res_str = 'ResBlockProj({})RELU({})'.format( + inner_str, out_channels) + else: + res_str = 'ResBlock({})RELU({})'.format( + inner_str, out_channels) + else: + res_str = '{}RELU({})'.format(inner_str, out_channels) + + full_str += res_str + + last_channels = out_channels + current_stride = 1 + pass + + netblocks_dict = basic_blocks.register_netblocks_dict({}) + self.block_list = global_utils.create_netblock_list_from_str( + full_str, + netblocks_dict=netblocks_dict, + no_create=no_create, + no_reslink=no_reslink, + no_BN=no_BN, + **kwargs) + if not no_create: + self.module_list = nn.ModuleList(self.block_list) + else: + self.module_list = None + + def __str__(self): + return type(self).__name__ + '({},{},{},{},{})'.format( + self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers) + + def __repr__(self): + return type( + self + ).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( + self.block_name, self.in_channels, self.out_channels, self.stride, + self.bottleneck_channels, self.sub_layers, self.kernel_size) + + @classmethod + def create_from_str(cls, s, **kwargs): + assert cls.is_instance_from_str(s) + idx = get_right_parentheses_index(s) + assert idx is not None + param_str = s[len(cls.__name__ + '('):idx] + + tmp_idx = param_str.find('|') + if tmp_idx < 0: + tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) + else: + tmp_block_name = param_str[0:tmp_idx] + param_str = param_str[tmp_idx + 1:] + + param_str_split = param_str.split(',') + in_channels = int(param_str_split[0]) + out_channels = int(param_str_split[1]) + stride = int(param_str_split[2]) + bottleneck_channels = int(param_str_split[3]) + sub_layers = int(param_str_split[4]) + return cls( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + block_name=tmp_block_name, + **kwargs), s[idx + 1:] + + +class SuperResK3K3(SuperResKXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK3K3, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=3, + no_create=no_create, + **kwargs) + + +class SuperResK5K5(SuperResKXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK5K5, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=5, + no_create=no_create, + **kwargs) + + +class SuperResK7K7(SuperResKXKX): + + def __init__(self, + in_channels=None, + out_channels=None, + stride=None, + bottleneck_channels=None, + sub_layers=None, + no_create=False, + **kwargs): + super(SuperResK7K7, self).__init__( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + bottleneck_channels=bottleneck_channels, + sub_layers=sub_layers, + kernel_size=7, + no_create=no_create, + **kwargs) + + +def register_netblocks_dict(netblocks_dict: dict): + this_py_file_netblocks_dict = { + 'SuperResK3K3': SuperResK3K3, + 'SuperResK5K5': SuperResK5K5, + 'SuperResK7K7': SuperResK7K7, + } + netblocks_dict.update(this_py_file_netblocks_dict) + return netblocks_dict diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 6027923e..76d0d575 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from .live_category_pipeline import LiveCategoryPipeline from .ocr_detection_pipeline import OCRDetectionPipeline from .skin_retouching_pipeline import SkinRetouchingPipeline + from .tinynas_classification_pipeline import TinynasClassificationPipeline from .video_category_pipeline import VideoCategoryPipeline from .virtual_try_on_pipeline import VirtualTryonPipeline else: @@ -65,6 +66,7 @@ else: ['Image2ImageGenerationPipeline'], 'ocr_detection_pipeline': ['OCRDetectionPipeline'], 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], + 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], 'video_category_pipeline': ['VideoCategoryPipeline'], 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], } diff --git a/modelscope/pipelines/cv/tinynas_classification_pipeline.py b/modelscope/pipelines/cv/tinynas_classification_pipeline.py new file mode 100644 index 00000000..d49166d1 --- /dev/null +++ b/modelscope/pipelines/cv/tinynas_classification_pipeline.py @@ -0,0 +1,96 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import math +import os.path as osp +from typing import Any, Dict + +import torch +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.tinynas_classfication import get_zennet +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_classification, module_name=Pipelines.tinynas_classification) +class TinynasClassificationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a tinynas classification pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + self.path = model + self.model = get_zennet() + + model_pth_path = osp.join(self.path, ModelFile.TORCH_MODEL_FILE) + + checkpoint = torch.load(model_pth_path, map_location='cpu') + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + self.model.load_state_dict(state_dict, strict=True) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input) + + input_image_size = 224 + crop_image_size = 380 + input_image_crop = 0.875 + resize_image_size = int(math.ceil(crop_image_size / input_image_crop)) + transforms_normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + transform_list = [ + transforms.Resize( + resize_image_size, + interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(crop_image_size), + transforms.ToTensor(), transforms_normalize + ] + transformer = transforms.Compose(transform_list) + + img = transformer(img) + img = torch.unsqueeze(img, 0) + img = torch.nn.functional.interpolate( + img, input_image_size, mode='bilinear') + result = {'img': img} + + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + is_train = False + if is_train: + self.model.train() + else: + self.model.eval() + + outputs = self.model(input['img']) + return {'outputs': outputs} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + label_mapping_path = osp.join(self.path, 'label_map.txt') + f = open(label_mapping_path) + content = f.read() + f.close() + label_dict = eval(content) + + output_prob = torch.nn.functional.softmax(inputs['outputs'], dim=-1) + score = torch.max(output_prob) + output_dict = { + OutputKeys.SCORES: score.item(), + OutputKeys.LABELS: label_dict[inputs['outputs'].argmax().item()] + } + return output_dict diff --git a/tests/pipelines/test_tinynas_classification.py b/tests/pipelines/test_tinynas_classification.py new file mode 100644 index 00000000..d64b5bc0 --- /dev/null +++ b/tests/pipelines/test_tinynas_classification.py @@ -0,0 +1,19 @@ +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class TinyNASClassificationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run(self): + tinynas_classification = pipeline( + Tasks.image_classification, model='damo/cv_tinynas_classification') + result = tinynas_classification('data/test/images/image_wolf.jpeg') + print(result) + + +if __name__ == '__main__': + unittest.main()