由于 在 系列方法 命名上有调整,将 lightnas 统一改为 tinynas 之前有一次提交,对应的 code review 为:https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9578861 现将 上面的的意见已统一修复,并重新 切换到 cv/tinynas/classification 分支上 进行重新 提交审核 之前 cv/lightnas/classification 已废弃 请悉知~ Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9663300master
| @@ -3,3 +3,4 @@ | |||
| *.mp4 filter=lfs diff=lfs merge=lfs -text | |||
| *.wav filter=lfs diff=lfs merge=lfs -text | |||
| *.JPEG filter=lfs diff=lfs merge=lfs -text | |||
| *.jpeg filter=lfs diff=lfs merge=lfs -text | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:cbe3c719d25c2c90349c3c280e74f46f315a490443655ceba8b8a203af0f7259 | |||
| size 171378 | |||
| @@ -102,6 +102,7 @@ class Pipelines(object): | |||
| image_portrait_enhancement = 'gpen-image-portrait-enhancement' | |||
| image_to_image_generation = 'image-to-image-generation' | |||
| skin_retouching = 'unet-skin-retouching' | |||
| tinynas_classification = 'tinynas-classification' | |||
| # nlp tasks | |||
| sentence_similarity = 'sentence-similarity' | |||
| @@ -0,0 +1,24 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .model_zoo import get_zennet | |||
| else: | |||
| _import_structure = { | |||
| 'model_zoo': ['get_zennet'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. | |||
| def smart_round(x, base=None): | |||
| if base is None: | |||
| if x > 32 * 8: | |||
| round_base = 32 | |||
| elif x > 16 * 8: | |||
| round_base = 16 | |||
| else: | |||
| round_base = 8 | |||
| else: | |||
| round_base = base | |||
| return max(round_base, round(x / float(round_base)) * round_base) | |||
| def get_right_parentheses_index(s): | |||
| left_paren_count = 0 | |||
| for index, x in enumerate(s): | |||
| if x == '(': | |||
| left_paren_count += 1 | |||
| elif x == ')': | |||
| left_paren_count -= 1 | |||
| if left_paren_count == 0: | |||
| return index | |||
| else: | |||
| pass | |||
| return None | |||
| def create_netblock_list_from_str_inner(s, | |||
| no_create=False, | |||
| netblocks_dict=None, | |||
| **kwargs): | |||
| block_list = [] | |||
| while len(s) > 0: | |||
| is_found_block_class = False | |||
| for the_block_class_name in netblocks_dict.keys(): | |||
| tmp_idx = s.find('(') | |||
| if tmp_idx > 0 and s[0:tmp_idx] == the_block_class_name: | |||
| is_found_block_class = True | |||
| the_block_class = netblocks_dict[the_block_class_name] | |||
| the_block, remaining_s = the_block_class.create_from_str( | |||
| s, no_create=no_create, **kwargs) | |||
| if the_block is not None: | |||
| block_list.append(the_block) | |||
| s = remaining_s | |||
| if len(s) > 0 and s[0] == ';': | |||
| return block_list, s[1:] | |||
| break | |||
| assert is_found_block_class | |||
| return block_list, '' | |||
| def create_netblock_list_from_str(s, | |||
| no_create=False, | |||
| netblocks_dict=None, | |||
| **kwargs): | |||
| the_list, remaining_s = create_netblock_list_from_str_inner( | |||
| s, no_create=no_create, netblocks_dict=netblocks_dict, **kwargs) | |||
| assert len(remaining_s) == 0 | |||
| return the_list | |||
| @@ -0,0 +1,94 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from torch import nn | |||
| from . import basic_blocks, plain_net_utils | |||
| class PlainNet(plain_net_utils.PlainNet): | |||
| def __init__(self, | |||
| argv=None, | |||
| opt=None, | |||
| num_classes=None, | |||
| plainnet_struct=None, | |||
| no_create=False, | |||
| no_reslink=None, | |||
| no_BN=None, | |||
| use_se=None, | |||
| dropout=None, | |||
| **kwargs): | |||
| module_opt = None | |||
| if no_BN is None: | |||
| if module_opt is not None: | |||
| no_BN = module_opt.no_BN | |||
| else: | |||
| no_BN = False | |||
| if no_reslink is None: | |||
| if module_opt is not None: | |||
| no_reslink = module_opt.no_reslink | |||
| else: | |||
| no_reslink = False | |||
| if use_se is None: | |||
| if module_opt is not None: | |||
| use_se = module_opt.use_se | |||
| else: | |||
| use_se = False | |||
| if dropout is None: | |||
| if module_opt is not None: | |||
| self.dropout = module_opt.dropout | |||
| else: | |||
| self.dropout = None | |||
| else: | |||
| self.dropout = dropout | |||
| super(PlainNet, self).__init__( | |||
| argv=argv, | |||
| opt=opt, | |||
| num_classes=num_classes, | |||
| plainnet_struct=plainnet_struct, | |||
| no_create=no_create, | |||
| no_reslink=no_reslink, | |||
| no_BN=no_BN, | |||
| use_se=use_se, | |||
| **kwargs) | |||
| self.last_channels = self.block_list[-1].out_channels | |||
| self.fc_linear = basic_blocks.Linear( | |||
| in_channels=self.last_channels, | |||
| out_channels=self.num_classes, | |||
| no_create=no_create) | |||
| self.no_create = no_create | |||
| self.no_reslink = no_reslink | |||
| self.no_BN = no_BN | |||
| self.use_se = use_se | |||
| for layer in self.modules(): | |||
| if isinstance(layer, nn.BatchNorm2d): | |||
| layer.eps = 1e-3 | |||
| def forward(self, x): | |||
| output = x | |||
| for block_id, the_block in enumerate(self.block_list): | |||
| output = the_block(output) | |||
| if self.dropout is not None: | |||
| dropout_p = float(block_id) / len( | |||
| self.block_list) * self.dropout | |||
| output = F.dropout( | |||
| output, dropout_p, training=self.training, inplace=True) | |||
| output = F.adaptive_avg_pool2d(output, output_size=1) | |||
| if self.dropout is not None: | |||
| output = F.dropout( | |||
| output, self.dropout, training=self.training, inplace=True) | |||
| output = torch.flatten(output, 1) | |||
| output = self.fc_linear(output) | |||
| return output | |||
| @@ -0,0 +1,22 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. | |||
| from . import master_net | |||
| def get_zennet(): | |||
| model_plainnet_str = ( | |||
| 'SuperConvK3BNRELU(3,32,2,1)' | |||
| 'SuperResK1K5K1(32,80,2,32,1)SuperResK1K7K1(80,432,2,128,5)' | |||
| 'SuperResK1K7K1(432,640,2,192,3)SuperResK1K7K1(640,1008,1,160,5)' | |||
| 'SuperResK1K7K1(1008,976,1,160,4)SuperResK1K5K1(976,2304,2,384,5)' | |||
| 'SuperResK1K5K1(2304,2496,1,384,5)SuperConvK1BNRELU(2496,3072,1,1)') | |||
| use_SE = False | |||
| num_classes = 1000 | |||
| model = master_net.PlainNet( | |||
| num_classes=num_classes, | |||
| plainnet_struct=model_plainnet_str, | |||
| use_se=use_SE) | |||
| return model | |||
| @@ -0,0 +1,89 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. | |||
| from torch import nn | |||
| from . import (basic_blocks, super_blocks, super_res_idwexkx, super_res_k1kxk1, | |||
| super_res_kxkx) | |||
| from .global_utils import create_netblock_list_from_str_inner | |||
| class PlainNet(nn.Module): | |||
| def __init__(self, | |||
| argv=None, | |||
| opt=None, | |||
| num_classes=None, | |||
| plainnet_struct=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(PlainNet, self).__init__() | |||
| self.argv = argv | |||
| self.opt = opt | |||
| self.num_classes = num_classes | |||
| self.plainnet_struct = plainnet_struct | |||
| self.module_opt = None | |||
| if self.num_classes is None: | |||
| self.num_classes = self.module_opt.num_classes | |||
| if self.plainnet_struct is None and self.module_opt.plainnet_struct is not None: | |||
| self.plainnet_struct = self.module_opt.plainnet_struct | |||
| if self.plainnet_struct is None: | |||
| if hasattr(opt, 'plainnet_struct_txt' | |||
| ) and opt.plainnet_struct_txt is not None: | |||
| plainnet_struct_txt = opt.plainnet_struct_txt | |||
| else: | |||
| plainnet_struct_txt = self.module_opt.plainnet_struct_txt | |||
| if plainnet_struct_txt is not None: | |||
| with open(plainnet_struct_txt, 'r') as fid: | |||
| the_line = fid.readlines()[0].strip() | |||
| self.plainnet_struct = the_line | |||
| pass | |||
| if self.plainnet_struct is None: | |||
| return | |||
| the_s = self.plainnet_struct | |||
| block_list, remaining_s = create_netblock_list_from_str_inner( | |||
| the_s, | |||
| netblocks_dict=_all_netblocks_dict_, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| assert len(remaining_s) == 0 | |||
| self.block_list = block_list | |||
| if not no_create: | |||
| self.module_list = nn.ModuleList(block_list) | |||
| def forward(self, x): | |||
| output = x | |||
| for the_block in self.block_list: | |||
| output = the_block(output) | |||
| return output | |||
| def __str__(self): | |||
| s = '' | |||
| for the_block in self.block_list: | |||
| s += str(the_block) | |||
| return s | |||
| def __repr__(self): | |||
| return str(self) | |||
| _all_netblocks_dict_ = {} | |||
| _all_netblocks_dict_ = basic_blocks.register_netblocks_dict( | |||
| _all_netblocks_dict_) | |||
| _all_netblocks_dict_ = super_blocks.register_netblocks_dict( | |||
| _all_netblocks_dict_) | |||
| _all_netblocks_dict_ = super_res_kxkx.register_netblocks_dict( | |||
| _all_netblocks_dict_) | |||
| _all_netblocks_dict_ = super_res_k1kxk1.register_netblocks_dict( | |||
| _all_netblocks_dict_) | |||
| _all_netblocks_dict_ = super_res_idwexkx.register_netblocks_dict( | |||
| _all_netblocks_dict_) | |||
| @@ -0,0 +1,228 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. | |||
| import uuid | |||
| from torch import nn | |||
| from . import basic_blocks, global_utils | |||
| from .global_utils import get_right_parentheses_index | |||
| class PlainNetSuperBlockClass(basic_blocks.PlainNetBasicBlockClass): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(PlainNetSuperBlockClass, self).__init__() | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.stride = stride | |||
| self.sub_layers = sub_layers | |||
| self.no_create = no_create | |||
| self.block_list = None | |||
| self.module_list = None | |||
| def forward(self, x): | |||
| output = x | |||
| for block in self.block_list: | |||
| output = block(output) | |||
| return output | |||
| def __str__(self): | |||
| return type(self).__name__ + '({},{},{},{})'.format( | |||
| self.in_channels, self.out_channels, self.stride, self.sub_layers) | |||
| def __repr__(self): | |||
| return type(self).__name__ + '({}|{},{},{},{})'.format( | |||
| self.block_name, self.in_channels, self.out_channels, self.stride, | |||
| self.sub_layers) | |||
| def get_output_resolution(self, input_resolution): | |||
| resolution = input_resolution | |||
| for block in self.block_list: | |||
| resolution = block.get_output_resolution(resolution) | |||
| return resolution | |||
| @classmethod | |||
| def create_from_str(cls, s, no_create=False, **kwargs): | |||
| assert cls.is_instance_from_str(s) | |||
| idx = get_right_parentheses_index(s) | |||
| assert idx is not None | |||
| param_str = s[len(cls.__name__ + '('):idx] | |||
| tmp_idx = param_str.find('|') | |||
| if tmp_idx < 0: | |||
| tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) | |||
| else: | |||
| tmp_block_name = param_str[0:tmp_idx] | |||
| param_str = param_str[tmp_idx + 1:] | |||
| param_str_split = param_str.split(',') | |||
| in_channels = int(param_str_split[0]) | |||
| out_channels = int(param_str_split[1]) | |||
| stride = int(param_str_split[2]) | |||
| sub_layers = int(param_str_split[3]) | |||
| return cls( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| sub_layers=sub_layers, | |||
| block_name=tmp_block_name, | |||
| no_create=no_create, | |||
| **kwargs), s[idx + 1:] | |||
| class SuperConvKXBNRELU(PlainNetSuperBlockClass): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| sub_layers=None, | |||
| kernel_size=None, | |||
| no_create=False, | |||
| no_reslink=False, | |||
| no_BN=False, | |||
| **kwargs): | |||
| super(SuperConvKXBNRELU, self).__init__(**kwargs) | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.stride = stride | |||
| self.sub_layers = sub_layers | |||
| self.kernel_size = kernel_size | |||
| self.no_create = no_create | |||
| self.no_reslink = no_reslink | |||
| self.no_BN = no_BN | |||
| full_str = '' | |||
| last_channels = in_channels | |||
| current_stride = stride | |||
| for i in range(self.sub_layers): | |||
| if not self.no_BN: | |||
| inner_str = 'ConvKX({},{},{},{})BN({})RELU({})'.format( | |||
| last_channels, self.out_channels, self.kernel_size, | |||
| current_stride, self.out_channels, self.out_channels) | |||
| else: | |||
| inner_str = 'ConvKX({},{},{},{})RELU({})'.format( | |||
| last_channels, self.out_channels, self.kernel_size, | |||
| current_stride, self.out_channels) | |||
| full_str += inner_str | |||
| last_channels = out_channels | |||
| current_stride = 1 | |||
| pass | |||
| netblocks_dict = basic_blocks.register_netblocks_dict({}) | |||
| self.block_list = global_utils.create_netblock_list_from_str( | |||
| full_str, | |||
| no_create=no_create, | |||
| netblocks_dict=netblocks_dict, | |||
| no_reslink=no_reslink, | |||
| no_BN=no_BN) | |||
| if not no_create: | |||
| self.module_list = nn.ModuleList(self.block_list) | |||
| else: | |||
| self.module_list = None | |||
| def __str__(self): | |||
| return type(self).__name__ + '({},{},{},{})'.format( | |||
| self.in_channels, self.out_channels, self.stride, self.sub_layers) | |||
| def __repr__(self): | |||
| return type( | |||
| self | |||
| ).__name__ + '({}|in={},out={},stride={},sub_layers={},kernel_size={})'.format( | |||
| self.block_name, self.in_channels, self.out_channels, self.stride, | |||
| self.sub_layers, self.kernel_size) | |||
| class SuperConvK1BNRELU(SuperConvKXBNRELU): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperConvK1BNRELU, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| sub_layers=sub_layers, | |||
| kernel_size=1, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperConvK3BNRELU(SuperConvKXBNRELU): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperConvK3BNRELU, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| sub_layers=sub_layers, | |||
| kernel_size=3, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperConvK5BNRELU(SuperConvKXBNRELU): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperConvK5BNRELU, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| sub_layers=sub_layers, | |||
| kernel_size=5, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperConvK7BNRELU(SuperConvKXBNRELU): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperConvK7BNRELU, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| sub_layers=sub_layers, | |||
| kernel_size=7, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| def register_netblocks_dict(netblocks_dict: dict): | |||
| this_py_file_netblocks_dict = { | |||
| 'SuperConvK1BNRELU': SuperConvK1BNRELU, | |||
| 'SuperConvK3BNRELU': SuperConvK3BNRELU, | |||
| 'SuperConvK5BNRELU': SuperConvK5BNRELU, | |||
| 'SuperConvK7BNRELU': SuperConvK7BNRELU, | |||
| } | |||
| netblocks_dict.update(this_py_file_netblocks_dict) | |||
| return netblocks_dict | |||
| @@ -0,0 +1,451 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. | |||
| import uuid | |||
| from torch import nn | |||
| from . import basic_blocks, global_utils | |||
| from .global_utils import get_right_parentheses_index | |||
| from .super_blocks import PlainNetSuperBlockClass | |||
| class SuperResIDWEXKX(PlainNetSuperBlockClass): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| kernel_size=None, | |||
| expension=None, | |||
| no_create=False, | |||
| no_reslink=False, | |||
| no_BN=False, | |||
| use_se=False, | |||
| **kwargs): | |||
| super(SuperResIDWEXKX, self).__init__(**kwargs) | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.stride = stride | |||
| self.bottleneck_channels = bottleneck_channels | |||
| self.sub_layers = sub_layers | |||
| self.kernel_size = kernel_size | |||
| self.expension = expension | |||
| self.no_create = no_create | |||
| self.no_reslink = no_reslink | |||
| self.no_BN = no_BN | |||
| self.use_se = use_se | |||
| full_str = '' | |||
| last_channels = in_channels | |||
| current_stride = stride | |||
| for i in range(self.sub_layers): | |||
| inner_str = '' | |||
| dw_channels = global_utils.smart_round( | |||
| self.bottleneck_channels * self.expension, base=8) | |||
| inner_str += 'ConvKX({},{},{},{})'.format(last_channels, | |||
| dw_channels, 1, 1) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(dw_channels) | |||
| inner_str += 'RELU({})'.format(dw_channels) | |||
| inner_str += 'ConvDW({},{},{})'.format(dw_channels, | |||
| self.kernel_size, | |||
| current_stride) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(dw_channels) | |||
| inner_str += 'RELU({})'.format(dw_channels) | |||
| if self.use_se: | |||
| inner_str += 'SE({})'.format(dw_channels) | |||
| inner_str += 'ConvKX({},{},{},{})'.format(dw_channels, | |||
| bottleneck_channels, 1, | |||
| 1) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(bottleneck_channels) | |||
| if not self.no_reslink: | |||
| if i == 0: | |||
| res_str = 'ResBlockProj({})RELU({})'.format( | |||
| inner_str, self.out_channels) | |||
| else: | |||
| res_str = 'ResBlock({})RELU({})'.format( | |||
| inner_str, self.out_channels) | |||
| else: | |||
| res_str = '{}RELU({})'.format(inner_str, self.out_channels) | |||
| full_str += res_str | |||
| inner_str = '' | |||
| dw_channels = global_utils.smart_round( | |||
| self.out_channels * self.expension, base=8) | |||
| inner_str += 'ConvKX({},{},{},{})'.format(bottleneck_channels, | |||
| dw_channels, 1, 1) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(dw_channels) | |||
| inner_str += 'RELU({})'.format(dw_channels) | |||
| inner_str += 'ConvDW({},{},{})'.format(dw_channels, | |||
| self.kernel_size, 1) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(dw_channels) | |||
| inner_str += 'RELU({})'.format(dw_channels) | |||
| if self.use_se: | |||
| inner_str += 'SE({})'.format(dw_channels) | |||
| inner_str += 'ConvKX({},{},{},{})'.format(dw_channels, | |||
| self.out_channels, 1, 1) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(self.out_channels) | |||
| if not self.no_reslink: | |||
| res_str = 'ResBlock({})RELU({})'.format( | |||
| inner_str, self.out_channels) | |||
| else: | |||
| res_str = '{}RELU({})'.format(inner_str, self.out_channels) | |||
| full_str += res_str | |||
| last_channels = out_channels | |||
| current_stride = 1 | |||
| pass | |||
| netblocks_dict = basic_blocks.register_netblocks_dict({}) | |||
| self.block_list = global_utils.create_netblock_list_from_str( | |||
| full_str, | |||
| netblocks_dict=netblocks_dict, | |||
| no_create=no_create, | |||
| no_reslink=no_reslink, | |||
| no_BN=no_BN, | |||
| **kwargs) | |||
| if not no_create: | |||
| self.module_list = nn.ModuleList(self.block_list) | |||
| else: | |||
| self.module_list = None | |||
| def __str__(self): | |||
| return type(self).__name__ + '({},{},{},{},{})'.format( | |||
| self.in_channels, self.out_channels, self.stride, | |||
| self.bottleneck_channels, self.sub_layers) | |||
| def __repr__(self): | |||
| return type( | |||
| self | |||
| ).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( | |||
| self.block_name, self.in_channels, self.out_channels, self.stride, | |||
| self.bottleneck_channels, self.sub_layers, self.kernel_size) | |||
| @classmethod | |||
| def create_from_str(cls, s, **kwargs): | |||
| assert cls.is_instance_from_str(s) | |||
| idx = get_right_parentheses_index(s) | |||
| assert idx is not None | |||
| param_str = s[len(cls.__name__ + '('):idx] | |||
| tmp_idx = param_str.find('|') | |||
| if tmp_idx < 0: | |||
| tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) | |||
| else: | |||
| tmp_block_name = param_str[0:tmp_idx] | |||
| param_str = param_str[tmp_idx + 1:] | |||
| param_str_split = param_str.split(',') | |||
| in_channels = int(param_str_split[0]) | |||
| out_channels = int(param_str_split[1]) | |||
| stride = int(param_str_split[2]) | |||
| bottleneck_channels = int(param_str_split[3]) | |||
| sub_layers = int(param_str_split[4]) | |||
| return cls( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| block_name=tmp_block_name, | |||
| **kwargs), s[idx + 1:] | |||
| class SuperResIDWE1K3(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE1K3, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=3, | |||
| expension=1.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResIDWE2K3(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE2K3, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=3, | |||
| expension=2.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResIDWE4K3(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE4K3, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=3, | |||
| expension=4.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResIDWE6K3(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE6K3, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=3, | |||
| expension=6.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResIDWE1K5(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE1K5, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=5, | |||
| expension=1.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResIDWE2K5(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE2K5, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=5, | |||
| expension=2.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResIDWE4K5(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE4K5, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=5, | |||
| expension=4.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResIDWE6K5(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE6K5, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=5, | |||
| expension=6.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResIDWE1K7(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE1K7, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=7, | |||
| expension=1.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResIDWE2K7(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE2K7, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=7, | |||
| expension=2.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResIDWE4K7(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE4K7, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=7, | |||
| expension=4.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResIDWE6K7(SuperResIDWEXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResIDWE6K7, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=7, | |||
| expension=6.0, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| def register_netblocks_dict(netblocks_dict: dict): | |||
| this_py_file_netblocks_dict = { | |||
| 'SuperResIDWE1K3': SuperResIDWE1K3, | |||
| 'SuperResIDWE2K3': SuperResIDWE2K3, | |||
| 'SuperResIDWE4K3': SuperResIDWE4K3, | |||
| 'SuperResIDWE6K3': SuperResIDWE6K3, | |||
| 'SuperResIDWE1K5': SuperResIDWE1K5, | |||
| 'SuperResIDWE2K5': SuperResIDWE2K5, | |||
| 'SuperResIDWE4K5': SuperResIDWE4K5, | |||
| 'SuperResIDWE6K5': SuperResIDWE6K5, | |||
| 'SuperResIDWE1K7': SuperResIDWE1K7, | |||
| 'SuperResIDWE2K7': SuperResIDWE2K7, | |||
| 'SuperResIDWE4K7': SuperResIDWE4K7, | |||
| 'SuperResIDWE6K7': SuperResIDWE6K7, | |||
| } | |||
| netblocks_dict.update(this_py_file_netblocks_dict) | |||
| return netblocks_dict | |||
| @@ -0,0 +1,238 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. | |||
| import uuid | |||
| from torch import nn | |||
| from . import basic_blocks, global_utils | |||
| from .global_utils import get_right_parentheses_index | |||
| from .super_blocks import PlainNetSuperBlockClass | |||
| class SuperResK1KXK1(PlainNetSuperBlockClass): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| kernel_size=None, | |||
| no_create=False, | |||
| no_reslink=False, | |||
| no_BN=False, | |||
| use_se=False, | |||
| **kwargs): | |||
| super(SuperResK1KXK1, self).__init__(**kwargs) | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.stride = stride | |||
| self.bottleneck_channels = bottleneck_channels | |||
| self.sub_layers = sub_layers | |||
| self.kernel_size = kernel_size | |||
| self.no_create = no_create | |||
| self.no_reslink = no_reslink | |||
| self.no_BN = no_BN | |||
| self.use_se = use_se | |||
| full_str = '' | |||
| last_channels = in_channels | |||
| current_stride = stride | |||
| for i in range(self.sub_layers): | |||
| inner_str = '' | |||
| inner_str += 'ConvKX({},{},{},{})'.format(last_channels, | |||
| self.bottleneck_channels, | |||
| 1, 1) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(self.bottleneck_channels) | |||
| inner_str += 'RELU({})'.format(self.bottleneck_channels) | |||
| inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, | |||
| self.bottleneck_channels, | |||
| self.kernel_size, | |||
| current_stride) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(self.bottleneck_channels) | |||
| inner_str += 'RELU({})'.format(self.bottleneck_channels) | |||
| if self.use_se: | |||
| inner_str += 'SE({})'.format(bottleneck_channels) | |||
| inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, | |||
| self.out_channels, 1, 1) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(self.out_channels) | |||
| if not self.no_reslink: | |||
| if i == 0: | |||
| res_str = 'ResBlockProj({})RELU({})'.format( | |||
| inner_str, out_channels) | |||
| else: | |||
| res_str = 'ResBlock({})RELU({})'.format( | |||
| inner_str, out_channels) | |||
| else: | |||
| res_str = '{}RELU({})'.format(inner_str, out_channels) | |||
| full_str += res_str | |||
| inner_str = '' | |||
| inner_str += 'ConvKX({},{},{},{})'.format(self.out_channels, | |||
| self.bottleneck_channels, | |||
| 1, 1) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(self.bottleneck_channels) | |||
| inner_str += 'RELU({})'.format(self.bottleneck_channels) | |||
| inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, | |||
| self.bottleneck_channels, | |||
| self.kernel_size, 1) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(self.bottleneck_channels) | |||
| inner_str += 'RELU({})'.format(self.bottleneck_channels) | |||
| if self.use_se: | |||
| inner_str += 'SE({})'.format(bottleneck_channels) | |||
| inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, | |||
| self.out_channels, 1, 1) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(self.out_channels) | |||
| if not self.no_reslink: | |||
| res_str = 'ResBlock({})RELU({})'.format( | |||
| inner_str, out_channels) | |||
| else: | |||
| res_str = '{}RELU({})'.format(inner_str, out_channels) | |||
| full_str += res_str | |||
| last_channels = out_channels | |||
| current_stride = 1 | |||
| pass | |||
| netblocks_dict = basic_blocks.register_netblocks_dict({}) | |||
| self.block_list = global_utils.create_netblock_list_from_str( | |||
| full_str, | |||
| netblocks_dict=netblocks_dict, | |||
| no_create=no_create, | |||
| no_reslink=no_reslink, | |||
| no_BN=no_BN, | |||
| **kwargs) | |||
| if not no_create: | |||
| self.module_list = nn.ModuleList(self.block_list) | |||
| else: | |||
| self.module_list = None | |||
| def __str__(self): | |||
| return type(self).__name__ + '({},{},{},{},{})'.format( | |||
| self.in_channels, self.out_channels, self.stride, | |||
| self.bottleneck_channels, self.sub_layers) | |||
| def __repr__(self): | |||
| return type( | |||
| self | |||
| ).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( | |||
| self.block_name, self.in_channels, self.out_channels, self.stride, | |||
| self.bottleneck_channels, self.sub_layers, self.kernel_size) | |||
| @classmethod | |||
| def create_from_str(cls, s, **kwargs): | |||
| assert cls.is_instance_from_str(s) | |||
| idx = get_right_parentheses_index(s) | |||
| assert idx is not None | |||
| param_str = s[len(cls.__name__ + '('):idx] | |||
| tmp_idx = param_str.find('|') | |||
| if tmp_idx < 0: | |||
| tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) | |||
| else: | |||
| tmp_block_name = param_str[0:tmp_idx] | |||
| param_str = param_str[tmp_idx + 1:] | |||
| param_str_split = param_str.split(',') | |||
| in_channels = int(param_str_split[0]) | |||
| out_channels = int(param_str_split[1]) | |||
| stride = int(param_str_split[2]) | |||
| bottleneck_channels = int(param_str_split[3]) | |||
| sub_layers = int(param_str_split[4]) | |||
| return cls( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| block_name=tmp_block_name, | |||
| **kwargs), s[idx + 1:] | |||
| class SuperResK1K3K1(SuperResK1KXK1): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResK1K3K1, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=3, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResK1K5K1(SuperResK1KXK1): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResK1K5K1, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=5, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResK1K7K1(SuperResK1KXK1): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResK1K7K1, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=7, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| def register_netblocks_dict(netblocks_dict: dict): | |||
| this_py_file_netblocks_dict = { | |||
| 'SuperResK1K3K1': SuperResK1K3K1, | |||
| 'SuperResK1K5K1': SuperResK1K5K1, | |||
| 'SuperResK1K7K1': SuperResK1K7K1, | |||
| } | |||
| netblocks_dict.update(this_py_file_netblocks_dict) | |||
| return netblocks_dict | |||
| @@ -0,0 +1,202 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # The ZenNAS implementation is also open-sourced by the authors, and available at https://github.com/idstcv/ZenNAS. | |||
| import uuid | |||
| from torch import nn | |||
| from . import basic_blocks, global_utils | |||
| from .global_utils import get_right_parentheses_index | |||
| from .super_blocks import PlainNetSuperBlockClass | |||
| class SuperResKXKX(PlainNetSuperBlockClass): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| kernel_size=None, | |||
| no_create=False, | |||
| no_reslink=False, | |||
| no_BN=False, | |||
| use_se=False, | |||
| **kwargs): | |||
| super(SuperResKXKX, self).__init__(**kwargs) | |||
| self.in_channels = in_channels | |||
| self.out_channels = out_channels | |||
| self.stride = stride | |||
| self.bottleneck_channels = bottleneck_channels | |||
| self.sub_layers = sub_layers | |||
| self.kernel_size = kernel_size | |||
| self.no_create = no_create | |||
| self.no_reslink = no_reslink | |||
| self.no_BN = no_BN | |||
| self.use_se = use_se | |||
| full_str = '' | |||
| last_channels = in_channels | |||
| current_stride = stride | |||
| for i in range(self.sub_layers): | |||
| inner_str = '' | |||
| inner_str += 'ConvKX({},{},{},{})'.format(last_channels, | |||
| self.bottleneck_channels, | |||
| self.kernel_size, | |||
| current_stride) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(self.bottleneck_channels) | |||
| inner_str += 'RELU({})'.format(self.bottleneck_channels) | |||
| if self.use_se: | |||
| inner_str += 'SE({})'.format(bottleneck_channels) | |||
| inner_str += 'ConvKX({},{},{},{})'.format(self.bottleneck_channels, | |||
| self.out_channels, | |||
| self.kernel_size, 1) | |||
| if not self.no_BN: | |||
| inner_str += 'BN({})'.format(self.out_channels) | |||
| if not self.no_reslink: | |||
| if i == 0: | |||
| res_str = 'ResBlockProj({})RELU({})'.format( | |||
| inner_str, out_channels) | |||
| else: | |||
| res_str = 'ResBlock({})RELU({})'.format( | |||
| inner_str, out_channels) | |||
| else: | |||
| res_str = '{}RELU({})'.format(inner_str, out_channels) | |||
| full_str += res_str | |||
| last_channels = out_channels | |||
| current_stride = 1 | |||
| pass | |||
| netblocks_dict = basic_blocks.register_netblocks_dict({}) | |||
| self.block_list = global_utils.create_netblock_list_from_str( | |||
| full_str, | |||
| netblocks_dict=netblocks_dict, | |||
| no_create=no_create, | |||
| no_reslink=no_reslink, | |||
| no_BN=no_BN, | |||
| **kwargs) | |||
| if not no_create: | |||
| self.module_list = nn.ModuleList(self.block_list) | |||
| else: | |||
| self.module_list = None | |||
| def __str__(self): | |||
| return type(self).__name__ + '({},{},{},{},{})'.format( | |||
| self.in_channels, self.out_channels, self.stride, | |||
| self.bottleneck_channels, self.sub_layers) | |||
| def __repr__(self): | |||
| return type( | |||
| self | |||
| ).__name__ + '({}|in={},out={},stride={},btl_channels={},sub_layers={},kernel_size={})'.format( | |||
| self.block_name, self.in_channels, self.out_channels, self.stride, | |||
| self.bottleneck_channels, self.sub_layers, self.kernel_size) | |||
| @classmethod | |||
| def create_from_str(cls, s, **kwargs): | |||
| assert cls.is_instance_from_str(s) | |||
| idx = get_right_parentheses_index(s) | |||
| assert idx is not None | |||
| param_str = s[len(cls.__name__ + '('):idx] | |||
| tmp_idx = param_str.find('|') | |||
| if tmp_idx < 0: | |||
| tmp_block_name = 'uuid{}'.format(uuid.uuid4().hex) | |||
| else: | |||
| tmp_block_name = param_str[0:tmp_idx] | |||
| param_str = param_str[tmp_idx + 1:] | |||
| param_str_split = param_str.split(',') | |||
| in_channels = int(param_str_split[0]) | |||
| out_channels = int(param_str_split[1]) | |||
| stride = int(param_str_split[2]) | |||
| bottleneck_channels = int(param_str_split[3]) | |||
| sub_layers = int(param_str_split[4]) | |||
| return cls( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| block_name=tmp_block_name, | |||
| **kwargs), s[idx + 1:] | |||
| class SuperResK3K3(SuperResKXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResK3K3, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=3, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResK5K5(SuperResKXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResK5K5, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=5, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| class SuperResK7K7(SuperResKXKX): | |||
| def __init__(self, | |||
| in_channels=None, | |||
| out_channels=None, | |||
| stride=None, | |||
| bottleneck_channels=None, | |||
| sub_layers=None, | |||
| no_create=False, | |||
| **kwargs): | |||
| super(SuperResK7K7, self).__init__( | |||
| in_channels=in_channels, | |||
| out_channels=out_channels, | |||
| stride=stride, | |||
| bottleneck_channels=bottleneck_channels, | |||
| sub_layers=sub_layers, | |||
| kernel_size=7, | |||
| no_create=no_create, | |||
| **kwargs) | |||
| def register_netblocks_dict(netblocks_dict: dict): | |||
| this_py_file_netblocks_dict = { | |||
| 'SuperResK3K3': SuperResK3K3, | |||
| 'SuperResK5K5': SuperResK5K5, | |||
| 'SuperResK7K7': SuperResK7K7, | |||
| } | |||
| netblocks_dict.update(this_py_file_netblocks_dict) | |||
| return netblocks_dict | |||
| @@ -30,6 +30,7 @@ if TYPE_CHECKING: | |||
| from .live_category_pipeline import LiveCategoryPipeline | |||
| from .ocr_detection_pipeline import OCRDetectionPipeline | |||
| from .skin_retouching_pipeline import SkinRetouchingPipeline | |||
| from .tinynas_classification_pipeline import TinynasClassificationPipeline | |||
| from .video_category_pipeline import VideoCategoryPipeline | |||
| from .virtual_try_on_pipeline import VirtualTryonPipeline | |||
| else: | |||
| @@ -65,6 +66,7 @@ else: | |||
| ['Image2ImageGenerationPipeline'], | |||
| 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | |||
| 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], | |||
| 'tinynas_classification_pipeline': ['TinynasClassificationPipeline'], | |||
| 'video_category_pipeline': ['VideoCategoryPipeline'], | |||
| 'virtual_try_on_pipeline': ['VirtualTryonPipeline'], | |||
| } | |||
| @@ -0,0 +1,96 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import math | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import torch | |||
| from torchvision import transforms | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.tinynas_classfication import get_zennet | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_classification, module_name=Pipelines.tinynas_classification) | |||
| class TinynasClassificationPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create a tinynas classification pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| self.path = model | |||
| self.model = get_zennet() | |||
| model_pth_path = osp.join(self.path, ModelFile.TORCH_MODEL_FILE) | |||
| checkpoint = torch.load(model_pth_path, map_location='cpu') | |||
| if 'state_dict' in checkpoint: | |||
| state_dict = checkpoint['state_dict'] | |||
| else: | |||
| state_dict = checkpoint | |||
| self.model.load_state_dict(state_dict, strict=True) | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| img = LoadImage.convert_to_img(input) | |||
| input_image_size = 224 | |||
| crop_image_size = 380 | |||
| input_image_crop = 0.875 | |||
| resize_image_size = int(math.ceil(crop_image_size / input_image_crop)) | |||
| transforms_normalize = transforms.Normalize( | |||
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||
| transform_list = [ | |||
| transforms.Resize( | |||
| resize_image_size, | |||
| interpolation=transforms.InterpolationMode.BICUBIC), | |||
| transforms.CenterCrop(crop_image_size), | |||
| transforms.ToTensor(), transforms_normalize | |||
| ] | |||
| transformer = transforms.Compose(transform_list) | |||
| img = transformer(img) | |||
| img = torch.unsqueeze(img, 0) | |||
| img = torch.nn.functional.interpolate( | |||
| img, input_image_size, mode='bilinear') | |||
| result = {'img': img} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| is_train = False | |||
| if is_train: | |||
| self.model.train() | |||
| else: | |||
| self.model.eval() | |||
| outputs = self.model(input['img']) | |||
| return {'outputs': outputs} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| label_mapping_path = osp.join(self.path, 'label_map.txt') | |||
| f = open(label_mapping_path) | |||
| content = f.read() | |||
| f.close() | |||
| label_dict = eval(content) | |||
| output_prob = torch.nn.functional.softmax(inputs['outputs'], dim=-1) | |||
| score = torch.max(output_prob) | |||
| output_dict = { | |||
| OutputKeys.SCORES: score.item(), | |||
| OutputKeys.LABELS: label_dict[inputs['outputs'].argmax().item()] | |||
| } | |||
| return output_dict | |||
| @@ -0,0 +1,19 @@ | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class TinyNASClassificationTest(unittest.TestCase): | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run(self): | |||
| tinynas_classification = pipeline( | |||
| Tasks.image_classification, model='damo/cv_tinynas_classification') | |||
| result = tinynas_classification('data/test/images/image_wolf.jpeg') | |||
| print(result) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||