| @@ -50,8 +50,8 @@ MobileNetV3总体网络架构如下: | |||||
| # 环境要求 | # 环境要求 | ||||
| - 硬件:GPU | |||||
| - 准备GPU处理器搭建硬件环境。 | |||||
| - 硬件:GPU/CPU | |||||
| - 准备GPU/CPU处理器搭建硬件环境。 | |||||
| - 框架 | - 框架 | ||||
| - [MindSpore](https://www.mindspore.cn/install) | - [MindSpore](https://www.mindspore.cn/install) | ||||
| - 如需查看详情,请参见如下资源: | - 如需查看详情,请参见如下资源: | ||||
| @@ -86,6 +86,7 @@ MobileNetV3总体网络架构如下: | |||||
| 使用python或shell脚本开始训练。shell脚本的使用方法如下: | 使用python或shell脚本开始训练。shell脚本的使用方法如下: | ||||
| - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] | - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] | ||||
| - CPU: sh run_trian.sh CPU [DATASET_PATH] | |||||
| ### 启动 | ### 启动 | ||||
| @@ -93,8 +94,10 @@ MobileNetV3总体网络架构如下: | |||||
| # 训练示例 | # 训练示例 | ||||
| python: | python: | ||||
| GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU | GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU | ||||
| CPU: python train.py --dataset_path ~/cifar10/train/ --device_targe CPU | |||||
| shell: | shell: | ||||
| GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ | GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ | ||||
| CPU: sh run_train.sh CPU ~/cifar10/train/ | |||||
| ``` | ``` | ||||
| ### 结果 | ### 结果 | ||||
| @@ -115,6 +118,7 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917 | |||||
| 使用python或shell脚本开始训练。shell脚本的使用方法如下: | 使用python或shell脚本开始训练。shell脚本的使用方法如下: | ||||
| - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] | - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] | ||||
| - CPU: sh run_infer.sh CPU [DATASET_PATH] [CHECKPOINT_PATH] | |||||
| ### 启动 | ### 启动 | ||||
| @@ -122,9 +126,11 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917 | |||||
| # 推理示例 | # 推理示例 | ||||
| python: | python: | ||||
| GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU | GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU | ||||
| CPU: python eval.py --dataset_path ~/cifar10/val/ --checkpoint_path mobilenet_199.ckpt --device_targe CPU | |||||
| shell: | shell: | ||||
| GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt | GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt | ||||
| CPU: sh run_infer.sh CPU ~/cifar10/val/ ~/train/mobilenet-200_625.ckpt | |||||
| ``` | ``` | ||||
| > 训练过程中可以生成检查点。 | > 训练过程中可以生成检查点。 | ||||
| @@ -19,7 +19,6 @@ | |||||
| # [MobileNetV3 Description](#contents) | # [MobileNetV3 Description](#contents) | ||||
| MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. | MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. | ||||
| [Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for mobilenetv3." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. | [Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for mobilenetv3." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. | ||||
| @@ -35,37 +34,35 @@ The overall network architecture of MobileNetV3 is show below: | |||||
| Dataset used: [imagenet](http://www.image-net.org/) | Dataset used: [imagenet](http://www.image-net.org/) | ||||
| - Dataset size: ~125G, 1.2W colorful images in 1000 classes | - Dataset size: ~125G, 1.2W colorful images in 1000 classes | ||||
| - Train: 120G, 1.2W images | |||||
| - Test: 5G, 50000 images | |||||
| - Train: 120G, 1.2W images | |||||
| - Test: 5G, 50000 images | |||||
| - Data format: RGB images. | - Data format: RGB images. | ||||
| - Note: Data will be processed in src/dataset.py | |||||
| - Note: Data will be processed in src/dataset.py | |||||
| # [Environment Requirements](#contents) | # [Environment Requirements](#contents) | ||||
| - Hardware(GPU) | |||||
| - Prepare hardware environment with GPU processor. | |||||
| - Hardware(GPU/CPU) | |||||
| - Prepare hardware environment with GPU/CPU processor. | |||||
| - Framework | - Framework | ||||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||||
| - For more information, please check the resources below: | - For more information, please check the resources below: | ||||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||||
| # [Script description](#contents) | # [Script description](#contents) | ||||
| ## [Script and sample code](#contents) | ## [Script and sample code](#contents) | ||||
| ```python | ```python | ||||
| ├── MobileNetV3 | |||||
| ├── Readme.md # descriptions about MobileNetV3 | |||||
| ├── scripts | |||||
| │ ├──run_train.sh # shell script for train | |||||
| │ ├──run_eval.sh # shell script for evaluation | |||||
| ├── src | |||||
| │ ├──config.py # parameter configuration | |||||
| ├── MobileNetV3 | |||||
| ├── Readme.md # descriptions about MobileNetV3 | |||||
| ├── scripts | |||||
| │ ├──run_train.sh # shell script for train | |||||
| │ ├──run_eval.sh # shell script for evaluation | |||||
| ├── src | |||||
| │ ├──config.py # parameter configuration | |||||
| │ ├──dataset.py # creating dataset | │ ├──dataset.py # creating dataset | ||||
| │ ├──lr_generator.py # learning rate config | |||||
| │ ├──lr_generator.py # learning rate config | |||||
| │ ├──mobilenetV3.py # MobileNetV3 architecture | │ ├──mobilenetV3.py # MobileNetV3 architecture | ||||
| ├── train.py # training script | ├── train.py # training script | ||||
| ├── eval.py # evaluation script | ├── eval.py # evaluation script | ||||
| @@ -80,22 +77,25 @@ Dataset used: [imagenet](http://www.image-net.org/) | |||||
| You can start training using python or shell scripts. The usage of shell scripts as follows: | You can start training using python or shell scripts. The usage of shell scripts as follows: | ||||
| - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] | - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] | ||||
| - CPU: sh run_trian.sh CPU [DATASET_PATH] | |||||
| ### Launch | ### Launch | ||||
| ``` | |||||
| ```shell | |||||
| # training example | # training example | ||||
| python: | python: | ||||
| GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU | GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU | ||||
| CPU: python train.py --dataset_path ~/cifar10/train/ --device_targe CPU | |||||
| shell: | shell: | ||||
| GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ | GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ | ||||
| CPU: sh run_train.sh CPU ~/cifar10/train/ | |||||
| ``` | ``` | ||||
| ### Result | ### Result | ||||
| Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. | |||||
| Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. | |||||
| ``` | |||||
| ```bash | |||||
| epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] | epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] | ||||
| epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 | epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 | ||||
| epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] | epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] | ||||
| @@ -109,25 +109,28 @@ epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 | |||||
| You can start training using python or shell scripts. The usage of shell scripts as follows: | You can start training using python or shell scripts. The usage of shell scripts as follows: | ||||
| - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] | - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] | ||||
| - CPU: sh run_infer.sh CPU [DATASET_PATH] [CHECKPOINT_PATH] | |||||
| ### Launch | ### Launch | ||||
| ``` | |||||
| ```shell | |||||
| # infer example | # infer example | ||||
| python: | python: | ||||
| GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU | GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU | ||||
| CPU: python eval.py --dataset_path ~/cifar10/val/ --checkpoint_path mobilenet_199.ckpt --device_targe CPU | |||||
| shell: | shell: | ||||
| GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt | GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt | ||||
| CPU: sh run_infer.sh CPU ~/cifar10/val/ ~/train/mobilenet-200_625.ckpt | |||||
| ``` | ``` | ||||
| > checkpoint can be produced in training process. | |||||
| > checkpoint can be produced in training process. | |||||
| ### Result | ### Result | ||||
| Inference result will be stored in the example path, you can find result like the followings in `val.log`. | |||||
| Inference result will be stored in the example path, you can find result like the followings in `val.log`. | |||||
| ``` | |||||
| ```bash | |||||
| result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt | result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt | ||||
| ``` | ``` | ||||
| @@ -135,7 +138,7 @@ result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625. | |||||
| Change the export mode and export file in `src/config.py`, and run `export.py`. | Change the export mode and export file in `src/config.py`, and run `export.py`. | ||||
| ``` | |||||
| ```python | |||||
| python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH] | python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH] | ||||
| ``` | ``` | ||||
| @@ -168,5 +171,5 @@ python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH] | |||||
| In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. | In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. | ||||
| # [ModelZoo Homepage](#contents) | # [ModelZoo Homepage](#contents) | ||||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||||
| @@ -21,7 +21,9 @@ from mindspore import nn | |||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.dataset import create_dataset_cifar | |||||
| from src.config import config_gpu | from src.config import config_gpu | ||||
| from src.config import config_cpu | |||||
| from src.mobilenetV3 import mobilenet_v3_large | from src.mobilenetV3 import mobilenet_v3_large | ||||
| @@ -38,17 +40,24 @@ if __name__ == '__main__': | |||||
| config = config_gpu | config = config_gpu | ||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| device_target="GPU", save_graphs=False) | device_target="GPU", save_graphs=False) | ||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, | |||||
| do_train=False, | |||||
| config=config, | |||||
| device_target=args_opt.device_target, | |||||
| batch_size=config.batch_size) | |||||
| elif args_opt.device_target == "CPU": | |||||
| config = config_cpu | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target="CPU", save_graphs=False) | |||||
| dataset = create_dataset_cifar(dataset_path=args_opt.dataset_path, | |||||
| do_train=False, | |||||
| batch_size=config.batch_size) | |||||
| else: | else: | ||||
| raise ValueError("Unsupported device_target.") | raise ValueError("Unsupported device_target.") | ||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | ||||
| net = mobilenet_v3_large(num_classes=config.num_classes, activation="Softmax") | net = mobilenet_v3_large(num_classes=config.num_classes, activation="Softmax") | ||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, | |||||
| do_train=False, | |||||
| config=config, | |||||
| device_target=args_opt.device_target, | |||||
| batch_size=config.batch_size) | |||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| if args_opt.checkpoint_path: | if args_opt.checkpoint_path: | ||||
| @@ -19,6 +19,7 @@ import argparse | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export | from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export | ||||
| from src.config import config_gpu | from src.config import config_gpu | ||||
| from src.config import config_cpu | |||||
| from src.mobilenetV3 import mobilenet_v3_large | from src.mobilenetV3 import mobilenet_v3_large | ||||
| @@ -32,6 +33,9 @@ if __name__ == '__main__': | |||||
| if args_opt.device_target == "GPU": | if args_opt.device_target == "GPU": | ||||
| cfg = config_gpu | cfg = config_gpu | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| elif args_opt.device_target == "CPU": | |||||
| cfg = config_cpu | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||||
| else: | else: | ||||
| raise ValueError("Unsupported device_target.") | raise ValueError("Unsupported device_target.") | ||||
| @@ -16,6 +16,7 @@ | |||||
| if [ $# != 3 ] | if [ $# != 3 ] | ||||
| then | then | ||||
| echo "GPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]" | echo "GPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]" | ||||
| echo "CPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -16,6 +16,14 @@ | |||||
| run_gpu() | run_gpu() | ||||
| { | { | ||||
| if [ $# -gt 5 ] || [ $# -lt 4 ] | |||||
| then | |||||
| echo "Usage:\n \ | |||||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ | |||||
| CPU: sh run_train.sh CPU [DATASET_PATH]\n \ | |||||
| " | |||||
| exit 1 | |||||
| fi | |||||
| if [ $2 -lt 1 ] && [ $2 -gt 8 ] | if [ $2 -lt 1 ] && [ $2 -gt 8 ] | ||||
| then | then | ||||
| echo "error: DEVICE_NUM=$2 is not in (1-8)" | echo "error: DEVICE_NUM=$2 is not in (1-8)" | ||||
| @@ -45,16 +53,42 @@ run_gpu() | |||||
| &> ../train.log & # dataset train folder | &> ../train.log & # dataset train folder | ||||
| } | } | ||||
| if [ $# -gt 5 ] || [ $# -lt 4 ] | |||||
| then | |||||
| echo "Usage:\n \ | |||||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ | |||||
| " | |||||
| exit 1 | |||||
| fi | |||||
| run_cpu() | |||||
| { | |||||
| if [ $# -gt 3 ] || [ $# -lt 2 ] | |||||
| then | |||||
| echo "Usage:\n \ | |||||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ | |||||
| CPU: sh run_train.sh CPU [DATASET_PATH]\n \ | |||||
| " | |||||
| exit 1 | |||||
| fi | |||||
| if [ ! -d $2 ] | |||||
| then | |||||
| echo "error: DATASET_PATH=$2 is not a directory" | |||||
| exit 1 | |||||
| fi | |||||
| BASEPATH=$(cd "`dirname $0`" || exit; pwd) | |||||
| export PYTHONPATH=${BASEPATH}:$PYTHONPATH | |||||
| if [ -d "../train" ]; | |||||
| then | |||||
| rm -rf ../train | |||||
| fi | |||||
| mkdir ../train | |||||
| cd ../train || exit | |||||
| python ${BASEPATH}/../train.py \ | |||||
| --dataset_path=$2 \ | |||||
| --device_target=$1 \ | |||||
| &> ../train.log & # dataset train folder | |||||
| } | |||||
| if [ $1 = "GPU" ] ; then | if [ $1 = "GPU" ] ; then | ||||
| run_gpu "$@" | run_gpu "$@" | ||||
| elif [ $1 = "CPU" ] ; then | |||||
| run_cpu "$@" | |||||
| else | else | ||||
| echo "Unsupported device_target" | echo "Unsupported device_target" | ||||
| fi; | fi; | ||||
| @@ -36,3 +36,23 @@ config_gpu = ed({ | |||||
| "export_format": "MINDIR", | "export_format": "MINDIR", | ||||
| "export_file": "mobilenetv3" | "export_file": "mobilenetv3" | ||||
| }) | }) | ||||
| config_cpu = ed({ | |||||
| "num_classes": 10, | |||||
| "image_height": 224, | |||||
| "image_width": 224, | |||||
| "batch_size": 32, | |||||
| "epoch_size": 120, | |||||
| "warmup_epochs": 5, | |||||
| "lr": 0.1, | |||||
| "momentum": 0.9, | |||||
| "weight_decay": 1e-4, | |||||
| "label_smooth": 0.1, | |||||
| "loss_scale": 1024, | |||||
| "save_checkpoint": True, | |||||
| "save_checkpoint_epochs": 1, | |||||
| "keep_checkpoint_max": 500, | |||||
| "save_checkpoint_path": "./checkpoint", | |||||
| "export_format": "MINDIR", | |||||
| "export_file": "mobilenetv3" | |||||
| }) | |||||
| @@ -83,3 +83,60 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, | |||||
| data_set = data_set.repeat(repeat_num) | data_set = data_set.repeat(repeat_num) | ||||
| return data_set | return data_set | ||||
| def create_dataset_cifar(dataset_path, | |||||
| do_train, | |||||
| repeat_num=1, | |||||
| batch_size=32, | |||||
| target="CPU"): | |||||
| """ | |||||
| create a train or evaluate cifar10 dataset | |||||
| Args: | |||||
| dataset_path(string): the path of dataset. | |||||
| do_train(bool): whether dataset is used for train or eval. | |||||
| repeat_num(int): the repeat times of dataset. Default: 1 | |||||
| batch_size(int): the batch size of dataset. Default: 32 | |||||
| target(str): the device target. Default: Ascend | |||||
| Returns: | |||||
| dataset | |||||
| """ | |||||
| data_set = ds.Cifar10Dataset(dataset_path, | |||||
| num_parallel_workers=8, | |||||
| shuffle=True) | |||||
| # define map operations | |||||
| if do_train: | |||||
| trans = [ | |||||
| C.RandomCrop((32, 32), (4, 4, 4, 4)), | |||||
| C.RandomHorizontalFlip(prob=0.5), | |||||
| C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4), | |||||
| C.Resize((224, 224)), | |||||
| C.Rescale(1.0 / 255.0, 0.0), | |||||
| C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), | |||||
| C.CutOut(112), | |||||
| C.HWC2CHW() | |||||
| ] | |||||
| else: | |||||
| trans = [ | |||||
| C.Resize((224, 224)), | |||||
| C.Rescale(1.0 / 255.0, 0.0), | |||||
| C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), | |||||
| C.HWC2CHW() | |||||
| ] | |||||
| type_cast_op = C2.TypeCast(mstype.int32) | |||||
| data_set = data_set.map(operations=type_cast_op, | |||||
| input_columns="label", | |||||
| num_parallel_workers=8) | |||||
| data_set = data_set.map(operations=trans, | |||||
| input_columns="image", | |||||
| num_parallel_workers=8) | |||||
| # apply batch operations | |||||
| data_set = data_set.batch(batch_size, drop_remainder=True) | |||||
| # apply dataset repeat operation | |||||
| data_set = data_set.repeat(repeat_num) | |||||
| return data_set | |||||
| @@ -37,8 +37,10 @@ from mindspore.common import set_seed | |||||
| from mindspore.communication.management import init, get_group_size, get_rank | from mindspore.communication.management import init, get_group_size, get_rank | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.dataset import create_dataset_cifar | |||||
| from src.lr_generator import get_lr | from src.lr_generator import get_lr | ||||
| from src.config import config_gpu | from src.config import config_gpu | ||||
| from src.config import config_cpu | |||||
| from src.mobilenetV3 import mobilenet_v3_large | from src.mobilenetV3 import mobilenet_v3_large | ||||
| set_seed(1) | set_seed(1) | ||||
| @@ -59,6 +61,10 @@ if args_opt.device_target == "GPU": | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), | context.set_auto_parallel_context(device_num=get_group_size(), | ||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| gradients_mean=True) | gradients_mean=True) | ||||
| elif args_opt.device_target == "CPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target="CPU", | |||||
| save_graphs=False) | |||||
| else: | else: | ||||
| raise ValueError("Unsupported device_target.") | raise ValueError("Unsupported device_target.") | ||||
| @@ -151,58 +157,71 @@ class Monitor(Callback): | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| config_ = None | |||||
| if args_opt.device_target == "GPU": | |||||
| config_ = config_gpu | |||||
| elif args_opt.device_target == "CPU": | |||||
| config_ = config_cpu | |||||
| else: | |||||
| raise ValueError("Unsupported device_target.") | |||||
| # train on device | |||||
| print("train args: ", args_opt) | |||||
| print("cfg: ", config_) | |||||
| # define net | |||||
| net = mobilenet_v3_large(num_classes=config_.num_classes) | |||||
| # define loss | |||||
| if config_.label_smooth > 0: | |||||
| loss = CrossEntropyWithLabelSmooth( | |||||
| smooth_factor=config_.label_smooth, num_classes=config_.num_classes) | |||||
| else: | |||||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||||
| # define dataset | |||||
| epoch_size = config_.epoch_size | |||||
| if args_opt.device_target == "GPU": | if args_opt.device_target == "GPU": | ||||
| # train on gpu | |||||
| print("train args: ", args_opt) | |||||
| print("cfg: ", config_gpu) | |||||
| # define net | |||||
| net = mobilenet_v3_large(num_classes=config_gpu.num_classes) | |||||
| # define loss | |||||
| if config_gpu.label_smooth > 0: | |||||
| loss = CrossEntropyWithLabelSmooth( | |||||
| smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes) | |||||
| else: | |||||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||||
| # define dataset | |||||
| epoch_size = config_gpu.epoch_size | |||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, | dataset = create_dataset(dataset_path=args_opt.dataset_path, | ||||
| do_train=True, | do_train=True, | ||||
| config=config_gpu, | |||||
| config=config_, | |||||
| device_target=args_opt.device_target, | device_target=args_opt.device_target, | ||||
| repeat_num=1, | repeat_num=1, | ||||
| batch_size=config_gpu.batch_size, | |||||
| run_distribute=args_opt.run_distribute) | |||||
| step_size = dataset.get_dataset_size() | |||||
| # resume | |||||
| if args_opt.pre_trained: | |||||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||||
| load_param_into_net(net, param_dict) | |||||
| # define optimizer | |||||
| loss_scale = FixedLossScaleManager( | |||||
| config_gpu.loss_scale, drop_overflow_update=False) | |||||
| lr = Tensor(get_lr(global_step=0, | |||||
| lr_init=0, | |||||
| lr_end=0, | |||||
| lr_max=config_gpu.lr, | |||||
| warmup_epochs=config_gpu.warmup_epochs, | |||||
| total_epochs=epoch_size, | |||||
| steps_per_epoch=step_size)) | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum, | |||||
| config_gpu.weight_decay, config_gpu.loss_scale) | |||||
| # define model | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, | |||||
| loss_scale_manager=loss_scale) | |||||
| cb = [Monitor(lr_init=lr.asnumpy())] | |||||
| if args_opt.run_distribute: | |||||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||||
| else: | |||||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/" | |||||
| if config_gpu.save_checkpoint: | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | |||||
| keep_checkpoint_max=config_gpu.keep_checkpoint_max) | |||||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck) | |||||
| cb += [ckpt_cb] | |||||
| # begine train | |||||
| model.train(epoch_size, dataset, callbacks=cb) | |||||
| batch_size=config_.batch_size, | |||||
| run_distribute=False) | |||||
| elif args_opt.device_target == "CPU": | |||||
| dataset = create_dataset_cifar(args_opt.dataset_path, | |||||
| do_train=True, | |||||
| batch_size=config_.batch_size) | |||||
| else: | |||||
| raise ValueError("Unsupported device_target.") | |||||
| step_size = dataset.get_dataset_size() | |||||
| # resume | |||||
| if args_opt.pre_trained: | |||||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||||
| load_param_into_net(net, param_dict) | |||||
| # define optimizer | |||||
| loss_scale = FixedLossScaleManager( | |||||
| config_.loss_scale, drop_overflow_update=False) | |||||
| lr = Tensor(get_lr(global_step=0, | |||||
| lr_init=0, | |||||
| lr_end=0, | |||||
| lr_max=config_.lr, | |||||
| warmup_epochs=config_.warmup_epochs, | |||||
| total_epochs=epoch_size, | |||||
| steps_per_epoch=step_size)) | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_.momentum, | |||||
| config_.weight_decay, config_.loss_scale) | |||||
| # define model | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, | |||||
| loss_scale_manager=loss_scale) | |||||
| cb = [Monitor(lr_init=lr.asnumpy())] | |||||
| if args_opt.run_distribute: | |||||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||||
| else: | |||||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/" | |||||
| if config_.save_checkpoint: | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_.save_checkpoint_epochs * step_size, | |||||
| keep_checkpoint_max=config_.keep_checkpoint_max) | |||||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck) | |||||
| cb += [ckpt_cb] | |||||
| # begine train | |||||
| model.train(epoch_size, dataset, callbacks=cb) | |||||