diff --git a/model_zoo/official/cv/mobilenetv3/README_CN.md b/model_zoo/official/cv/mobilenetv3/README_CN.md index 89755c6644..549ccdcd50 100644 --- a/model_zoo/official/cv/mobilenetv3/README_CN.md +++ b/model_zoo/official/cv/mobilenetv3/README_CN.md @@ -50,8 +50,8 @@ MobileNetV3总体网络架构如下: # 环境要求 -- 硬件:GPU - - 准备GPU处理器搭建硬件环境。 +- 硬件:GPU/CPU + - 准备GPU/CPU处理器搭建硬件环境。 - 框架 - [MindSpore](https://www.mindspore.cn/install) - 如需查看详情,请参见如下资源: @@ -86,6 +86,7 @@ MobileNetV3总体网络架构如下: 使用python或shell脚本开始训练。shell脚本的使用方法如下: - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] +- CPU: sh run_trian.sh CPU [DATASET_PATH] ### 启动 @@ -93,8 +94,10 @@ MobileNetV3总体网络架构如下: # 训练示例 python: GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU + CPU: python train.py --dataset_path ~/cifar10/train/ --device_targe CPU shell: GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ + CPU: sh run_train.sh CPU ~/cifar10/train/ ``` ### 结果 @@ -115,6 +118,7 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917 使用python或shell脚本开始训练。shell脚本的使用方法如下: - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] +- CPU: sh run_infer.sh CPU [DATASET_PATH] [CHECKPOINT_PATH] ### 启动 @@ -122,9 +126,11 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917 # 推理示例 python: GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU + CPU: python eval.py --dataset_path ~/cifar10/val/ --checkpoint_path mobilenet_199.ckpt --device_targe CPU shell: GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt + CPU: sh run_infer.sh CPU ~/cifar10/val/ ~/train/mobilenet-200_625.ckpt ``` > 训练过程中可以生成检查点。 diff --git a/model_zoo/official/cv/mobilenetv3/Readme.md b/model_zoo/official/cv/mobilenetv3/Readme.md index 86ff26f3fd..c5ae0d32a9 100644 --- a/model_zoo/official/cv/mobilenetv3/Readme.md +++ b/model_zoo/official/cv/mobilenetv3/Readme.md @@ -19,7 +19,6 @@ # [MobileNetV3 Description](#contents) - MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. [Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for mobilenetv3." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. @@ -35,37 +34,35 @@ The overall network architecture of MobileNetV3 is show below: Dataset used: [imagenet](http://www.image-net.org/) - Dataset size: ~125G, 1.2W colorful images in 1000 classes - - Train: 120G, 1.2W images - - Test: 5G, 50000 images + - Train: 120G, 1.2W images + - Test: 5G, 50000 images - Data format: RGB images. - - Note: Data will be processed in src/dataset.py - + - Note: Data will be processed in src/dataset.py # [Environment Requirements](#contents) -- Hardware(GPU) - - Prepare hardware environment with GPU processor. +- Hardware(GPU/CPU) + - Prepare hardware environment with GPU/CPU processor. - Framework - - [MindSpore](https://www.mindspore.cn/install/en) + - [MindSpore](https://www.mindspore.cn/install/en) - For more information, please check the resources below: - - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) - - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) - + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) # [Script description](#contents) ## [Script and sample code](#contents) ```python -├── MobileNetV3 - ├── Readme.md # descriptions about MobileNetV3 - ├── scripts - │ ├──run_train.sh # shell script for train - │ ├──run_eval.sh # shell script for evaluation - ├── src - │ ├──config.py # parameter configuration +├── MobileNetV3 + ├── Readme.md # descriptions about MobileNetV3 + ├── scripts + │ ├──run_train.sh # shell script for train + │ ├──run_eval.sh # shell script for evaluation + ├── src + │ ├──config.py # parameter configuration │ ├──dataset.py # creating dataset - │ ├──lr_generator.py # learning rate config + │ ├──lr_generator.py # learning rate config │ ├──mobilenetV3.py # MobileNetV3 architecture ├── train.py # training script ├── eval.py # evaluation script @@ -80,22 +77,25 @@ Dataset used: [imagenet](http://www.image-net.org/) You can start training using python or shell scripts. The usage of shell scripts as follows: - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] +- CPU: sh run_trian.sh CPU [DATASET_PATH] ### Launch -``` +```shell # training example python: GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU + CPU: python train.py --dataset_path ~/cifar10/train/ --device_targe CPU shell: GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ + CPU: sh run_train.sh CPU ~/cifar10/train/ ``` ### Result -Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. +Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. -``` +```bash epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] @@ -109,25 +109,28 @@ epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 You can start training using python or shell scripts. The usage of shell scripts as follows: - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] +- CPU: sh run_infer.sh CPU [DATASET_PATH] [CHECKPOINT_PATH] ### Launch -``` +```shell # infer example python: GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU + CPU: python eval.py --dataset_path ~/cifar10/val/ --checkpoint_path mobilenet_199.ckpt --device_targe CPU shell: GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt + CPU: sh run_infer.sh CPU ~/cifar10/val/ ~/train/mobilenet-200_625.ckpt ``` -> checkpoint can be produced in training process. +> checkpoint can be produced in training process. ### Result -Inference result will be stored in the example path, you can find result like the followings in `val.log`. +Inference result will be stored in the example path, you can find result like the followings in `val.log`. -``` +```bash result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt ``` @@ -135,7 +138,7 @@ result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625. Change the export mode and export file in `src/config.py`, and run `export.py`. -``` +```python python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH] ``` @@ -168,5 +171,5 @@ python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH] In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. # [ModelZoo Homepage](#contents) - -Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). + +Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/mobilenetv3/eval.py b/model_zoo/official/cv/mobilenetv3/eval.py index 1babebf1dc..b99fbe809e 100644 --- a/model_zoo/official/cv/mobilenetv3/eval.py +++ b/model_zoo/official/cv/mobilenetv3/eval.py @@ -21,7 +21,9 @@ from mindspore import nn from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.dataset import create_dataset +from src.dataset import create_dataset_cifar from src.config import config_gpu +from src.config import config_cpu from src.mobilenetV3 import mobilenet_v3_large @@ -38,17 +40,24 @@ if __name__ == '__main__': config = config_gpu context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=False) + dataset = create_dataset(dataset_path=args_opt.dataset_path, + do_train=False, + config=config, + device_target=args_opt.device_target, + batch_size=config.batch_size) + elif args_opt.device_target == "CPU": + config = config_cpu + context.set_context(mode=context.GRAPH_MODE, + device_target="CPU", save_graphs=False) + dataset = create_dataset_cifar(dataset_path=args_opt.dataset_path, + do_train=False, + batch_size=config.batch_size) else: raise ValueError("Unsupported device_target.") loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') net = mobilenet_v3_large(num_classes=config.num_classes, activation="Softmax") - dataset = create_dataset(dataset_path=args_opt.dataset_path, - do_train=False, - config=config, - device_target=args_opt.device_target, - batch_size=config.batch_size) step_size = dataset.get_dataset_size() if args_opt.checkpoint_path: diff --git a/model_zoo/official/cv/mobilenetv3/export.py b/model_zoo/official/cv/mobilenetv3/export.py index f073887f68..d8b9ac52b8 100644 --- a/model_zoo/official/cv/mobilenetv3/export.py +++ b/model_zoo/official/cv/mobilenetv3/export.py @@ -19,6 +19,7 @@ import argparse import numpy as np from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export from src.config import config_gpu +from src.config import config_cpu from src.mobilenetV3 import mobilenet_v3_large @@ -32,6 +33,9 @@ if __name__ == '__main__': if args_opt.device_target == "GPU": cfg = config_gpu context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + elif args_opt.device_target == "CPU": + cfg = config_cpu + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") else: raise ValueError("Unsupported device_target.") diff --git a/model_zoo/official/cv/mobilenetv3/scripts/run_infer.sh b/model_zoo/official/cv/mobilenetv3/scripts/run_infer.sh index f59c552c51..b8d600f1c9 100644 --- a/model_zoo/official/cv/mobilenetv3/scripts/run_infer.sh +++ b/model_zoo/official/cv/mobilenetv3/scripts/run_infer.sh @@ -16,6 +16,7 @@ if [ $# != 3 ] then echo "GPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]" + echo "CPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]" exit 1 fi diff --git a/model_zoo/official/cv/mobilenetv3/scripts/run_train.sh b/model_zoo/official/cv/mobilenetv3/scripts/run_train.sh index e9f1ac745d..075b3721ca 100644 --- a/model_zoo/official/cv/mobilenetv3/scripts/run_train.sh +++ b/model_zoo/official/cv/mobilenetv3/scripts/run_train.sh @@ -16,6 +16,14 @@ run_gpu() { + if [ $# -gt 5 ] || [ $# -lt 4 ] + then + echo "Usage:\n \ + GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ + CPU: sh run_train.sh CPU [DATASET_PATH]\n \ + " + exit 1 + fi if [ $2 -lt 1 ] && [ $2 -gt 8 ] then echo "error: DEVICE_NUM=$2 is not in (1-8)" @@ -45,16 +53,42 @@ run_gpu() &> ../train.log & # dataset train folder } -if [ $# -gt 5 ] || [ $# -lt 4 ] -then - echo "Usage:\n \ - GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ - " -exit 1 -fi +run_cpu() +{ + if [ $# -gt 3 ] || [ $# -lt 2 ] + then + echo "Usage:\n \ + GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ + CPU: sh run_train.sh CPU [DATASET_PATH]\n \ + " + exit 1 + fi + + if [ ! -d $2 ] + then + echo "error: DATASET_PATH=$2 is not a directory" + exit 1 + fi + + BASEPATH=$(cd "`dirname $0`" || exit; pwd) + export PYTHONPATH=${BASEPATH}:$PYTHONPATH + if [ -d "../train" ]; + then + rm -rf ../train + fi + mkdir ../train + cd ../train || exit + + python ${BASEPATH}/../train.py \ + --dataset_path=$2 \ + --device_target=$1 \ + &> ../train.log & # dataset train folder +} if [ $1 = "GPU" ] ; then run_gpu "$@" +elif [ $1 = "CPU" ] ; then + run_cpu "$@" else echo "Unsupported device_target" fi; diff --git a/model_zoo/official/cv/mobilenetv3/src/config.py b/model_zoo/official/cv/mobilenetv3/src/config.py index ebc9e2e549..578edab27c 100644 --- a/model_zoo/official/cv/mobilenetv3/src/config.py +++ b/model_zoo/official/cv/mobilenetv3/src/config.py @@ -36,3 +36,23 @@ config_gpu = ed({ "export_format": "MINDIR", "export_file": "mobilenetv3" }) + +config_cpu = ed({ + "num_classes": 10, + "image_height": 224, + "image_width": 224, + "batch_size": 32, + "epoch_size": 120, + "warmup_epochs": 5, + "lr": 0.1, + "momentum": 0.9, + "weight_decay": 1e-4, + "label_smooth": 0.1, + "loss_scale": 1024, + "save_checkpoint": True, + "save_checkpoint_epochs": 1, + "keep_checkpoint_max": 500, + "save_checkpoint_path": "./checkpoint", + "export_format": "MINDIR", + "export_file": "mobilenetv3" +}) diff --git a/model_zoo/official/cv/mobilenetv3/src/dataset.py b/model_zoo/official/cv/mobilenetv3/src/dataset.py index ec082919f3..cd2f057cb9 100644 --- a/model_zoo/official/cv/mobilenetv3/src/dataset.py +++ b/model_zoo/official/cv/mobilenetv3/src/dataset.py @@ -83,3 +83,60 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, data_set = data_set.repeat(repeat_num) return data_set + +def create_dataset_cifar(dataset_path, + do_train, + repeat_num=1, + batch_size=32, + target="CPU"): + """ + create a train or evaluate cifar10 dataset + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + + Returns: + dataset + """ + data_set = ds.Cifar10Dataset(dataset_path, + num_parallel_workers=8, + shuffle=True) + # define map operations + if do_train: + trans = [ + C.RandomCrop((32, 32), (4, 4, 4, 4)), + C.RandomHorizontalFlip(prob=0.5), + C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4), + C.Resize((224, 224)), + C.Rescale(1.0 / 255.0, 0.0), + C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), + C.CutOut(112), + C.HWC2CHW() + ] + else: + trans = [ + C.Resize((224, 224)), + C.Rescale(1.0 / 255.0, 0.0), + C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), + C.HWC2CHW() + ] + + type_cast_op = C2.TypeCast(mstype.int32) + + data_set = data_set.map(operations=type_cast_op, + input_columns="label", + num_parallel_workers=8) + data_set = data_set.map(operations=trans, + input_columns="image", + num_parallel_workers=8) + + # apply batch operations + data_set = data_set.batch(batch_size, drop_remainder=True) + + # apply dataset repeat operation + data_set = data_set.repeat(repeat_num) + + return data_set diff --git a/model_zoo/official/cv/mobilenetv3/train.py b/model_zoo/official/cv/mobilenetv3/train.py index 1e351088b3..09724b64db 100644 --- a/model_zoo/official/cv/mobilenetv3/train.py +++ b/model_zoo/official/cv/mobilenetv3/train.py @@ -37,8 +37,10 @@ from mindspore.common import set_seed from mindspore.communication.management import init, get_group_size, get_rank from src.dataset import create_dataset +from src.dataset import create_dataset_cifar from src.lr_generator import get_lr from src.config import config_gpu +from src.config import config_cpu from src.mobilenetV3 import mobilenet_v3_large set_seed(1) @@ -59,6 +61,10 @@ if args_opt.device_target == "GPU": context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) +elif args_opt.device_target == "CPU": + context.set_context(mode=context.GRAPH_MODE, + device_target="CPU", + save_graphs=False) else: raise ValueError("Unsupported device_target.") @@ -151,58 +157,71 @@ class Monitor(Callback): if __name__ == '__main__': + config_ = None + if args_opt.device_target == "GPU": + config_ = config_gpu + elif args_opt.device_target == "CPU": + config_ = config_cpu + else: + raise ValueError("Unsupported device_target.") + # train on device + print("train args: ", args_opt) + print("cfg: ", config_) + + # define net + net = mobilenet_v3_large(num_classes=config_.num_classes) + # define loss + if config_.label_smooth > 0: + loss = CrossEntropyWithLabelSmooth( + smooth_factor=config_.label_smooth, num_classes=config_.num_classes) + else: + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + # define dataset + epoch_size = config_.epoch_size if args_opt.device_target == "GPU": - # train on gpu - print("train args: ", args_opt) - print("cfg: ", config_gpu) - - # define net - net = mobilenet_v3_large(num_classes=config_gpu.num_classes) - # define loss - if config_gpu.label_smooth > 0: - loss = CrossEntropyWithLabelSmooth( - smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes) - else: - loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - # define dataset - epoch_size = config_gpu.epoch_size dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, - config=config_gpu, + config=config_, device_target=args_opt.device_target, repeat_num=1, - batch_size=config_gpu.batch_size, - run_distribute=args_opt.run_distribute) - step_size = dataset.get_dataset_size() - # resume - if args_opt.pre_trained: - param_dict = load_checkpoint(args_opt.pre_trained) - load_param_into_net(net, param_dict) - # define optimizer - loss_scale = FixedLossScaleManager( - config_gpu.loss_scale, drop_overflow_update=False) - lr = Tensor(get_lr(global_step=0, - lr_init=0, - lr_end=0, - lr_max=config_gpu.lr, - warmup_epochs=config_gpu.warmup_epochs, - total_epochs=epoch_size, - steps_per_epoch=step_size)) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum, - config_gpu.weight_decay, config_gpu.loss_scale) - # define model - model = Model(net, loss_fn=loss, optimizer=opt, - loss_scale_manager=loss_scale) - - cb = [Monitor(lr_init=lr.asnumpy())] - if args_opt.run_distribute: - ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" - else: - ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/" - if config_gpu.save_checkpoint: - config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, - keep_checkpoint_max=config_gpu.keep_checkpoint_max) - ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck) - cb += [ckpt_cb] - # begine train - model.train(epoch_size, dataset, callbacks=cb) + batch_size=config_.batch_size, + run_distribute=False) + elif args_opt.device_target == "CPU": + dataset = create_dataset_cifar(args_opt.dataset_path, + do_train=True, + batch_size=config_.batch_size) + else: + raise ValueError("Unsupported device_target.") + step_size = dataset.get_dataset_size() + # resume + if args_opt.pre_trained: + param_dict = load_checkpoint(args_opt.pre_trained) + load_param_into_net(net, param_dict) + # define optimizer + loss_scale = FixedLossScaleManager( + config_.loss_scale, drop_overflow_update=False) + lr = Tensor(get_lr(global_step=0, + lr_init=0, + lr_end=0, + lr_max=config_.lr, + warmup_epochs=config_.warmup_epochs, + total_epochs=epoch_size, + steps_per_epoch=step_size)) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_.momentum, + config_.weight_decay, config_.loss_scale) + # define model + model = Model(net, loss_fn=loss, optimizer=opt, + loss_scale_manager=loss_scale) + + cb = [Monitor(lr_init=lr.asnumpy())] + if args_opt.run_distribute: + ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" + else: + ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/" + if config_.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=config_.save_checkpoint_epochs * step_size, + keep_checkpoint_max=config_.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck) + cb += [ckpt_cb] + # begine train + model.train(epoch_size, dataset, callbacks=cb)