Added Readme.md Fixed PyLint Errors Fixed PyLint Errors-2 Fixed PyLint Errors-3 Fixed PyLint Errors-4 Fixed PyLint Errors-5 Fixed PyLint Errors-6 Fixed PyLint Errors-7 Update eval.py Updated ShuffleNetV2 model Fixed PyLint Error Fixed PyLint Error #2 Fixed PyLint Error #3 Applied Comments Fixed PyLint Fixed PyLint #2tags/v1.0.0
| @@ -0,0 +1,119 @@ | |||||
| # Contents | |||||
| - [ShuffleNetV2 Description](#shufflenetv2-description) | |||||
| - [Model Architecture](#model-architecture) | |||||
| - [Dataset](#dataset) | |||||
| - [Environment Requirements](#environment-requirements) | |||||
| - [Script Description](#script-description) | |||||
| - [Script and Sample Code](#script-and-sample-code) | |||||
| - [Training Process](#training-process) | |||||
| - [Evaluation Process](#evaluation-process) | |||||
| - [Evaluation](#evaluation) | |||||
| - [Model Description](#model-description) | |||||
| - [Performance](#performance) | |||||
| - [Training Performance](#evaluation-performance) | |||||
| - [Inference Performance](#evaluation-performance) | |||||
| # [ShuffleNetV2 Description](#contents) | |||||
| ShuffleNetV2 is a much faster and more accurate netowrk than the previous networks on different platforms such as Ascend or GPU. | |||||
| [Paper](https://arxiv.org/pdf/1807.11164.pdf) Ma, N., Zhang, X., Zheng, H. T., & Sun, J. (2018). Shufflenet v2: Practical guidelines for efficient cnn architecture design. In Proceedings of the European conference on computer vision (ECCV) (pp. 116-131). | |||||
| # [Model architecture](#contents) | |||||
| The overall network architecture of ShuffleNetV2 is show below: | |||||
| [Link](https://arxiv.org/pdf/1807.11164.pdf) | |||||
| # [Dataset](#contents) | |||||
| Dataset used: [imagenet](http://www.image-net.org/) | |||||
| - Dataset size: ~125G, 1.2W colorful images in 1000 classes | |||||
| - Train: 120G, 1.2W images | |||||
| - Test: 5G, 50000 images | |||||
| - Data format: RGB images. | |||||
| - Note: Data will be processed in src/dataset.py | |||||
| # [Environment Requirements](#contents) | |||||
| - Hardware(GPU) | |||||
| - Prepare hardware environment with GPU processor. | |||||
| - Framework | |||||
| - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) | |||||
| - For more information, please check the resources below: | |||||
| - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) | |||||
| - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) | |||||
| # [Script description](#contents) | |||||
| ## [Script and sample code](#contents) | |||||
| ```python | |||||
| +-- ShuffleNetV2 | |||||
| +-- Readme.md # descriptions about ShuffleNetV2 | |||||
| +-- scripts | |||||
| ¦ +--run_distribute_train_for_gpu.sh # shell script for distributed training | |||||
| ¦ +--run_eval_for_multi_gpu.sh # shell script for evaluation | |||||
| ¦ +--run_standalone_train_for_gpu.sh # shell script for standalone training | |||||
| +-- src | |||||
| ¦ +--config.py # parameter configuration | |||||
| ¦ +--dataset.py # creating dataset | |||||
| ¦ +--loss.py # loss function for network | |||||
| ¦ +--lr_generator.py # learning rate config | |||||
| +-- train.py # training script | |||||
| +-- eval.py # evaluation script | |||||
| +-- blocks.py # ShuffleNetV2 blocks | |||||
| +-- network.py # ShuffleNetV2 model network | |||||
| ``` | |||||
| ## [Training process](#contents) | |||||
| ### Usage | |||||
| You can start training using python or shell scripts. The usage of shell scripts as follows: | |||||
| - Ditributed training on GPU: sh run_distribute_train_for_gpu.sh [DATA_DIR] | |||||
| - Standalone training on GPU: sh run_standalone_train_for_gpu.sh [DEVICE_ID] [DATA_DIR] | |||||
| ### Launch | |||||
| ``` | |||||
| # training example | |||||
| python: | |||||
| GPU: mpirun --allow-run-as-root -n 8 python train.py --is_distributed --platform 'GPU' --dataset_path '~/imagenet/train/' > train.log 2>&1 & | |||||
| shell: | |||||
| GPU: sh run_distribute_train_for_gpu.sh ~/imagenet/train/ | |||||
| ``` | |||||
| ### Result | |||||
| Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log`. | |||||
| ## [Eval process](#contents) | |||||
| ### Usage | |||||
| You can start evaluation using python or shell scripts. The usage of shell scripts as follows: | |||||
| - GPU: sh run_eval_for_multi_gpu.sh [DEVICE_ID] [EPOCH] | |||||
| ### Launch | |||||
| ``` | |||||
| # infer example | |||||
| python: | |||||
| GPU: CUDA_VISIBLE_DEVICES=0 python eval.py --platform 'GPU' --dataset_path '~/imagenet/val/' --epoch 250 > eval.log 2>&1 & | |||||
| shell: | |||||
| GPU: sh run_eval_for_multi_gpu.sh 0 250 | |||||
| ``` | |||||
| > checkpoint can be produced in training process. | |||||
| ### Result | |||||
| Inference result will be stored in the example path, you can find result in `val.log`. | |||||
| @@ -0,0 +1,83 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops.operations as P | |||||
| class ShuffleV2Block(nn.Cell): | |||||
| def __init__(self, inp, oup, mid_channels, *, ksize, stride): | |||||
| super(ShuffleV2Block, self).__init__() | |||||
| self.stride = stride | |||||
| ##assert stride in [1, 2] | |||||
| self.mid_channels = mid_channels | |||||
| self.ksize = ksize | |||||
| pad = ksize // 2 | |||||
| self.pad = pad | |||||
| self.inp = inp | |||||
| outputs = oup - inp | |||||
| branch_main = [ | |||||
| # pw | |||||
| nn.Conv2d(in_channels=inp, out_channels=mid_channels, kernel_size=1, stride=1, | |||||
| pad_mode='pad', padding=0, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), | |||||
| nn.ReLU(), | |||||
| # dw | |||||
| nn.Conv2d(in_channels=mid_channels, out_channels=mid_channels, kernel_size=ksize, stride=stride, | |||||
| pad_mode='pad', padding=pad, group=mid_channels, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=mid_channels, momentum=0.9), | |||||
| # pw-linear | |||||
| nn.Conv2d(in_channels=mid_channels, out_channels=outputs, kernel_size=1, stride=1, | |||||
| pad_mode='pad', padding=0, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=outputs, momentum=0.9), | |||||
| nn.ReLU(), | |||||
| ] | |||||
| self.branch_main = nn.SequentialCell(branch_main) | |||||
| if stride == 2: | |||||
| branch_proj = [ | |||||
| # dw | |||||
| nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=ksize, stride=stride, | |||||
| pad_mode='pad', padding=pad, group=inp, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=inp, momentum=0.9), | |||||
| # pw-linear | |||||
| nn.Conv2d(in_channels=inp, out_channels=inp, kernel_size=1, stride=1, | |||||
| pad_mode='pad', padding=0, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=inp, momentum=0.9), | |||||
| nn.ReLU(), | |||||
| ] | |||||
| self.branch_proj = nn.SequentialCell(branch_proj) | |||||
| else: | |||||
| self.branch_proj = None | |||||
| def construct(self, old_x): | |||||
| if self.stride == 1: | |||||
| x_proj, x = self.channel_shuffle(old_x) | |||||
| return P.Concat(1)((x_proj, self.branch_main(x))) | |||||
| if self.stride == 2: | |||||
| x_proj = old_x | |||||
| x = old_x | |||||
| return P.Concat(1)((self.branch_proj(x_proj), self.branch_main(x))) | |||||
| return None | |||||
| def channel_shuffle(self, x): | |||||
| batchsize, num_channels, height, width = P.Shape()(x) | |||||
| ##assert (num_channels % 4 == 0) | |||||
| x = P.Reshape()(x, (batchsize * num_channels // 2, 2, height * width,)) | |||||
| x = P.Transpose()(x, (1, 0, 2,)) | |||||
| x = P.Reshape()(x, (2, -1, num_channels // 2, height, width,)) | |||||
| return x[0], x[1] | |||||
| @@ -0,0 +1,54 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """evaluate_imagenet""" | |||||
| import argparse | |||||
| import os | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.config import config_gpu as cfg | |||||
| from src.dataset import create_dataset | |||||
| from network import ShuffleNetV2 | |||||
| if __name__ == '__main__': | |||||
| parser = argparse.ArgumentParser(description='image classification evaluation') | |||||
| parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of ShuffleNetV2 (Default: None)') | |||||
| parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') | |||||
| parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') | |||||
| parser.add_argument('--epoch', type=str, default='') | |||||
| args_opt = parser.parse_args() | |||||
| if args_opt.platform == 'Ascend': | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(device_id=device_id) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, device_id=0) | |||||
| net = ShuffleNetV2(n_class=cfg.num_classes) | |||||
| ckpt = load_checkpoint(args_opt.checkpoint) | |||||
| load_param_into_net(net, ckpt) | |||||
| net.set_train(False) | |||||
| dataset = create_dataset(args_opt.dataset_path, cfg, False) | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False, | |||||
| smooth_factor=0.1, num_classes=cfg.num_classes) | |||||
| eval_metrics = {'Loss': nn.Loss(), | |||||
| 'Top1-Acc': nn.Top1CategoricalAccuracy(), | |||||
| 'Top5-Acc': nn.Top5CategoricalAccuracy()} | |||||
| model = Model(net, loss, optimizer=None, metrics=eval_metrics) | |||||
| metrics = model.eval(dataset) | |||||
| print("metric: ", metrics) | |||||
| @@ -0,0 +1,108 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| from blocks import ShuffleV2Block | |||||
| from mindspore import Tensor | |||||
| import mindspore.nn as nn | |||||
| import mindspore.ops.operations as P | |||||
| class ShuffleNetV2(nn.Cell): | |||||
| def __init__(self, input_size=224, n_class=1000, model_size='1.0x'): | |||||
| super(ShuffleNetV2, self).__init__() | |||||
| print('model size is ', model_size) | |||||
| self.stage_repeats = [4, 8, 4] | |||||
| self.model_size = model_size | |||||
| if model_size == '0.5x': | |||||
| self.stage_out_channels = [-1, 24, 48, 96, 192, 1024] | |||||
| elif model_size == '1.0x': | |||||
| self.stage_out_channels = [-1, 24, 116, 232, 464, 1024] | |||||
| elif model_size == '1.5x': | |||||
| self.stage_out_channels = [-1, 24, 176, 352, 704, 1024] | |||||
| elif model_size == '2.0x': | |||||
| self.stage_out_channels = [-1, 24, 244, 488, 976, 2048] | |||||
| else: | |||||
| raise NotImplementedError | |||||
| # building first layer | |||||
| input_channel = self.stage_out_channels[1] | |||||
| self.first_conv = nn.SequentialCell([ | |||||
| nn.Conv2d(in_channels=3, out_channels=input_channel, kernel_size=3, stride=2, | |||||
| pad_mode='pad', padding=1, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=input_channel, momentum=0.9), | |||||
| nn.ReLU(), | |||||
| ]) | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') | |||||
| self.features = [] | |||||
| for idxstage in range(len(self.stage_repeats)): | |||||
| numrepeat = self.stage_repeats[idxstage] | |||||
| output_channel = self.stage_out_channels[idxstage+2] | |||||
| for i in range(numrepeat): | |||||
| if i == 0: | |||||
| self.features.append(ShuffleV2Block(input_channel, output_channel, | |||||
| mid_channels=output_channel // 2, ksize=3, stride=2)) | |||||
| else: | |||||
| self.features.append(ShuffleV2Block(input_channel // 2, output_channel, | |||||
| mid_channels=output_channel // 2, ksize=3, stride=1)) | |||||
| input_channel = output_channel | |||||
| self.features = nn.SequentialCell([*self.features]) | |||||
| self.conv_last = nn.SequentialCell([ | |||||
| nn.Conv2d(in_channels=input_channel, out_channels=self.stage_out_channels[-1], kernel_size=1, stride=1, | |||||
| pad_mode='pad', padding=0, has_bias=False), | |||||
| nn.BatchNorm2d(num_features=self.stage_out_channels[-1], momentum=0.9), | |||||
| nn.ReLU() | |||||
| ]) | |||||
| self.globalpool = nn.AvgPool2d(kernel_size=7, stride=7, pad_mode='valid') | |||||
| if self.model_size == '2.0x': | |||||
| self.dropout = nn.Dropout(keep_prob=0.8) | |||||
| self.classifier = nn.SequentialCell([nn.Dense(in_channels=self.stage_out_channels[-1], | |||||
| out_channels=n_class, has_bias=False)]) | |||||
| ##TODO init weights | |||||
| self._initialize_weights() | |||||
| def construct(self, x): | |||||
| x = self.first_conv(x) | |||||
| x = self.maxpool(x) | |||||
| x = self.features(x) | |||||
| x = self.conv_last(x) | |||||
| x = self.globalpool(x) | |||||
| if self.model_size == '2.0x': | |||||
| x = self.dropout(x) | |||||
| x = P.Reshape()(x, (-1, self.stage_out_channels[-1],)) | |||||
| x = self.classifier(x) | |||||
| return x | |||||
| def _initialize_weights(self): | |||||
| for name, m in self.cells_and_names(): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| if 'first' in name: | |||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, | |||||
| m.weight.data.shape).astype("float32"))) | |||||
| else: | |||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, 1.0/m.weight.data.shape[1], | |||||
| m.weight.data.shape).astype("float32"))) | |||||
| if isinstance(m, nn.Dense): | |||||
| m.weight.set_parameter_data(Tensor(np.random.normal(0, 0.01, m.weight.data.shape).astype("float32"))) | |||||
| @@ -0,0 +1,17 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| DATA_DIR=$1 | |||||
| mpirun --allow-run-as-root -n 8 python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 & | |||||
| @@ -0,0 +1,18 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| DEVICE_ID=$1 | |||||
| EPOCH=$2 | |||||
| CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --platform 'GPU' --dataset_path '/home/data/ImageNet_Original/val/' --epoch $EPOCH > eval.log 2>&1 & | |||||
| @@ -0,0 +1,18 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| DEVICE_ID=$1 | |||||
| DATA_DIR=$2 | |||||
| CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 & | |||||
| @@ -0,0 +1,49 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting, will be used in main.py | |||||
| """ | |||||
| from easydict import EasyDict as edict | |||||
| config_gpu = edict({ | |||||
| 'random_seed': 1, | |||||
| 'rank': 0, | |||||
| 'group_size': 1, | |||||
| 'work_nums': 8, | |||||
| 'epoch_size': 250, | |||||
| 'keep_checkpoint_max': 100, | |||||
| 'ckpt_path': './checkpoint/', | |||||
| 'is_save_on_master': 0, | |||||
| ### Dataset Config | |||||
| 'batch_size': 128, | |||||
| 'num_classes': 1000, | |||||
| ### Loss Config | |||||
| 'label_smooth_factor': 0.1, | |||||
| 'aux_factor': 0.4, | |||||
| ### Learning Rate Config | |||||
| 'lr_init': 0.5, | |||||
| ### Optimization Config | |||||
| 'weight_decay': 0.00004, | |||||
| 'momentum': 0.9, | |||||
| 'opt_eps': 1.0, | |||||
| 'rmsprop_decay': 0.9, | |||||
| "loss_scale": 1, | |||||
| }) | |||||
| @@ -0,0 +1,81 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Data operations, will be used in train.py and eval.py | |||||
| """ | |||||
| import numpy as np | |||||
| from src.config import config_gpu as cfg | |||||
| import mindspore.common.dtype as mstype | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.dataset.transforms.c_transforms as C2 | |||||
| import mindspore.dataset.transforms.vision.c_transforms as C | |||||
| class toBGR(): | |||||
| def __call__(self, img): | |||||
| img = img[:, :, ::-1] | |||||
| img = np.ascontiguousarray(img) | |||||
| return img | |||||
| def create_dataset(dataset_path, do_train, rank, group_size, repeat_num=1): | |||||
| """ | |||||
| create a train or eval dataset | |||||
| Args: | |||||
| dataset_path(string): the path of dataset. | |||||
| do_train(bool): whether dataset is used for train or eval. | |||||
| rank (int): The shard ID within num_shards (default=None). | |||||
| group_size (int): Number of shards that the dataset should be divided into (default=None). | |||||
| repeat_num(int): the repeat times of dataset. Default: 1. | |||||
| Returns: | |||||
| dataset | |||||
| """ | |||||
| if group_size == 1: | |||||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True) | |||||
| else: | |||||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=True, | |||||
| num_shards=group_size, shard_id=rank) | |||||
| # define map operations | |||||
| if do_train: | |||||
| trans = [ | |||||
| C.RandomCropDecodeResize(224), | |||||
| C.RandomHorizontalFlip(prob=0.5), | |||||
| C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) | |||||
| ] | |||||
| else: | |||||
| trans = [ | |||||
| C.Decode(), | |||||
| C.Resize(256), | |||||
| C.CenterCrop(224) | |||||
| ] | |||||
| trans += [ | |||||
| toBGR(), | |||||
| C.Rescale(1.0 / 255.0, 0.0), | |||||
| # C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |||||
| C.HWC2CHW(), | |||||
| C2.TypeCast(mstype.float32) | |||||
| ] | |||||
| type_cast_op = C2.TypeCast(mstype.int32) | |||||
| ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=cfg.work_nums) | |||||
| ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=cfg.work_nums) | |||||
| # apply batch operations | |||||
| ds = ds.batch(cfg.batch_size, drop_remainder=True) | |||||
| # apply dataset repeat operation | |||||
| ds = ds.repeat(repeat_num) | |||||
| return ds | |||||
| @@ -0,0 +1,60 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """define loss function for network.""" | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.nn.loss.loss import _Loss | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore import Tensor | |||||
| import mindspore.nn as nn | |||||
| class CrossEntropy(_Loss): | |||||
| """the redefined loss function with SoftmaxCrossEntropyWithLogits""" | |||||
| def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4): | |||||
| super(CrossEntropy, self).__init__() | |||||
| self.factor = factor | |||||
| self.onehot = P.OneHot() | |||||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||||
| self.ce = nn.SoftmaxCrossEntropyWithLogits() | |||||
| self.mean = P.ReduceMean(False) | |||||
| def construct(self, logits, label): | |||||
| logit, aux = logits | |||||
| one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) | |||||
| loss_logit = self.ce(logit, one_hot_label) | |||||
| loss_logit = self.mean(loss_logit, 0) | |||||
| one_hot_label_aux = self.onehot(label, F.shape(aux)[1], self.on_value, self.off_value) | |||||
| loss_aux = self.ce(aux, one_hot_label_aux) | |||||
| loss_aux = self.mean(loss_aux, 0) | |||||
| return loss_logit + self.factor*loss_aux | |||||
| class CrossEntropy_Val(_Loss): | |||||
| """the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process""" | |||||
| def __init__(self, smooth_factor=0, num_classes=1000): | |||||
| super(CrossEntropy_Val, self).__init__() | |||||
| self.onehot = P.OneHot() | |||||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||||
| self.ce = nn.SoftmaxCrossEntropyWithLogits() | |||||
| self.mean = P.ReduceMean(False) | |||||
| def construct(self, logits, label): | |||||
| one_hot_label = self.onehot(label, F.shape(logits)[1], self.on_value, self.off_value) | |||||
| loss_logit = self.ce(logits, one_hot_label) | |||||
| loss_logit = self.mean(loss_logit, 0) | |||||
| return loss_logit | |||||
| @@ -0,0 +1,64 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """learning rate exponential decay generator""" | |||||
| import math | |||||
| import numpy as np | |||||
| def get_lr(lr_init, lr_decay_rate, num_epoch_per_decay, total_epochs, steps_per_epoch, is_stair=False): | |||||
| """ | |||||
| generate learning rate array | |||||
| Args: | |||||
| lr_init(float): init learning rate | |||||
| lr_decay_rate (float): | |||||
| total_epochs(int): total epoch of training | |||||
| steps_per_epoch(int): steps of one epoch | |||||
| is_stair(bool): If `True` decay the learning rate at discrete intervals (default=False) | |||||
| Returns: | |||||
| learning_rate, learning rate numpy array | |||||
| """ | |||||
| lr_each_step = [] | |||||
| total_steps = steps_per_epoch * total_epochs | |||||
| decay_steps = steps_per_epoch * num_epoch_per_decay | |||||
| for i in range(total_steps): | |||||
| p = i/decay_steps | |||||
| if is_stair: | |||||
| p = math.floor(p) | |||||
| lr_each_step.append(lr_init * math.pow(lr_decay_rate, p)) | |||||
| learning_rate = np.array(lr_each_step).astype(np.float32) | |||||
| return learning_rate | |||||
| def get_lr_basic(lr_init, total_epochs, steps_per_epoch, is_stair=False): | |||||
| """ | |||||
| generate basic learning rate array | |||||
| Args: | |||||
| lr_init(float): init learning rate | |||||
| total_epochs(int): total epochs of training | |||||
| steps_per_epoch(int): steps of one epoch | |||||
| is_stair(bool): If `True` decay the learning rate at discrete intervals (default=False) | |||||
| Returns: | |||||
| learning_rate, learning rate numpy array | |||||
| """ | |||||
| lr_each_step = [] | |||||
| total_steps = steps_per_epoch * total_epochs | |||||
| for i in range(total_steps): | |||||
| lr = lr_init - lr_init * (i) / (total_steps) | |||||
| lr_each_step.append(lr) | |||||
| learning_rate = np.array(lr_each_step).astype(np.float32) | |||||
| return learning_rate | |||||
| @@ -0,0 +1,124 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """train_imagenet.""" | |||||
| import argparse | |||||
| import os | |||||
| import random | |||||
| import numpy as np | |||||
| from network import ShuffleNetV2 | |||||
| import mindspore.nn as nn | |||||
| from mindspore import context | |||||
| from mindspore import dataset as de | |||||
| from mindspore import ParallelMode | |||||
| from mindspore import Tensor | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | |||||
| from mindspore.nn.optim.momentum import Momentum | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.config import config_gpu as cfg | |||||
| from src.dataset import create_dataset | |||||
| from src.lr_generator import get_lr_basic | |||||
| random.seed(cfg.random_seed) | |||||
| np.random.seed(cfg.random_seed) | |||||
| de.config.set_seed(cfg.random_seed) | |||||
| if __name__ == '__main__': | |||||
| parser = argparse.ArgumentParser(description='image classification training') | |||||
| parser.add_argument('--dataset_path', type=str, default='/home/data/imagenet_jpeg/train/', help='Dataset path') | |||||
| parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') | |||||
| parser.add_argument('--is_distributed', action='store_true', default=False, | |||||
| help='distributed training') | |||||
| parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU'), help='run platform') | |||||
| parser.add_argument('--model_size', type=str, default='1.0x', help='ShuffleNetV2 model size parameter') | |||||
| args_opt = parser.parse_args() | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False) | |||||
| if os.getenv('DEVICE_ID', "not_set").isdigit(): | |||||
| context.set_context(device_id=int(os.getenv('DEVICE_ID'))) | |||||
| # init distributed | |||||
| if args_opt.is_distributed: | |||||
| if args_opt.platform == "Ascend": | |||||
| init() | |||||
| else: | |||||
| init("nccl") | |||||
| cfg.rank = get_rank() | |||||
| cfg.group_size = get_group_size() | |||||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size, | |||||
| parameter_broadcast=True, mirror_mean=True) | |||||
| else: | |||||
| cfg.rank = 0 | |||||
| cfg.group_size = 1 | |||||
| # dataloader | |||||
| dataset = create_dataset(args_opt.dataset_path, True, cfg.rank, cfg.group_size) | |||||
| batches_per_epoch = dataset.get_dataset_size() | |||||
| print("Batches Per Epoch: ", batches_per_epoch) | |||||
| # network | |||||
| net = ShuffleNetV2(n_class=cfg.num_classes, model_size=args_opt.model_size) | |||||
| # loss | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False, | |||||
| smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) | |||||
| # learning rate schedule | |||||
| lr = get_lr_basic(lr_init=cfg.lr_init, total_epochs=cfg.epoch_size, | |||||
| steps_per_epoch=batches_per_epoch, is_stair=True) | |||||
| lr = Tensor(lr) | |||||
| # optimizer | |||||
| decayed_params = [] | |||||
| no_decayed_params = [] | |||||
| for param in net.trainable_params(): | |||||
| if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name: | |||||
| decayed_params.append(param) | |||||
| else: | |||||
| no_decayed_params.append(param) | |||||
| group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay}, | |||||
| {'params': no_decayed_params}, | |||||
| {'order_params': net.trainable_params()}] | |||||
| optimizer = Momentum(params=net.trainable_params(), learning_rate=Tensor(lr), momentum=cfg.momentum, | |||||
| weight_decay=cfg.weight_decay) | |||||
| eval_metrics = {'Loss': nn.Loss(), | |||||
| 'Top1-Acc': nn.Top1CategoricalAccuracy(), | |||||
| 'Top5-Acc': nn.Top5CategoricalAccuracy()} | |||||
| if args_opt.resume: | |||||
| ckpt = load_checkpoint(args_opt.resume) | |||||
| load_param_into_net(net, ckpt) | |||||
| model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={'acc'}) | |||||
| print("============== Starting Training ==============") | |||||
| loss_cb = LossMonitor(per_print_times=batches_per_epoch) | |||||
| time_cb = TimeMonitor(data_size=batches_per_epoch) | |||||
| callbacks = [loss_cb, time_cb] | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max) | |||||
| ckpoint_cb = ModelCheckpoint(prefix=f"shufflenet-rank{cfg.rank}", directory=cfg.ckpt_path, config=config_ck) | |||||
| if args_opt.is_distributed & cfg.is_save_on_master: | |||||
| if cfg.rank == 0: | |||||
| callbacks.append(ckpoint_cb) | |||||
| model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) | |||||
| else: | |||||
| callbacks.append(ckpoint_cb) | |||||
| model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=True) | |||||
| print("train success") | |||||