| @@ -50,8 +50,8 @@ MobileNetV3总体网络架构如下: | |||
| # 环境要求 | |||
| - 硬件:GPU | |||
| - 准备GPU处理器搭建硬件环境。 | |||
| - 硬件:GPU/CPU | |||
| - 准备GPU/CPU处理器搭建硬件环境。 | |||
| - 框架 | |||
| - [MindSpore](https://www.mindspore.cn/install) | |||
| - 如需查看详情,请参见如下资源: | |||
| @@ -86,6 +86,7 @@ MobileNetV3总体网络架构如下: | |||
| 使用python或shell脚本开始训练。shell脚本的使用方法如下: | |||
| - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] | |||
| - CPU: sh run_trian.sh CPU [DATASET_PATH] | |||
| ### 启动 | |||
| @@ -93,8 +94,10 @@ MobileNetV3总体网络架构如下: | |||
| # 训练示例 | |||
| python: | |||
| GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU | |||
| CPU: python train.py --dataset_path ~/cifar10/train/ --device_targe CPU | |||
| shell: | |||
| GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ | |||
| CPU: sh run_train.sh CPU ~/cifar10/train/ | |||
| ``` | |||
| ### 结果 | |||
| @@ -115,6 +118,7 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917 | |||
| 使用python或shell脚本开始训练。shell脚本的使用方法如下: | |||
| - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] | |||
| - CPU: sh run_infer.sh CPU [DATASET_PATH] [CHECKPOINT_PATH] | |||
| ### 启动 | |||
| @@ -122,9 +126,11 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917 | |||
| # 推理示例 | |||
| python: | |||
| GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU | |||
| CPU: python eval.py --dataset_path ~/cifar10/val/ --checkpoint_path mobilenet_199.ckpt --device_targe CPU | |||
| shell: | |||
| GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt | |||
| CPU: sh run_infer.sh CPU ~/cifar10/val/ ~/train/mobilenet-200_625.ckpt | |||
| ``` | |||
| > 训练过程中可以生成检查点。 | |||
| @@ -19,7 +19,6 @@ | |||
| # [MobileNetV3 Description](#contents) | |||
| MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. | |||
| [Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for mobilenetv3." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. | |||
| @@ -35,37 +34,35 @@ The overall network architecture of MobileNetV3 is show below: | |||
| Dataset used: [imagenet](http://www.image-net.org/) | |||
| - Dataset size: ~125G, 1.2W colorful images in 1000 classes | |||
| - Train: 120G, 1.2W images | |||
| - Test: 5G, 50000 images | |||
| - Train: 120G, 1.2W images | |||
| - Test: 5G, 50000 images | |||
| - Data format: RGB images. | |||
| - Note: Data will be processed in src/dataset.py | |||
| - Note: Data will be processed in src/dataset.py | |||
| # [Environment Requirements](#contents) | |||
| - Hardware(GPU) | |||
| - Prepare hardware environment with GPU processor. | |||
| - Hardware(GPU/CPU) | |||
| - Prepare hardware environment with GPU/CPU processor. | |||
| - Framework | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - [MindSpore](https://www.mindspore.cn/install/en) | |||
| - For more information, please check the resources below: | |||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) | |||
| - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) | |||
| # [Script description](#contents) | |||
| ## [Script and sample code](#contents) | |||
| ```python | |||
| ├── MobileNetV3 | |||
| ├── Readme.md # descriptions about MobileNetV3 | |||
| ├── scripts | |||
| │ ├──run_train.sh # shell script for train | |||
| │ ├──run_eval.sh # shell script for evaluation | |||
| ├── src | |||
| │ ├──config.py # parameter configuration | |||
| ├── MobileNetV3 | |||
| ├── Readme.md # descriptions about MobileNetV3 | |||
| ├── scripts | |||
| │ ├──run_train.sh # shell script for train | |||
| │ ├──run_eval.sh # shell script for evaluation | |||
| ├── src | |||
| │ ├──config.py # parameter configuration | |||
| │ ├──dataset.py # creating dataset | |||
| │ ├──lr_generator.py # learning rate config | |||
| │ ├──lr_generator.py # learning rate config | |||
| │ ├──mobilenetV3.py # MobileNetV3 architecture | |||
| ├── train.py # training script | |||
| ├── eval.py # evaluation script | |||
| @@ -80,22 +77,25 @@ Dataset used: [imagenet](http://www.image-net.org/) | |||
| You can start training using python or shell scripts. The usage of shell scripts as follows: | |||
| - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] | |||
| - CPU: sh run_trian.sh CPU [DATASET_PATH] | |||
| ### Launch | |||
| ``` | |||
| ```shell | |||
| # training example | |||
| python: | |||
| GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU | |||
| CPU: python train.py --dataset_path ~/cifar10/train/ --device_targe CPU | |||
| shell: | |||
| GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ | |||
| CPU: sh run_train.sh CPU ~/cifar10/train/ | |||
| ``` | |||
| ### Result | |||
| Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. | |||
| Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings. | |||
| ``` | |||
| ```bash | |||
| epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] | |||
| epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 | |||
| epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] | |||
| @@ -109,25 +109,28 @@ epoch time: 138331.250, per step time: 221.330, avg loss: 3.917 | |||
| You can start training using python or shell scripts. The usage of shell scripts as follows: | |||
| - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] | |||
| - CPU: sh run_infer.sh CPU [DATASET_PATH] [CHECKPOINT_PATH] | |||
| ### Launch | |||
| ``` | |||
| ```shell | |||
| # infer example | |||
| python: | |||
| GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU | |||
| CPU: python eval.py --dataset_path ~/cifar10/val/ --checkpoint_path mobilenet_199.ckpt --device_targe CPU | |||
| shell: | |||
| GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt | |||
| CPU: sh run_infer.sh CPU ~/cifar10/val/ ~/train/mobilenet-200_625.ckpt | |||
| ``` | |||
| > checkpoint can be produced in training process. | |||
| > checkpoint can be produced in training process. | |||
| ### Result | |||
| Inference result will be stored in the example path, you can find result like the followings in `val.log`. | |||
| Inference result will be stored in the example path, you can find result like the followings in `val.log`. | |||
| ``` | |||
| ```bash | |||
| result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt | |||
| ``` | |||
| @@ -135,7 +138,7 @@ result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625. | |||
| Change the export mode and export file in `src/config.py`, and run `export.py`. | |||
| ``` | |||
| ```python | |||
| python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH] | |||
| ``` | |||
| @@ -168,5 +171,5 @@ python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH] | |||
| In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. | |||
| # [ModelZoo Homepage](#contents) | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). | |||
| @@ -21,7 +21,9 @@ from mindspore import nn | |||
| from mindspore.train.model import Model | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from src.dataset import create_dataset | |||
| from src.dataset import create_dataset_cifar | |||
| from src.config import config_gpu | |||
| from src.config import config_cpu | |||
| from src.mobilenetV3 import mobilenet_v3_large | |||
| @@ -38,17 +40,24 @@ if __name__ == '__main__': | |||
| config = config_gpu | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="GPU", save_graphs=False) | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, | |||
| do_train=False, | |||
| config=config, | |||
| device_target=args_opt.device_target, | |||
| batch_size=config.batch_size) | |||
| elif args_opt.device_target == "CPU": | |||
| config = config_cpu | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="CPU", save_graphs=False) | |||
| dataset = create_dataset_cifar(dataset_path=args_opt.dataset_path, | |||
| do_train=False, | |||
| batch_size=config.batch_size) | |||
| else: | |||
| raise ValueError("Unsupported device_target.") | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||
| net = mobilenet_v3_large(num_classes=config.num_classes, activation="Softmax") | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, | |||
| do_train=False, | |||
| config=config, | |||
| device_target=args_opt.device_target, | |||
| batch_size=config.batch_size) | |||
| step_size = dataset.get_dataset_size() | |||
| if args_opt.checkpoint_path: | |||
| @@ -19,6 +19,7 @@ import argparse | |||
| import numpy as np | |||
| from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export | |||
| from src.config import config_gpu | |||
| from src.config import config_cpu | |||
| from src.mobilenetV3 import mobilenet_v3_large | |||
| @@ -32,6 +33,9 @@ if __name__ == '__main__': | |||
| if args_opt.device_target == "GPU": | |||
| cfg = config_gpu | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| elif args_opt.device_target == "CPU": | |||
| cfg = config_cpu | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="CPU") | |||
| else: | |||
| raise ValueError("Unsupported device_target.") | |||
| @@ -16,6 +16,7 @@ | |||
| if [ $# != 3 ] | |||
| then | |||
| echo "GPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]" | |||
| echo "CPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]" | |||
| exit 1 | |||
| fi | |||
| @@ -16,6 +16,14 @@ | |||
| run_gpu() | |||
| { | |||
| if [ $# -gt 5 ] || [ $# -lt 4 ] | |||
| then | |||
| echo "Usage:\n \ | |||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ | |||
| CPU: sh run_train.sh CPU [DATASET_PATH]\n \ | |||
| " | |||
| exit 1 | |||
| fi | |||
| if [ $2 -lt 1 ] && [ $2 -gt 8 ] | |||
| then | |||
| echo "error: DEVICE_NUM=$2 is not in (1-8)" | |||
| @@ -45,16 +53,42 @@ run_gpu() | |||
| &> ../train.log & # dataset train folder | |||
| } | |||
| if [ $# -gt 5 ] || [ $# -lt 4 ] | |||
| then | |||
| echo "Usage:\n \ | |||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ | |||
| " | |||
| exit 1 | |||
| fi | |||
| run_cpu() | |||
| { | |||
| if [ $# -gt 3 ] || [ $# -lt 2 ] | |||
| then | |||
| echo "Usage:\n \ | |||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \ | |||
| CPU: sh run_train.sh CPU [DATASET_PATH]\n \ | |||
| " | |||
| exit 1 | |||
| fi | |||
| if [ ! -d $2 ] | |||
| then | |||
| echo "error: DATASET_PATH=$2 is not a directory" | |||
| exit 1 | |||
| fi | |||
| BASEPATH=$(cd "`dirname $0`" || exit; pwd) | |||
| export PYTHONPATH=${BASEPATH}:$PYTHONPATH | |||
| if [ -d "../train" ]; | |||
| then | |||
| rm -rf ../train | |||
| fi | |||
| mkdir ../train | |||
| cd ../train || exit | |||
| python ${BASEPATH}/../train.py \ | |||
| --dataset_path=$2 \ | |||
| --device_target=$1 \ | |||
| &> ../train.log & # dataset train folder | |||
| } | |||
| if [ $1 = "GPU" ] ; then | |||
| run_gpu "$@" | |||
| elif [ $1 = "CPU" ] ; then | |||
| run_cpu "$@" | |||
| else | |||
| echo "Unsupported device_target" | |||
| fi; | |||
| @@ -36,3 +36,23 @@ config_gpu = ed({ | |||
| "export_format": "MINDIR", | |||
| "export_file": "mobilenetv3" | |||
| }) | |||
| config_cpu = ed({ | |||
| "num_classes": 10, | |||
| "image_height": 224, | |||
| "image_width": 224, | |||
| "batch_size": 32, | |||
| "epoch_size": 120, | |||
| "warmup_epochs": 5, | |||
| "lr": 0.1, | |||
| "momentum": 0.9, | |||
| "weight_decay": 1e-4, | |||
| "label_smooth": 0.1, | |||
| "loss_scale": 1024, | |||
| "save_checkpoint": True, | |||
| "save_checkpoint_epochs": 1, | |||
| "keep_checkpoint_max": 500, | |||
| "save_checkpoint_path": "./checkpoint", | |||
| "export_format": "MINDIR", | |||
| "export_file": "mobilenetv3" | |||
| }) | |||
| @@ -83,3 +83,60 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, | |||
| data_set = data_set.repeat(repeat_num) | |||
| return data_set | |||
| def create_dataset_cifar(dataset_path, | |||
| do_train, | |||
| repeat_num=1, | |||
| batch_size=32, | |||
| target="CPU"): | |||
| """ | |||
| create a train or evaluate cifar10 dataset | |||
| Args: | |||
| dataset_path(string): the path of dataset. | |||
| do_train(bool): whether dataset is used for train or eval. | |||
| repeat_num(int): the repeat times of dataset. Default: 1 | |||
| batch_size(int): the batch size of dataset. Default: 32 | |||
| target(str): the device target. Default: Ascend | |||
| Returns: | |||
| dataset | |||
| """ | |||
| data_set = ds.Cifar10Dataset(dataset_path, | |||
| num_parallel_workers=8, | |||
| shuffle=True) | |||
| # define map operations | |||
| if do_train: | |||
| trans = [ | |||
| C.RandomCrop((32, 32), (4, 4, 4, 4)), | |||
| C.RandomHorizontalFlip(prob=0.5), | |||
| C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4), | |||
| C.Resize((224, 224)), | |||
| C.Rescale(1.0 / 255.0, 0.0), | |||
| C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), | |||
| C.CutOut(112), | |||
| C.HWC2CHW() | |||
| ] | |||
| else: | |||
| trans = [ | |||
| C.Resize((224, 224)), | |||
| C.Rescale(1.0 / 255.0, 0.0), | |||
| C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), | |||
| C.HWC2CHW() | |||
| ] | |||
| type_cast_op = C2.TypeCast(mstype.int32) | |||
| data_set = data_set.map(operations=type_cast_op, | |||
| input_columns="label", | |||
| num_parallel_workers=8) | |||
| data_set = data_set.map(operations=trans, | |||
| input_columns="image", | |||
| num_parallel_workers=8) | |||
| # apply batch operations | |||
| data_set = data_set.batch(batch_size, drop_remainder=True) | |||
| # apply dataset repeat operation | |||
| data_set = data_set.repeat(repeat_num) | |||
| return data_set | |||
| @@ -37,8 +37,10 @@ from mindspore.common import set_seed | |||
| from mindspore.communication.management import init, get_group_size, get_rank | |||
| from src.dataset import create_dataset | |||
| from src.dataset import create_dataset_cifar | |||
| from src.lr_generator import get_lr | |||
| from src.config import config_gpu | |||
| from src.config import config_cpu | |||
| from src.mobilenetV3 import mobilenet_v3_large | |||
| set_seed(1) | |||
| @@ -59,6 +61,10 @@ if args_opt.device_target == "GPU": | |||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True) | |||
| elif args_opt.device_target == "CPU": | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="CPU", | |||
| save_graphs=False) | |||
| else: | |||
| raise ValueError("Unsupported device_target.") | |||
| @@ -151,58 +157,71 @@ class Monitor(Callback): | |||
| if __name__ == '__main__': | |||
| config_ = None | |||
| if args_opt.device_target == "GPU": | |||
| config_ = config_gpu | |||
| elif args_opt.device_target == "CPU": | |||
| config_ = config_cpu | |||
| else: | |||
| raise ValueError("Unsupported device_target.") | |||
| # train on device | |||
| print("train args: ", args_opt) | |||
| print("cfg: ", config_) | |||
| # define net | |||
| net = mobilenet_v3_large(num_classes=config_.num_classes) | |||
| # define loss | |||
| if config_.label_smooth > 0: | |||
| loss = CrossEntropyWithLabelSmooth( | |||
| smooth_factor=config_.label_smooth, num_classes=config_.num_classes) | |||
| else: | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||
| # define dataset | |||
| epoch_size = config_.epoch_size | |||
| if args_opt.device_target == "GPU": | |||
| # train on gpu | |||
| print("train args: ", args_opt) | |||
| print("cfg: ", config_gpu) | |||
| # define net | |||
| net = mobilenet_v3_large(num_classes=config_gpu.num_classes) | |||
| # define loss | |||
| if config_gpu.label_smooth > 0: | |||
| loss = CrossEntropyWithLabelSmooth( | |||
| smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes) | |||
| else: | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | |||
| # define dataset | |||
| epoch_size = config_gpu.epoch_size | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, | |||
| do_train=True, | |||
| config=config_gpu, | |||
| config=config_, | |||
| device_target=args_opt.device_target, | |||
| repeat_num=1, | |||
| batch_size=config_gpu.batch_size, | |||
| run_distribute=args_opt.run_distribute) | |||
| step_size = dataset.get_dataset_size() | |||
| # resume | |||
| if args_opt.pre_trained: | |||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||
| load_param_into_net(net, param_dict) | |||
| # define optimizer | |||
| loss_scale = FixedLossScaleManager( | |||
| config_gpu.loss_scale, drop_overflow_update=False) | |||
| lr = Tensor(get_lr(global_step=0, | |||
| lr_init=0, | |||
| lr_end=0, | |||
| lr_max=config_gpu.lr, | |||
| warmup_epochs=config_gpu.warmup_epochs, | |||
| total_epochs=epoch_size, | |||
| steps_per_epoch=step_size)) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum, | |||
| config_gpu.weight_decay, config_gpu.loss_scale) | |||
| # define model | |||
| model = Model(net, loss_fn=loss, optimizer=opt, | |||
| loss_scale_manager=loss_scale) | |||
| cb = [Monitor(lr_init=lr.asnumpy())] | |||
| if args_opt.run_distribute: | |||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||
| else: | |||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/" | |||
| if config_gpu.save_checkpoint: | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | |||
| keep_checkpoint_max=config_gpu.keep_checkpoint_max) | |||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck) | |||
| cb += [ckpt_cb] | |||
| # begine train | |||
| model.train(epoch_size, dataset, callbacks=cb) | |||
| batch_size=config_.batch_size, | |||
| run_distribute=False) | |||
| elif args_opt.device_target == "CPU": | |||
| dataset = create_dataset_cifar(args_opt.dataset_path, | |||
| do_train=True, | |||
| batch_size=config_.batch_size) | |||
| else: | |||
| raise ValueError("Unsupported device_target.") | |||
| step_size = dataset.get_dataset_size() | |||
| # resume | |||
| if args_opt.pre_trained: | |||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||
| load_param_into_net(net, param_dict) | |||
| # define optimizer | |||
| loss_scale = FixedLossScaleManager( | |||
| config_.loss_scale, drop_overflow_update=False) | |||
| lr = Tensor(get_lr(global_step=0, | |||
| lr_init=0, | |||
| lr_end=0, | |||
| lr_max=config_.lr, | |||
| warmup_epochs=config_.warmup_epochs, | |||
| total_epochs=epoch_size, | |||
| steps_per_epoch=step_size)) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_.momentum, | |||
| config_.weight_decay, config_.loss_scale) | |||
| # define model | |||
| model = Model(net, loss_fn=loss, optimizer=opt, | |||
| loss_scale_manager=loss_scale) | |||
| cb = [Monitor(lr_init=lr.asnumpy())] | |||
| if args_opt.run_distribute: | |||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||
| else: | |||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/" | |||
| if config_.save_checkpoint: | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_.save_checkpoint_epochs * step_size, | |||
| keep_checkpoint_max=config_.keep_checkpoint_max) | |||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck) | |||
| cb += [ckpt_cb] | |||
| # begine train | |||
| model.train(epoch_size, dataset, callbacks=cb) | |||