| @@ -77,6 +77,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil | |||||
| │ ├──utils.py # utils to load ckpt_file for fine tune or incremental learn | │ ├──utils.py # utils to load ckpt_file for fine tune or incremental learn | ||||
| ├── train.py # training script | ├── train.py # training script | ||||
| ├── eval.py # evaluation script | ├── eval.py # evaluation script | ||||
| ├── mindspore_hub_conf.py # mindspore hub interface | |||||
| ``` | ``` | ||||
| ## [Training process](#contents) | ## [Training process](#contents) | ||||
| @@ -119,7 +119,7 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True): | |||||
| for param in network.get_parameters(): | for param in network.get_parameters(): | ||||
| param.requires_grad = False | param.requires_grad = False | ||||
| def define_net(config, is_training): | |||||
| def define_net(config, is_training=True): | |||||
| backbone_net = MobileNetV2Backbone() | backbone_net = MobileNetV2Backbone() | ||||
| activation = config.activation if not is_training else "None" | activation = config.activation if not is_training else "None" | ||||
| head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, | head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, | ||||
| @@ -69,6 +69,7 @@ Dataset used: [imagenet](http://www.image-net.org/) | |||||
| │ ├──mobilenetV3.py # MobileNetV3 architecture | │ ├──mobilenetV3.py # MobileNetV3 architecture | ||||
| ├── train.py # training script | ├── train.py # training script | ||||
| ├── eval.py # evaluation script | ├── eval.py # evaluation script | ||||
| ├── mindspore_hub_conf.py # mindspore hub interface | |||||
| ``` | ``` | ||||
| ## [Training process](#contents) | ## [Training process](#contents) | ||||
| @@ -42,7 +42,7 @@ if __name__ == '__main__': | |||||
| raise ValueError("Unsupported device_target.") | raise ValueError("Unsupported device_target.") | ||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | ||||
| net = mobilenet_v3_large(num_classes=config.num_classes) | |||||
| net = mobilenet_v3_large(num_classes=config.num_classes, activation="Softmax") | |||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, | dataset = create_dataset(dataset_path=args_opt.dataset_path, | ||||
| do_train=False, | do_train=False, | ||||
| @@ -0,0 +1,25 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """hub config.""" | |||||
| from src.mobilenetV3 import mobilenet_v3_large, mobilenet_v3_small | |||||
| def create_network(name, *args, **kwargs): | |||||
| if name == "mobilenetv3_large": | |||||
| net = mobilenet_v3_large(*args, **kwargs) | |||||
| elif name == "mobilenetv3_small": | |||||
| net = mobilenet_v3_small(*args, **kwargs) | |||||
| else: | |||||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||||
| return net | |||||
| @@ -246,7 +246,8 @@ class MobileNetV3(nn.Cell): | |||||
| >>> MobileNetV3(num_classes=1000) | >>> MobileNetV3(num_classes=1000) | ||||
| """ | """ | ||||
| def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8): | |||||
| def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., | |||||
| round_nearest=8, include_top=True, activation="None"): | |||||
| super(MobileNetV3, self).__init__() | super(MobileNetV3, self).__init__() | ||||
| self.cfgs = model_cfgs['cfg'] | self.cfgs = model_cfgs['cfg'] | ||||
| self.inplanes = 16 | self.inplanes = 16 | ||||
| @@ -285,19 +286,34 @@ class MobileNetV3(nn.Cell): | |||||
| # make it nn.CellList | # make it nn.CellList | ||||
| self.features = nn.SequentialCell(self.features) | self.features = nn.SequentialCell(self.features) | ||||
| self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'], | |||||
| out_channels=num_classes, | |||||
| kernel_size=1, has_bias=True, pad_mode='pad') | |||||
| self.squeeze = P.Squeeze(axis=(2, 3)) | |||||
| self.include_top = include_top | |||||
| self.need_activation = False | |||||
| if self.include_top: | |||||
| self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'], | |||||
| out_channels=num_classes, | |||||
| kernel_size=1, has_bias=True, pad_mode='pad') | |||||
| self.squeeze = P.Squeeze(axis=(2, 3)) | |||||
| if activation != "None": | |||||
| self.need_activation = True | |||||
| if activation == "Sigmoid": | |||||
| self.activation = P.Sigmoid() | |||||
| elif activation == "Softmax": | |||||
| self.activation = P.Softmax() | |||||
| else: | |||||
| raise NotImplementedError(f"The activation {activation} not in [Sigmoid, Softmax].") | |||||
| self._initialize_weights() | self._initialize_weights() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.features(x) | x = self.features(x) | ||||
| x = self.output(x) | |||||
| x = self.squeeze(x) | |||||
| if self.include_top: | |||||
| x = self.output(x) | |||||
| x = self.squeeze(x) | |||||
| if self.need_activation: | |||||
| x = self.activation(x) | |||||
| return x | return x | ||||
| def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1): | def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1): | ||||
| mid_planes = exp_ch | mid_planes = exp_ch | ||||
| out_planes = out_channel | out_planes = out_channel | ||||
| @@ -96,7 +96,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil | |||||
| ├─warmup_cosine_annealing.py # learning rate each step | ├─warmup_cosine_annealing.py # learning rate each step | ||||
| ├─warmup_step_lr.py # warmup step learning rate | ├─warmup_step_lr.py # warmup step learning rate | ||||
| ├─eval.py # eval net | ├─eval.py # eval net | ||||
| └─train.py # train net | |||||
| ├──train.py # train net | |||||
| ├──mindspore_hub_conf.py # mindspore hub interface | |||||
| ``` | ``` | ||||
| @@ -201,7 +201,7 @@ def test(cloud_args=None): | |||||
| max_epoch=1, rank=args.rank, group_size=args.group_size, | max_epoch=1, rank=args.rank, group_size=args.group_size, | ||||
| mode='eval') | mode='eval') | ||||
| eval_dataloader = de_dataset.create_tuple_iterator(output_numpy=True) | eval_dataloader = de_dataset.create_tuple_iterator(output_numpy=True) | ||||
| network = get_network(args.backbone, args.num_classes, platform=args.platform) | |||||
| network = get_network(args.backbone, num_classes=args.num_classes, platform=args.platform) | |||||
| if network is None: | if network is None: | ||||
| raise NotImplementedError('not implement {}'.format(args.backbone)) | raise NotImplementedError('not implement {}'.format(args.backbone)) | ||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """hub config.""" | |||||
| from src.image_classification import get_network | |||||
| def create_network(name, *args, **kwargs): | |||||
| if name == "renext50": | |||||
| get_network("renext50", *args, **kwargs) | |||||
| return net | |||||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||||
| @@ -31,31 +31,46 @@ class ImageClassificationNetwork(nn.Cell): | |||||
| Returns: | Returns: | ||||
| Tensor, output tensor. | Tensor, output tensor. | ||||
| """ | """ | ||||
| def __init__(self, backbone, head): | |||||
| def __init__(self, backbone, head, include_top=True, activation="None"): | |||||
| super(ImageClassificationNetwork, self).__init__() | super(ImageClassificationNetwork, self).__init__() | ||||
| self.backbone = backbone | self.backbone = backbone | ||||
| self.head = head | |||||
| self.include_top = include_top | |||||
| self.need_activation = False | |||||
| if self.include_top: | |||||
| self.head = head | |||||
| if activation != "None": | |||||
| self.need_activation = True | |||||
| if activation == "Sigmoid": | |||||
| self.activation = P.Sigmoid() | |||||
| elif activation == "Softmax": | |||||
| self.activation = P.Softmax() | |||||
| else: | |||||
| raise NotImplementedError(f"The activation {activation} not in [Sigmoid, Softmax].") | |||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.backbone(x) | x = self.backbone(x) | ||||
| x = self.head(x) | |||||
| if self.include_top: | |||||
| x = self.head(x) | |||||
| if self.need_activation: | |||||
| x = self.activation(x) | |||||
| return x | return x | ||||
| class Resnet(ImageClassificationNetwork): | class Resnet(ImageClassificationNetwork): | ||||
| """ | """ | ||||
| Resnet architecture. | Resnet architecture. | ||||
| Args: | Args: | ||||
| backbone_name (string): backbone. | backbone_name (string): backbone. | ||||
| num_classes (int): number of classes. | |||||
| num_classes (int): number of classes, Default is 1000. | |||||
| Returns: | Returns: | ||||
| Resnet. | Resnet. | ||||
| """ | """ | ||||
| def __init__(self, backbone_name, num_classes, platform="Ascend"): | |||||
| def __init__(self, backbone_name, num_classes=1000, platform="Ascend", include_top=True, activation="None"): | |||||
| self.backbone_name = backbone_name | self.backbone_name = backbone_name | ||||
| backbone = backbones.__dict__[self.backbone_name](platform=platform) | backbone = backbones.__dict__[self.backbone_name](platform=platform) | ||||
| out_channels = backbone.get_out_channels() | out_channels = backbone.get_out_channels() | ||||
| head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels) | head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels) | ||||
| super(Resnet, self).__init__(backbone, head) | |||||
| super(Resnet, self).__init__(backbone, head, include_top, activation) | |||||
| default_recurisive_init(self) | default_recurisive_init(self) | ||||
| @@ -79,7 +94,7 @@ class Resnet(ImageClassificationNetwork): | |||||
| def get_network(backbone_name, num_classes, platform="Ascend"): | |||||
| def get_network(backbone_name, **kwargs): | |||||
| if backbone_name in ['resnext50']: | if backbone_name in ['resnext50']: | ||||
| return Resnet(backbone_name, num_classes, platform) | |||||
| return Resnet(backbone_name, **kwargs) | |||||
| return None | return None | ||||
| @@ -213,7 +213,7 @@ def train(cloud_args=None): | |||||
| # network | # network | ||||
| args.logger.important_info('start create network') | args.logger.important_info('start create network') | ||||
| # get network and init | # get network and init | ||||
| network = get_network(args.backbone, args.num_classes, platform=args.platform) | |||||
| network = get_network(args.backbone, num_classes=args.num_classes, platform=args.platform) | |||||
| if network is None: | if network is None: | ||||
| raise NotImplementedError('not implement {}'.format(args.backbone)) | raise NotImplementedError('not implement {}'.format(args.backbone)) | ||||
| @@ -114,7 +114,8 @@ sh run_eval.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] | |||||
| ├─ lr_schedule.py ## learning ratio generator | ├─ lr_schedule.py ## learning ratio generator | ||||
| └─ ssd.py ## ssd architecture | └─ ssd.py ## ssd architecture | ||||
| ├─ eval.py ## eval scripts | ├─ eval.py ## eval scripts | ||||
| └─ train.py ## train scripts | |||||
| ├─ train.py ## train scripts | |||||
| ├── mindspore_hub_conf.py # mindspore hub interface | |||||
| ``` | ``` | ||||
| ## [Script Parameters](#contents) | ## [Script Parameters](#contents) | ||||