From: @Gogery Reviewed-by: @guoqi1024 Signed-off-by: @guoqi1024tags/v1.1.0
| @@ -0,0 +1,234 @@ | |||
| # Contents | |||
| - [Contents](#contents) | |||
| - [Xception Description](#xception-description) | |||
| - [Model architecture](#model-architecture) | |||
| - [Dataset](#dataset) | |||
| - [Features](#features) | |||
| - [Mixed Precision(Ascend)](#mixed-precisionascend) | |||
| - [Environment Requirements](#environment-requirements) | |||
| - [Script description](#script-description) | |||
| - [Script and sample code](#script-and-sample-code) | |||
| - [Script Parameters](#script-parameters) | |||
| - [Training process](#training-process) | |||
| - [Usage](#usage) | |||
| - [Launch](#launch) | |||
| - [Result](#result) | |||
| - [Eval process](#eval-process) | |||
| - [Usage](#usage-1) | |||
| - [Launch](#launch-1) | |||
| - [Result](#result-1) | |||
| - [Model description](#model-description) | |||
| - [Performance](#performance) | |||
| - [Training Performance](#training-performance) | |||
| - [Inference Performance](#inference-performance) | |||
| - [Description of Random Situation](#description-of-random-situation) | |||
| - [ModelZoo Homepage](#modelzoo-homepage) | |||
| # [Xception Description](#contents) | |||
| Xception by Google is extreme version of Inception. With a modified depthwise separable convolution, it is even better than Inception-v3. This paper was published in 2017. | |||
| [Paper](https://arxiv.org/pdf/1610.02357v3.pdf) Franois Chollet. Xception: Deep Learning with Depthwise Separable Convolutions. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) IEEE, 2017. | |||
| # [Model architecture](#contents) | |||
| The overall network architecture of Xception is show below: | |||
| [Link](https://arxiv.org/pdf/1610.02357v3.pdf) | |||
| # [Dataset](#contents) | |||
| Dataset used can refer to paper. | |||
| - Dataset size: 125G, 1250k colorful images in 1000 classes | |||
| - Train: 120G, 1200k images | |||
| - Test: 5G, 50k images | |||
| - Data format: RGB images. | |||
| - Note: Data will be processed in src/dataset.py | |||
| # [Features](#contents) | |||
| ## [Mixed Precision(Ascend)](#contents) | |||
| The [mixed precision](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. | |||
| For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. | |||
| # [Environment Requirements](#contents) | |||
| - Hardware(Ascend) | |||
| - Prepare hardware environment with Ascend. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below: | |||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| # [Script description](#contents) | |||
| ## [Script and sample code](#contents) | |||
| ```shell | |||
| . | |||
| └─Xception | |||
| ├─README.md | |||
| ├─scripts | |||
| ├─run_standalone_train.sh # launch standalone training with ascend platform(1p) | |||
| ├─run_distribute_train.sh # launch distributed training with ascend platform(8p) | |||
| └─run_eval.sh # launch evaluating with ascend platform | |||
| ├─src | |||
| ├─config.py # parameter configuration | |||
| ├─dataset.py # data preprocessing | |||
| ├─Xception.py # network definition | |||
| ├─CrossEntropySmooth.py # Customized CrossEntropy loss function | |||
| └─lr_generator.py # learning rate generator | |||
| ├─train.py # train net | |||
| └─eval.py # eval net | |||
| ``` | |||
| ## [Script Parameters](#contents) | |||
| ```python | |||
| Major parameters in train.py and config.py are: | |||
| 'num_classes': 1000 # dataset class numbers | |||
| 'batch_size': 128 # input batchsize | |||
| 'loss_scale': 1024 # loss scale | |||
| 'momentum': 0.9 # momentum | |||
| 'weight_decay': 1e-4 # weight decay | |||
| 'epoch_size': 250 # total epoch numbers | |||
| 'save_checkpoint': True # save checkpoint | |||
| 'save_checkpoint_epochs': 1 # save checkpoint epochs | |||
| 'keep_checkpoint_max': 5 # max numbers to keep checkpoints | |||
| 'save_checkpoint_path': "./" # save checkpoint path | |||
| 'warmup_epochs': 1 # warmup epoch numbers | |||
| 'lr_decay_mode': "liner" # lr decay mode | |||
| 'use_label_smooth': True # use label smooth | |||
| 'finish_epoch': 0 # finished epochs numbers | |||
| 'label_smooth_factor': 0.1 # label smoothing factor | |||
| 'lr_init': 0.00004 # initiate learning rate | |||
| 'lr_max': 0.4 # max bound of learning rate | |||
| 'lr_end': 0.00004 # min bound of learning rate | |||
| "weight_init": 'xavier_uniform' # Weight initialization mode | |||
| ``` | |||
| ## [Training process](#contents) | |||
| ### Usage | |||
| You can start training using python or shell scripts. The usage of shell scripts as follows: | |||
| - Ascend: | |||
| ```shell | |||
| # distribute training example(8p) | |||
| sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH | |||
| # standalone training | |||
| sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH | |||
| ``` | |||
| > Notes: RANK_TABLE_FILE can refer to [Link](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/distributed_training_ascend.html), and the device_ip can be got as [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools). | |||
| > This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh` | |||
| ### Launch | |||
| ``` shell | |||
| # training example | |||
| python: | |||
| Ascend: | |||
| python train.py --device_target Ascend --dataset_path /dataset/train | |||
| shell: | |||
| Ascend: | |||
| # distribute training example(8p) | |||
| sh scripts/run_distribute_train.sh RANK_TABLE_FILE DATA_PATH | |||
| # standalone training | |||
| sh scripts/run_standalone_train.sh DEVICE_ID DATA_PATH | |||
| ``` | |||
| ### Result | |||
| Training result will be stored in the example path. Checkpoints will be stored at `. /model_0` by default, and training log will be redirected to `log.txt` like followings. | |||
| ``` shell | |||
| epoch: [ 0/250], step:[ 1250/ 1251], loss:[4.761/5.613], time:[529.305], lr:[0.400] | |||
| epoch time: 1128662.862, per step time: 902.209, avg loss: 5.609 | |||
| epoch: [ 1/250], step:[ 1250/ 1251], loss:[4.164/4.318], time:[503.708], lr:[0.398] | |||
| epoch time: 889163.081, per step time: 710.762, avg loss: 4.312 | |||
| ``` | |||
| ## [Eval process](#contents) | |||
| ### Usage | |||
| You can start training using python or shell scripts. The usage of shell scripts as follows: | |||
| ```shell | |||
| sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT | |||
| ``` | |||
| ### Launch | |||
| ```shell | |||
| # eval example | |||
| python: | |||
| Ascend: python eval.py --device_target Ascend --checkpoint_path PATH_CHECKPOINT --dataset_path DATA_DIR | |||
| shell: | |||
| Ascend: sh scripts/run_eval.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT | |||
| ``` | |||
| > checkpoint can be produced in training process. | |||
| ### Result | |||
| Evaluation result will be stored in the example path, you can find result like the followings in `eval.log`. | |||
| ```shell | |||
| result: {'Loss': 1.7797744848789312, 'Top_1_Acc': 0.7985777243589743, 'Top_5_Acc': 0.9485777243589744} | |||
| ``` | |||
| # [Model description](#contents) | |||
| ## [Performance](#contents) | |||
| ### Training Performance | |||
| | Parameters | Ascend | | |||
| | -------------------------- | ---------------------------------------------- | | |||
| | Model Version | Xception | | |||
| | Resource | HUAWEI CLOUD Modelarts | | |||
| | uploaded Date | 11/15/2020 | | |||
| | MindSpore Version | 1.0.0 | | |||
| | Dataset | 1200k images | | |||
| | Batch_size | 128 | | |||
| | Training Parameters | src/config.py | | |||
| | Optimizer | Momentum | | |||
| | Loss Function | CrossEntropySmooth | | |||
| | Loss | 1.78 | | |||
| | Accuracy (8p) | Top1[79.9%] Top5[94.9%] | | |||
| | Total time (8p) | 63h | | |||
| | Params (M) | 180M | | |||
| | Scripts | [Xception script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/Xception) | | |||
| #### Inference Performance | |||
| | Parameters | Ascend | | |||
| | ------------------- | --------------------------- | | |||
| | Model Version | Xception | | |||
| | Resource | HUAWEI CLOUD Modelarts | | |||
| | Uploaded Date | 11/15/2020 | | |||
| | MindSpore Version | 1.0.0 | | |||
| | Dataset | 50k images | | |||
| | Batch_size | 128 | | |||
| | Accuracy | Top1[79.9%] Top5[94.9%] | | |||
| | Total time | 3mins | | |||
| # [Description of Random Situation](#contents) | |||
| In `dataset.py`, we set the seed inside `create_dataset` function. We also use random seed in `train.py`. | |||
| # [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -0,0 +1,63 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """eval Xception.""" | |||
| import argparse | |||
| from mindspore import context, nn | |||
| from mindspore.train.model import Model | |||
| from mindspore.common import set_seed | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.Xception import xception | |||
| from src.config import config | |||
| from src.dataset import create_dataset | |||
| from src.loss import CrossEntropySmooth | |||
| set_seed(1) | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='Image classification') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') | |||
| parser.add_argument('--device_id', type=int, default=0, help='Device id') | |||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | |||
| args_opt = parser.parse_args() | |||
| context.set_context(device_id=args_opt.device_id) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False) | |||
| # create dataset | |||
| dataset = create_dataset(args_opt.dataset_path, do_train=False, batch_size=config.batch_size, device_num=1, rank=0) | |||
| step_size = dataset.get_dataset_size() | |||
| # define net | |||
| net = xception(class_num=config.class_num) | |||
| # load checkpoint | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| # define loss, model | |||
| loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | |||
| # define model | |||
| eval_metrics = {'Loss': nn.Loss(), | |||
| 'Top_1_Acc': nn.Top1CategoricalAccuracy(), | |||
| 'Top_5_Acc': nn.Top5CategoricalAccuracy()} | |||
| model = Model(net, loss_fn=loss, metrics=eval_metrics) | |||
| # eval model | |||
| res = model.eval(dataset, dataset_sink_mode=False) | |||
| print("result:", res, "ckpt=", args_opt.checkpoint_path) | |||
| @@ -0,0 +1,50 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| DATA_DIR=$2 | |||
| export RANK_TABLE_FILE=$1 | |||
| export RANK_SIZE=8 | |||
| cores=`cat /proc/cpuinfo|grep "processor" |wc -l` | |||
| echo "the number of logical core" $cores | |||
| avg_core_per_rank=`expr $cores \/ $RANK_SIZE` | |||
| core_gap=`expr $avg_core_per_rank \- 1` | |||
| echo "avg_core_per_rank" $avg_core_per_rank | |||
| echo "core_gap" $core_gap | |||
| for((i=0;i<RANK_SIZE;i++)) | |||
| do | |||
| start=`expr $i \* $avg_core_per_rank` | |||
| export DEVICE_ID=$i | |||
| export RANK_ID=$i | |||
| export DEPLOY_MODE=0 | |||
| export GE_USE_STATIC_MEMORY=1 | |||
| end=`expr $start \+ $core_gap` | |||
| cmdopt=$start"-"$end | |||
| rm -rf train_parallel$i | |||
| mkdir ./train_parallel$i | |||
| cp *.py ./train_parallel$i | |||
| cd ./train_parallel$i || exit | |||
| echo "start training for rank $i, device $DEVICE_ID" | |||
| env > env.log | |||
| taskset -c $cmdopt python ../train.py \ | |||
| --is_distributed \ | |||
| --device_target=Ascend \ | |||
| --dataset_path=$DATA_DIR > log.txt 2>&1 & | |||
| cd ../ | |||
| done | |||
| @@ -0,0 +1,25 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export DEVICE_ID=$1 | |||
| DATA_DIR=$2 | |||
| PATH_CHECKPOINT=$3 | |||
| python ./eval.py \ | |||
| --device_target=Ascend \ | |||
| --device_id=$DEVICE_ID \ | |||
| --checkpoint_path=$PATH_CHECKPOINT \ | |||
| --dataset_path=$DATA_DIR > eval.log 2>&1 & | |||
| @@ -0,0 +1,22 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export DEVICE_ID=$1 | |||
| DATA_DIR=$2 | |||
| python ./train.py \ | |||
| --device_target=Ascend \ | |||
| --dataset_path=$DATA_DIR > log.txt 2>&1 & | |||
| @@ -0,0 +1,186 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Xception.""" | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| from src.config import config | |||
| class SeparableConv2d(nn.Cell): | |||
| def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0): | |||
| super(SeparableConv2d, self).__init__() | |||
| self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, group=in_channels, pad_mode='pad', | |||
| padding=padding, weight_init=config.weight_init) | |||
| self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, pad_mode='valid', | |||
| weight_init=config.weight_init) | |||
| def construct(self, x): | |||
| x = self.conv1(x) | |||
| x = self.pointwise(x) | |||
| return x | |||
| class Block(nn.Cell): | |||
| def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): | |||
| super(Block, self).__init__() | |||
| if out_filters != in_filters or strides != 1: | |||
| self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, pad_mode='valid', has_bias=False, | |||
| weight_init=config.weight_init) | |||
| self.skipbn = nn.BatchNorm2d(out_filters, momentum=0.9) | |||
| else: | |||
| self.skip = None | |||
| self.relu = nn.ReLU() | |||
| rep = [] | |||
| filters = in_filters | |||
| if grow_first: | |||
| rep.append(nn.ReLU()) | |||
| rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1)) | |||
| rep.append(nn.BatchNorm2d(out_filters, momentum=0.9)) | |||
| filters = out_filters | |||
| for _ in range(reps - 1): | |||
| rep.append(nn.ReLU()) | |||
| rep.append(SeparableConv2d(filters, filters, kernel_size=3, stride=1, padding=1)) | |||
| rep.append(nn.BatchNorm2d(filters, momentum=0.9)) | |||
| if not grow_first: | |||
| rep.append(nn.ReLU()) | |||
| rep.append(SeparableConv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1)) | |||
| rep.append(nn.BatchNorm2d(out_filters, momentum=0.9)) | |||
| if not start_with_relu: | |||
| rep = rep[1:] | |||
| else: | |||
| rep[0] = nn.ReLU() | |||
| if strides != 1: | |||
| rep.append(nn.MaxPool2d(3, strides, pad_mode="same")) | |||
| self.rep = nn.SequentialCell(*rep) | |||
| self.add = P.TensorAdd() | |||
| def construct(self, inp): | |||
| x = self.rep(inp) | |||
| if self.skip is not None: | |||
| skip = self.skip(inp) | |||
| skip = self.skipbn(skip) | |||
| else: | |||
| skip = inp | |||
| x = self.add(x, skip) | |||
| return x | |||
| class Xception(nn.Cell): | |||
| """ | |||
| Xception optimized for the ImageNet dataset, as specified in | |||
| https://arxiv.org/abs/1610.02357.pdf | |||
| """ | |||
| def __init__(self, num_classes=1000): | |||
| """ Constructor | |||
| Args: | |||
| num_classes: number of classes. | |||
| """ | |||
| super(Xception, self).__init__() | |||
| self.num_classes = num_classes | |||
| self.conv1 = nn.Conv2d(3, 32, 3, 2, pad_mode='valid', weight_init=config.weight_init) | |||
| self.bn1 = nn.BatchNorm2d(32, momentum=0.9) | |||
| self.relu = nn.ReLU() | |||
| self.conv2 = nn.Conv2d(32, 64, 3, pad_mode='valid', weight_init=config.weight_init) | |||
| self.bn2 = nn.BatchNorm2d(64, momentum=0.9) | |||
| # Entry flow | |||
| self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) | |||
| self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) | |||
| self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) | |||
| # Middle flow | |||
| self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) | |||
| self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) | |||
| self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) | |||
| self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) | |||
| self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) | |||
| self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) | |||
| self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) | |||
| self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) | |||
| # Exit flow | |||
| self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False) | |||
| self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) | |||
| self.bn3 = nn.BatchNorm2d(1536, momentum=0.9) | |||
| self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) | |||
| self.bn4 = nn.BatchNorm2d(2048, momentum=0.9) | |||
| self.avg_pool = nn.AvgPool2d(10) | |||
| self.dropout = nn.Dropout() | |||
| self.fc = nn.Dense(2048, num_classes) | |||
| def construct(self, x): | |||
| shape = P.Shape() | |||
| reshape = P.Reshape() | |||
| x = self.conv1(x) | |||
| x = self.bn1(x) | |||
| x = self.relu(x) | |||
| x = self.conv2(x) | |||
| x = self.bn2(x) | |||
| x = self.relu(x) | |||
| x = self.block1(x) | |||
| x = self.block2(x) | |||
| x = self.block3(x) | |||
| x = self.block4(x) | |||
| x = self.block5(x) | |||
| x = self.block6(x) | |||
| x = self.block7(x) | |||
| x = self.block8(x) | |||
| x = self.block9(x) | |||
| x = self.block10(x) | |||
| x = self.block11(x) | |||
| x = self.block12(x) | |||
| x = self.conv3(x) | |||
| x = self.bn3(x) | |||
| x = self.relu(x) | |||
| x = self.conv4(x) | |||
| x = self.bn4(x) | |||
| x = self.relu(x) | |||
| x = self.avg_pool(x) | |||
| x = self.dropout(x) | |||
| x = reshape(x, (shape(x)[0], -1)) | |||
| x = self.fc(x) | |||
| return x | |||
| def xception(class_num=1000): | |||
| """ | |||
| Get Xception neural network. | |||
| Args: | |||
| class_num (int): Class number. | |||
| Returns: | |||
| Cell, cell instance of Xception neural network. | |||
| Examples: | |||
| >>> net = xception(1000) | |||
| """ | |||
| return Xception(class_num) | |||
| @@ -0,0 +1,41 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| network config setting, will be used in train.py and eval.py | |||
| """ | |||
| from easydict import EasyDict as ed | |||
| # config for Xception, imagenet2012. | |||
| config = ed({ | |||
| "class_num": 1000, | |||
| "batch_size": 128, | |||
| "loss_scale": 1024, | |||
| "momentum": 0.9, | |||
| "weight_decay": 1e-4, | |||
| "epoch_size": 250, | |||
| "save_checkpoint": True, | |||
| "save_checkpoint_epochs": 1, | |||
| "keep_checkpoint_max": 5, | |||
| "save_checkpoint_path": "./", | |||
| "warmup_epochs": 1, | |||
| "lr_decay_mode": "liner", | |||
| "use_label_smooth": True, | |||
| "finish_epoch": 0, | |||
| "label_smooth_factor": 0.1, | |||
| "lr_init": 0.00004, | |||
| "lr_max": 0.4, | |||
| "lr_end": 0.00004, | |||
| "weight_init": 'xavier_uniform' | |||
| }) | |||
| @@ -0,0 +1,66 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Data operations, will be used in train.py and eval.py | |||
| """ | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.dataset.engine as de | |||
| import mindspore.dataset.transforms.c_transforms as C2 | |||
| import mindspore.dataset.vision.c_transforms as C | |||
| def create_dataset(dataset_path, do_train, batch_size=16, device_num=1, rank=0): | |||
| """ | |||
| create a train or eval dataset | |||
| Args: | |||
| dataset_path(string): the path of dataset. | |||
| do_train(bool): whether dataset is used for train or eval. | |||
| batch_size(int): the batch size of dataset. Default: 16. | |||
| device_num (int): Number of shards that the dataset should be divided into (default=1). | |||
| rank (int): The shard ID within num_shards (default=0). | |||
| Returns: | |||
| dataset | |||
| """ | |||
| if device_num == 1: | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | |||
| else: | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=device_num, shard_id=rank) | |||
| # define map operations | |||
| if do_train: | |||
| trans = [ | |||
| C.RandomCropDecodeResize(299), | |||
| C.RandomHorizontalFlip(prob=0.5), | |||
| C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) | |||
| ] | |||
| else: | |||
| trans = [ | |||
| C.Decode(), | |||
| C.Resize(320), | |||
| C.CenterCrop(299) | |||
| ] | |||
| trans += [ | |||
| C.Normalize(mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]), | |||
| C.HWC2CHW(), | |||
| C2.TypeCast(mstype.float32) | |||
| ] | |||
| type_cast_op = C2.TypeCast(mstype.int32) | |||
| ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=8) | |||
| ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=8) | |||
| # apply batch operations | |||
| ds = ds.batch(batch_size, drop_remainder=True) | |||
| return ds | |||
| @@ -0,0 +1,38 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """define loss function for network""" | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import operations as P | |||
| class CrossEntropySmooth(_Loss): | |||
| """CrossEntropy""" | |||
| def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): | |||
| super(CrossEntropySmooth, self).__init__() | |||
| self.onehot = P.OneHot() | |||
| self.sparse = sparse | |||
| self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) | |||
| self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) | |||
| self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction) | |||
| def construct(self, logit, label): | |||
| if self.sparse: | |||
| label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) | |||
| loss = self.ce(logit, label) | |||
| return loss | |||
| @@ -0,0 +1,87 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """learning rate generator""" | |||
| import math | |||
| import numpy as np | |||
| def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): | |||
| """ | |||
| generate learning rate array | |||
| Args: | |||
| lr_init(float): init learning rate | |||
| lr_end(float): end learning rate | |||
| lr_max(float): max learning rate | |||
| warmup_epochs(int): number of warmup epochs | |||
| total_epochs(int): total epoch of training | |||
| steps_per_epoch(int): steps of one epoch | |||
| lr_decay_mode(string): learning rate decay mode, including steps, poly or default | |||
| Returns: | |||
| np.array, learning rate array | |||
| """ | |||
| lr_each_step = [] | |||
| total_steps = steps_per_epoch * total_epochs | |||
| warmup_steps = steps_per_epoch * warmup_epochs | |||
| if lr_decay_mode == 'steps': | |||
| decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] | |||
| for i in range(total_steps): | |||
| if i < decay_epoch_index[0]: | |||
| lr = lr_max | |||
| elif i < decay_epoch_index[1]: | |||
| lr = lr_max * 0.1 | |||
| elif i < decay_epoch_index[2]: | |||
| lr = lr_max * 0.01 | |||
| else: | |||
| lr = lr_max * 0.001 | |||
| lr_each_step.append(lr) | |||
| elif lr_decay_mode == 'poly': | |||
| if warmup_steps != 0: | |||
| inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) | |||
| else: | |||
| inc_each_step = 0 | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr = float(lr_init) + inc_each_step * float(i) | |||
| else: | |||
| base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) | |||
| lr = float(lr_max) * base * base | |||
| if lr < 0.0: | |||
| lr = 0.0 | |||
| lr_each_step.append(lr) | |||
| elif lr_decay_mode == 'cosine': | |||
| decay_steps = total_steps - warmup_steps | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps) | |||
| lr = float(lr_init) + lr_inc * (i + 1) | |||
| else: | |||
| linear_decay = (total_steps - i) / decay_steps | |||
| cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps)) | |||
| decayed = linear_decay * cosine_decay + 0.00001 | |||
| lr = lr_max * decayed | |||
| lr_each_step.append(lr) | |||
| else: | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr = lr_init + (lr_max - lr_init) * i / warmup_steps | |||
| else: | |||
| lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) | |||
| lr_each_step.append(lr) | |||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||
| return lr_each_step | |||
| @@ -0,0 +1,173 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """train Xception.""" | |||
| import os | |||
| import time | |||
| import argparse | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore import Tensor | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.train.model import Model, ParallelMode | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.communication.management import init, get_rank, get_group_size | |||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore.common import set_seed | |||
| from src.lr_generator import get_lr | |||
| from src.Xception import xception | |||
| from src.config import config | |||
| from src.dataset import create_dataset | |||
| from src.loss import CrossEntropySmooth | |||
| set_seed(1) | |||
| class Monitor(Callback): | |||
| """ | |||
| Monitor loss and time. | |||
| Args: | |||
| lr_init (numpy array): train lr | |||
| Returns: | |||
| None | |||
| Examples: | |||
| >>> Monitor(lr_init=Tensor([0.05]*100).asnumpy()) | |||
| """ | |||
| def __init__(self, lr_init=None): | |||
| super(Monitor, self).__init__() | |||
| self.lr_init = lr_init | |||
| self.lr_init_len = len(lr_init) | |||
| def epoch_begin(self, run_context): | |||
| self.losses = [] | |||
| self.epoch_time = time.time() | |||
| def epoch_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||
| per_step_mseconds = epoch_mseconds / cb_params.batch_num | |||
| print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}".format(epoch_mseconds, | |||
| per_step_mseconds, | |||
| np.mean(self.losses))) | |||
| def step_begin(self, run_context): | |||
| self.step_time = time.time() | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| step_mseconds = (time.time() - self.step_time) * 1000 | |||
| step_loss = cb_params.net_outputs | |||
| if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): | |||
| step_loss = step_loss[0] | |||
| if isinstance(step_loss, Tensor): | |||
| step_loss = np.mean(step_loss.asnumpy()) | |||
| self.losses.append(step_loss) | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num | |||
| print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.3f}/{:5.3f}], time:[{:5.3f}], lr:[{:5.3f}]".format( | |||
| cb_params.cur_epoch_num - 1 + config.finish_epoch, cb_params.epoch_num + config.finish_epoch, | |||
| cur_step_in_epoch, cb_params.batch_num, step_loss, | |||
| np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser(description='image classification training') | |||
| parser.add_argument('--is_distributed', action='store_true', default=False, help='distributed training') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', help='run platform') | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='dataset path') | |||
| parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') | |||
| args_opt = parser.parse_args() | |||
| # init distributed | |||
| if args_opt.is_distributed: | |||
| if os.getenv('DEVICE_ID', "not_set").isdigit(): | |||
| context.set_context(device_id=int(os.getenv('DEVICE_ID'))) | |||
| rank = get_rank() | |||
| group_size = get_group_size() | |||
| parallel_mode = ParallelMode.DATA_PARALLEL | |||
| context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True) | |||
| init() | |||
| else: | |||
| rank = 0 | |||
| group_size = 1 | |||
| context.set_context(device_id=0) | |||
| if args_opt.device_target == "Ascend": | |||
| #train on Ascend | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False) | |||
| # define network | |||
| net = xception(class_num=config.class_num) | |||
| net.to_float(mstype.float16) | |||
| # define loss | |||
| if not config.use_label_smooth: | |||
| config.label_smooth_factor = 0.0 | |||
| loss = CrossEntropySmooth(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | |||
| # define dataset | |||
| dataset = create_dataset(args_opt.dataset_path, do_train=True, batch_size=config.batch_size, | |||
| device_num=group_size, rank=rank) | |||
| step_size = dataset.get_dataset_size() | |||
| # resume | |||
| if args_opt.resume: | |||
| ckpt = load_checkpoint(args_opt.resume) | |||
| load_param_into_net(net, ckpt) | |||
| # get learning rate | |||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||
| lr = Tensor(get_lr(lr_init=config.lr_init, | |||
| lr_end=config.lr_end, | |||
| lr_max=config.lr_max, | |||
| warmup_epochs=config.warmup_epochs, | |||
| total_epochs=config.epoch_size, | |||
| steps_per_epoch=step_size, | |||
| lr_decay_mode=config.lr_decay_mode)) | |||
| # define optimization | |||
| opt = Momentum(net.trainable_params(), lr, config.momentum, config.weight_decay, config.loss_scale) | |||
| # define model | |||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, | |||
| amp_level='O3', keep_batchnorm_fp32=True) | |||
| # define callbacks | |||
| cb = [Monitor(lr_init=lr.asnumpy())] | |||
| if config.save_checkpoint: | |||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'model_' + str(rank) + '/') | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_epochs * step_size, | |||
| keep_checkpoint_max=config.keep_checkpoint_max) | |||
| ckpt_cb = ModelCheckpoint(f"Xception-rank{rank}", directory=save_ckpt_path, config=config_ck) | |||
| # begin train | |||
| if args_opt.is_distributed: | |||
| if rank == 0: | |||
| cb += [ckpt_cb] | |||
| model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=False) | |||
| else: | |||
| cb += [ckpt_cb] | |||
| model.train(config.epoch_size - config.finish_epoch, dataset, callbacks=cb, dataset_sink_mode=False) | |||
| print("train success") | |||
| else: | |||
| raise ValueError("Unsupported device_target.") | |||