diff --git a/model_zoo/official/cv/inceptionv4/README.md b/model_zoo/official/cv/inceptionv4/README.md index aed4fb7cc6..eac92df58d 100644 --- a/model_zoo/official/cv/inceptionv4/README.md +++ b/model_zoo/official/cv/inceptionv4/README.md @@ -1,4 +1,4 @@ -# InceptionV4 for Ascend +# InceptionV4 for Ascend/GPU - [InceptionV4 Description](#InceptionV4-description) - [Model Architecture](#model-architecture) @@ -12,7 +12,7 @@ - [Evaluation Process](#evaluation-process) - [Evaluation](#evaluation) - [Model Description](#model-description) - - [Performance](#performance) + - [Performance](#performance) - [Training Performance](#evaluation-performance) - [Inference Performance](#evaluation-performance) - [Description of Random Situation](#description-of-random-situation) @@ -50,8 +50,9 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil # [Environment Requirements](#contents) -- Hardware(Ascend) - - Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Hardware(Ascend/GPU) + - Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. + - or prepare GPU processor. - Framework - [MindSpore](https://www.mindspore.cn/install/en) - For more information, please check the resources below: @@ -67,6 +68,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil └─Inception-v4 ├─README.md ├─scripts + ├─run_distribute_train_gpu.sh # launch distributed training with gpu platform(8p) + ├─run_eval_gpu.sh # launch evaluating with gpu platform ├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p) ├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p) └─run_eval_ascend.sh # launch evaluating with ascend platform @@ -125,6 +128,13 @@ sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR > > This is processor cores binding operation regarding the `device_num` and total processor numbers. If you are not expect to do it, remove the operations `taskset` in `scripts/run_distribute_train.sh` +- GPU: + +```bash +# distribute training example(8p) +sh scripts/run_distribute_train_gpu.sh DATA_PATH +``` + ### Launch ```bash @@ -135,11 +145,16 @@ sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR sh scripts/run_distribute_train_ascend.sh RANK_TABLE_FILE DATA_PATH DATA_DIR # standalone training sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR + GPU: + # distribute training example(8p) + sh scripts/run_distribute_train_gpu.sh DATA_PATH ``` ### Result -Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log.txt` like following. +Training result will be stored in the example path. Checkpoints will be stored at `ckpt_path` by default, and training log will be redirected to `./log.txt` like followings. + +- Ascend ```python epoch: 1 step: 1251, loss is 5.4833196 @@ -150,6 +165,17 @@ epoch: 3 step: 1251, loss is 3.6242008 Epoch time: 288507.506, per step time: 230.622 ``` +- GPU + +```python +epoch: 1 step: 1251, loss is 6.49775 +Epoch time: 1487493.604, per step time: 1189.044 +epoch: 2 step: 1251, loss is 5.6884665 +Epoch time: 1421838.433, per step time: 1136.561 +epoch: 3 step: 1251, loss is 5.5168786 +Epoch time: 1423009.501, per step time: 1137.498 +``` + ## [Eval process](#contents) ### Usage @@ -162,6 +188,12 @@ You can start training using python or shell scripts. The usage of shell scripts sh scripts/run_eval_ascend.sh DEVICE_ID DATA_DIR CHECKPOINT_PATH ``` +- GPU + +```bash + sh scripts/run_eval_gpu.sh DATA_DIR CHECKPOINT_PATH +``` + ### Launch ```bash @@ -169,57 +201,67 @@ You can start training using python or shell scripts. The usage of shell scripts shell: Ascend: sh scripts/run_eval_ascend.sh DEVICE_ID DATA_DIR CHECKPOINT_PATH + GPU: + sh scripts/run_eval_gpu.sh DATA_DIR CHECKPOINT_PATH ``` > checkpoint can be produced in training process. ### Result -Evaluation result will be stored in the example path, you can find result like the following in `eval.log`. +Evaluation result will be stored in the example path, you can find result like the followings in `eval.log`. + +- Ascend ```python metric: {'Loss': 0.9849, 'Top1-Acc':0.7985, 'Top5-Acc':0.9460} ``` +- GPU(8p) + +```python +metric: {'Loss': 0.8144, 'Top1-Acc': 0.8009, 'Top5-Acc': 0.9457} +``` + # [Model description](#contents) ## [Performance](#contents) ### Training Performance -| Parameters | Ascend | -| -------------------------- | ------------------------------------------------------------ | -| Model Version | InceptionV4 | -| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G | -| uploaded Date | 11/04/2020 | -| MindSpore Version | 1.0.0 | -| Dataset | 1200k images | -| Batch_size | 128 | -| Training Parameters | src/config.py | -| Optimizer | RMSProp | -| Loss Function | SoftmaxCrossEntropyWithLogits | -| Outputs | probability | -| Loss | 0.98486 | -| Accuracy (8p) | ACC1[79.85%] ACC5[94.60%] | -| Total time (8p) | 20h | -| Params (M) | 153M | -| Checkpoint for Fine tuning | 2135M | -| Scripts | [inceptionv4 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv4) | +| Parameters | Ascend | GPU | +| -------------------------- | --------------------------------------------- | -------------------------------- | +| Model Version | InceptionV4 | InceptionV4 | +| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G | NV SMX2 V100-32G | +| uploaded Date | 11/04/2020 | 03/05/2021 | +| MindSpore Version | 1.0.0 | 1.0.0 | +| Dataset | 1200k images | 1200K images | +| Batch_size | 128 | 128 | +| Training Parameters | src/config.py (Ascend) | src/config.py (GPU) | +| Optimizer | RMSProp | RMSProp | +| Loss Function | SoftmaxCrossEntropyWithLogits | SoftmaxCrossEntropyWithLogits | +| Outputs | probability | probability | +| Loss | 0.98486 | 0.8144 | +| Accuracy (8p) | ACC1[79.85%] ACC5[94.60%] | ACC1[80.09%] ACC5[94.57%] | +| Total time (8p) | 20h | 95h | +| Params (M) | 153M | 153M | +| Checkpoint for Fine tuning | 2135M | 489M | +| Scripts | [inceptionv4 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv4) | [inceptionv4 script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/inceptionv4) | #### Inference Performance -| Parameters | Ascend | -| ------------------- | --------------------------- | -| Model Version | InceptionV4 | -| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G | -| Uploaded Date | 11/04/2020 | -| MindSpore Version | 1.0.0 | -| Dataset | 50k images | -| Batch_size | 128 | -| Outputs | probability | -| Accuracy | ACC1[79.85%] ACC5[94.60%] | -| Total time | 2mins | -| Model for inference | 2135M (.ckpt file) | +| Parameters | Ascend | GPU | +| ------------------- | --------------------------------------------- | ---------------------------------- | +| Model Version | InceptionV4 | InceptionV4 | +| Resource | Ascend 910, cpu:2.60GHz 192cores, memory:755G | NV SMX2 V100-32G | +| Uploaded Date | 11/04/2020 | 03/05/2021 | +| MindSpore Version | 1.0.0 | 1.0.0 | +| Dataset | 50k images | 50K images | +| Batch_size | 128 | 128 | +| Outputs | probability | probability | +| Accuracy | ACC1[79.85%] ACC5[94.60%] | ACC1[80.09%] ACC5[94.57%] | +| Total time | 2mins | 2mins | +| Model for inference | 2135M (.ckpt file) | 489M (.ckpt file) | #### Training performance results @@ -229,7 +271,11 @@ metric: {'Loss': 0.9849, 'Top1-Acc':0.7985, 'Top5-Acc':0.9460} | **Ascend** | train performance | | :--------: | :---------------: | -| 8p | 4430 img/s | +| 8p | 4430 img/s | + +| **GPU** | train performance | +| :--------: | :---------------: | +| 8p | 906 img/s | # [Description of Random Situation](#contents) @@ -237,4 +283,4 @@ In dataset.py, we set the seed inside “create_dataset" function. We also use r # [ModelZoo Homepage](#contents) -Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). \ No newline at end of file +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/inceptionv4/eval.py b/model_zoo/official/cv/inceptionv4/eval.py index 6c356742a3..b02059f86d 100644 --- a/model_zoo/official/cv/inceptionv4/eval.py +++ b/model_zoo/official/cv/inceptionv4/eval.py @@ -24,7 +24,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from src.dataset import create_dataset from src.inceptionv4 import Inceptionv4 -from src.config import config_ascend as config +from src.config import config def parse_args(): '''parse_args''' @@ -39,7 +39,7 @@ if __name__ == '__main__': args = parse_args() if args.platform == 'Ascend': - device_id = int(os.getenv('DEVICE_ID')) + device_id = int(os.getenv('DEVICE_ID', '0')) context.set_context(device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target=args.platform) diff --git a/model_zoo/official/cv/inceptionv4/export.py b/model_zoo/official/cv/inceptionv4/export.py index 16dcd03f5f..0131287e50 100644 --- a/model_zoo/official/cv/inceptionv4/export.py +++ b/model_zoo/official/cv/inceptionv4/export.py @@ -20,7 +20,7 @@ import mindspore as ms from mindspore import Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net, export, context -from src.config import config_ascend as config +from src.config import config from src.inceptionv4 import Inceptionv4 parser = argparse.ArgumentParser(description='inceptionv4 export') diff --git a/model_zoo/official/cv/inceptionv4/scripts/run_distribute_train_gpu.sh b/model_zoo/official/cv/inceptionv4/scripts/run_distribute_train_gpu.sh new file mode 100644 index 0000000000..d24d716b39 --- /dev/null +++ b/model_zoo/official/cv/inceptionv4/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +rm -rf device +mkdir device +cp ./*.py ./device +cp -r ./src ./device +cd ./device + +DATA_DIR=$1 + +export DEVICE_ID=0 +export RANK_SIZE=8 + +echo "start training" + +mpirun -n $RANK_SIZE --allow-run-as-root python train.py --dataset_path=$DATA_DIR --platform='GPU' > train.log 2>&1 & diff --git a/model_zoo/official/cv/inceptionv4/scripts/run_eval_gpu.sh b/model_zoo/official/cv/inceptionv4/scripts/run_eval_gpu.sh new file mode 100644 index 0000000000..115e5e8f22 --- /dev/null +++ b/model_zoo/official/cv/inceptionv4/scripts/run_eval_gpu.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +rm -rf evaluation +mkdir evaluation +cp ./*.py ./evaluation +cp -r ./src ./evaluation +cd ./evaluation + +export DEVICE_ID=0 +export RANK_SIZE=1 + +DATA_DIR=$1 +CKPT_DIR=$2 + +echo "start evaluation" + +python eval.py --dataset_path=$DATA_DIR --checkpoint_path=$CKPT_DIR --platform='GPU' > eval.log 2>&1 & diff --git a/model_zoo/official/cv/inceptionv4/scripts/run_standalone_train_ascend.sh b/model_zoo/official/cv/inceptionv4/scripts/run_standalone_train_ascend.sh index 9a00d4e07b..326f2a7d5f 100644 --- a/model_zoo/official/cv/inceptionv4/scripts/run_standalone_train_ascend.sh +++ b/model_zoo/official/cv/inceptionv4/scripts/run_standalone_train_ascend.sh @@ -26,4 +26,4 @@ env > env.log python -u ../train.py \ --device_id=$1 \ --dataset_path=$DATA_DIR > log.txt 2>&1 & -cd ../ \ No newline at end of file +cd ../ diff --git a/model_zoo/official/cv/inceptionv4/src/config.py b/model_zoo/official/cv/inceptionv4/src/config.py index 9ece7a578e..f1882e97df 100644 --- a/model_zoo/official/cv/inceptionv4/src/config.py +++ b/model_zoo/official/cv/inceptionv4/src/config.py @@ -17,7 +17,7 @@ network config setting, will be used in main.py """ from easydict import EasyDict as edict -config_ascend = edict({ +config = edict({ 'is_save_on_master': False, 'batch_size': 128, diff --git a/model_zoo/official/cv/inceptionv4/src/dataset.py b/model_zoo/official/cv/inceptionv4/src/dataset.py index eb03c6458c..b31b93acb1 100644 --- a/model_zoo/official/cv/inceptionv4/src/dataset.py +++ b/model_zoo/official/cv/inceptionv4/src/dataset.py @@ -18,14 +18,14 @@ import mindspore.common.dtype as mstype import mindspore.dataset as de import mindspore.dataset.vision.c_transforms as C import mindspore.dataset.transforms.c_transforms as C2 -from src.config import config_ascend as config +from src.config import config -device_id = int(os.getenv('DEVICE_ID')) -device_num = int(os.getenv('RANK_SIZE')) +device_id = int(os.getenv('DEVICE_ID', '0')) +device_num = int(os.getenv('RANK_SIZE', '1')) -def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): +def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, shard_id=0): """ Create a train or eval dataset. @@ -45,7 +45,7 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, shuffle=do_shuffle) else: ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, - shuffle=do_shuffle, num_shards=device_num, shard_id=device_id) + shuffle=do_shuffle, num_shards=device_num, shard_id=shard_id) image_length = 299 if do_train: diff --git a/model_zoo/official/cv/inceptionv4/src/inceptionv4.py b/model_zoo/official/cv/inceptionv4/src/inceptionv4.py index 15e5d0bee8..70e612d1ee 100644 --- a/model_zoo/official/cv/inceptionv4/src/inceptionv4.py +++ b/model_zoo/official/cv/inceptionv4/src/inceptionv4.py @@ -286,7 +286,6 @@ class Inceptionv4(nn.Cell): self.avgpool = P.ReduceMean(keep_dims=False) self.softmax = nn.DenseBnAct( 1536, classes, weight_init="XavierUniform", has_bias=True, has_bn=True, activation="logsoftmax") - if is_train: self.dropout = nn.Dropout(0.20) else: diff --git a/model_zoo/official/cv/inceptionv4/train.py b/model_zoo/official/cv/inceptionv4/train.py index d1185ee3ab..6ae12cd628 100644 --- a/model_zoo/official/cv/inceptionv4/train.py +++ b/model_zoo/official/cv/inceptionv4/train.py @@ -34,7 +34,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.inceptionv4 import Inceptionv4 from src.dataset import create_dataset, device_num -from src.config import config_ascend as config +from src.config import config os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' set_seed(1) @@ -82,12 +82,20 @@ def inception_v4_train(): """ print('epoch_size: {} batch_size: {} class_num {}'.format(config.epoch_size, config.batch_size, config.num_classes)) - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - context.set_context(device_id=args.device_id) - context.set_context(enable_graph_kernel=False) + context.set_context(mode=context.GRAPH_MODE, device_target=args.platform) + if args.platform == "Ascend": + context.set_context(device_id=args.device_id) + context.set_context(enable_graph_kernel=False) + rank = 0 if device_num > 1: - init(backend_name='hccl') + if args.platform == "Ascend": + init(backend_name='hccl') + elif args.platform == "GPU": + init() + else: + raise ValueError("Unsupported device target.") + rank = get_rank() context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, @@ -96,7 +104,7 @@ def inception_v4_train(): # create dataset train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True, - repeat_num=1, batch_size=config.batch_size) + repeat_num=1, batch_size=config.batch_size, shard_id=rank) train_step_size = train_dataset.get_dataset_size() # create model @@ -131,8 +139,16 @@ def inception_v4_train(): load_param_into_net(net, ckpt) loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) - model = Model(net, loss_fn=loss, optimizer=opt, metrics={ - 'acc', 'top_1_accuracy', 'top_5_accuracy'}, loss_scale_manager=loss_scale_manager, amp_level=config.amp_level) + + + if args.platform == "Ascend": + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'}, + loss_scale_manager=loss_scale_manager, amp_level=config.amp_level) + elif args.platform == "GPU": + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'}, + loss_scale_manager=loss_scale_manager, amp_level='O0') + else: + raise ValueError("Unsupported device target.") # define callbacks performance_cb = TimeMonitor(data_size=train_step_size) @@ -156,6 +172,8 @@ def parse_args(): arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training') arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') arg_parser.add_argument('--device_id', type=int, default=0, help='device id') + arg_parser.add_argument('--platform', type=str, default='Ascend', choices=("Ascend", "GPU"), + help='Platform, support Ascend, GPU.') arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint') args_opt = arg_parser.parse_args() return args_opt