Browse Source

Add MobileNetV3 CPU scripts

tags/v1.2.0-rc1
wuxuejian 4 years ago
parent
commit
72cbb44d6e
9 changed files with 246 additions and 93 deletions
  1. +8
    -2
      model_zoo/official/cv/mobilenetv3/README_CN.md
  2. +32
    -29
      model_zoo/official/cv/mobilenetv3/Readme.md
  3. +14
    -5
      model_zoo/official/cv/mobilenetv3/eval.py
  4. +4
    -0
      model_zoo/official/cv/mobilenetv3/export.py
  5. +1
    -0
      model_zoo/official/cv/mobilenetv3/scripts/run_infer.sh
  6. +41
    -7
      model_zoo/official/cv/mobilenetv3/scripts/run_train.sh
  7. +20
    -0
      model_zoo/official/cv/mobilenetv3/src/config.py
  8. +57
    -0
      model_zoo/official/cv/mobilenetv3/src/dataset.py
  9. +69
    -50
      model_zoo/official/cv/mobilenetv3/train.py

+ 8
- 2
model_zoo/official/cv/mobilenetv3/README_CN.md View File

@@ -50,8 +50,8 @@ MobileNetV3总体网络架构如下:


# 环境要求 # 环境要求


- 硬件:GPU
- 准备GPU处理器搭建硬件环境。
- 硬件:GPU/CPU
- 准备GPU/CPU处理器搭建硬件环境。
- 框架 - 框架
- [MindSpore](https://www.mindspore.cn/install) - [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源: - 如需查看详情,请参见如下资源:
@@ -86,6 +86,7 @@ MobileNetV3总体网络架构如下:
使用python或shell脚本开始训练。shell脚本的使用方法如下: 使用python或shell脚本开始训练。shell脚本的使用方法如下:


- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
- CPU: sh run_trian.sh CPU [DATASET_PATH]


### 启动 ### 启动


@@ -93,8 +94,10 @@ MobileNetV3总体网络架构如下:
# 训练示例 # 训练示例
python: python:
GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU
CPU: python train.py --dataset_path ~/cifar10/train/ --device_targe CPU
shell: shell:
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/
CPU: sh run_train.sh CPU ~/cifar10/train/
``` ```


### 结果 ### 结果
@@ -115,6 +118,7 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917
使用python或shell脚本开始训练。shell脚本的使用方法如下: 使用python或shell脚本开始训练。shell脚本的使用方法如下:


- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH]
- CPU: sh run_infer.sh CPU [DATASET_PATH] [CHECKPOINT_PATH]


### 启动 ### 启动


@@ -122,9 +126,11 @@ epoch time:138331.250, per step time:221.330, avg loss:3.917
# 推理示例 # 推理示例
python: python:
GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU
CPU: python eval.py --dataset_path ~/cifar10/val/ --checkpoint_path mobilenet_199.ckpt --device_targe CPU


shell: shell:
GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
CPU: sh run_infer.sh CPU ~/cifar10/val/ ~/train/mobilenet-200_625.ckpt
``` ```


> 训练过程中可以生成检查点。 > 训练过程中可以生成检查点。


+ 32
- 29
model_zoo/official/cv/mobilenetv3/Readme.md View File

@@ -19,7 +19,6 @@


# [MobileNetV3 Description](#contents) # [MobileNetV3 Description](#contents)



MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019. MobileNetV3 is tuned to mobile phone CPUs through a combination of hardware- aware network architecture search (NAS) complemented by the NetAdapt algorithm and then subsequently improved through novel architecture advances.Nov 20, 2019.


[Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for mobilenetv3." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019. [Paper](https://arxiv.org/pdf/1905.02244) Howard, Andrew, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang et al. "Searching for mobilenetv3." In Proceedings of the IEEE International Conference on Computer Vision, pp. 1314-1324. 2019.
@@ -35,37 +34,35 @@ The overall network architecture of MobileNetV3 is show below:
Dataset used: [imagenet](http://www.image-net.org/) Dataset used: [imagenet](http://www.image-net.org/)


- Dataset size: ~125G, 1.2W colorful images in 1000 classes - Dataset size: ~125G, 1.2W colorful images in 1000 classes
- Train: 120G, 1.2W images
- Test: 5G, 50000 images
- Train: 120G, 1.2W images
- Test: 5G, 50000 images
- Data format: RGB images. - Data format: RGB images.
- Note: Data will be processed in src/dataset.py

- Note: Data will be processed in src/dataset.py


# [Environment Requirements](#contents) # [Environment Requirements](#contents)


- Hardware(GPU)
- Prepare hardware environment with GPU processor.
- Hardware(GPU/CPU
- Prepare hardware environment with GPU/CPU processor.
- Framework - Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below: - For more information, please check the resources below:
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)

- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)


# [Script description](#contents) # [Script description](#contents)


## [Script and sample code](#contents) ## [Script and sample code](#contents)


```python ```python
├── MobileNetV3
├── Readme.md # descriptions about MobileNetV3
├── scripts
│ ├──run_train.sh # shell script for train
│ ├──run_eval.sh # shell script for evaluation
├── src
│ ├──config.py # parameter configuration
├── MobileNetV3
├── Readme.md # descriptions about MobileNetV3
├── scripts
│ ├──run_train.sh # shell script for train
│ ├──run_eval.sh # shell script for evaluation
├── src
│ ├──config.py # parameter configuration
│ ├──dataset.py # creating dataset │ ├──dataset.py # creating dataset
│ ├──lr_generator.py # learning rate config
│ ├──lr_generator.py # learning rate config
│ ├──mobilenetV3.py # MobileNetV3 architecture │ ├──mobilenetV3.py # MobileNetV3 architecture
├── train.py # training script ├── train.py # training script
├── eval.py # evaluation script ├── eval.py # evaluation script
@@ -80,22 +77,25 @@ Dataset used: [imagenet](http://www.image-net.org/)
You can start training using python or shell scripts. The usage of shell scripts as follows: You can start training using python or shell scripts. The usage of shell scripts as follows:


- GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] - GPU: sh run_trian.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]
- CPU: sh run_trian.sh CPU [DATASET_PATH]


### Launch ### Launch


```
```shell
# training example # training example
python: python:
GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU GPU: python train.py --dataset_path ~/imagenet/train/ --device_targe GPU
CPU: python train.py --dataset_path ~/cifar10/train/ --device_targe CPU
shell: shell:
GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/ GPU: sh run_train.sh GPU 8 0,1,2,3,4,5,6,7 ~/imagenet/train/
CPU: sh run_train.sh CPU ~/cifar10/train/
``` ```


### Result ### Result


Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings.
Training result will be stored in the example path. Checkpoints will be stored at `. /checkpoint` by default, and training log will be redirected to `./train/train.log` like followings.


```
```bash
epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100] epoch: [ 0/200], step:[ 624/ 625], loss:[5.258/5.258], time:[140412.236], lr:[0.100]
epoch time: 140522.500, per step time: 224.836, avg loss: 5.258 epoch time: 140522.500, per step time: 224.836, avg loss: 5.258
epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200] epoch: [ 1/200], step:[ 624/ 625], loss:[3.917/3.917], time:[138221.250], lr:[0.200]
@@ -109,25 +109,28 @@ epoch time: 138331.250, per step time: 221.330, avg loss: 3.917
You can start training using python or shell scripts. The usage of shell scripts as follows: You can start training using python or shell scripts. The usage of shell scripts as follows:


- GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH] - GPU: sh run_infer.sh GPU [DATASET_PATH] [CHECKPOINT_PATH]
- CPU: sh run_infer.sh CPU [DATASET_PATH] [CHECKPOINT_PATH]


### Launch ### Launch


```
```shell
# infer example # infer example
python: python:
GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU GPU: python eval.py --dataset_path ~/imagenet/val/ --checkpoint_path mobilenet_199.ckpt --device_targe GPU
CPU: python eval.py --dataset_path ~/cifar10/val/ --checkpoint_path mobilenet_199.ckpt --device_targe CPU


shell: shell:
GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt GPU: sh run_infer.sh GPU ~/imagenet/val/ ~/train/mobilenet-200_625.ckpt
CPU: sh run_infer.sh CPU ~/cifar10/val/ ~/train/mobilenet-200_625.ckpt
``` ```


> checkpoint can be produced in training process.
> checkpoint can be produced in training process.


### Result ### Result


Inference result will be stored in the example path, you can find result like the followings in `val.log`.
Inference result will be stored in the example path, you can find result like the followings in `val.log`.


```
```bash
result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.ckpt
``` ```


@@ -135,7 +138,7 @@ result: {'acc': 0.71976314102564111} ckpt=/path/to/checkpoint/mobilenet-200_625.


Change the export mode and export file in `src/config.py`, and run `export.py`. Change the export mode and export file in `src/config.py`, and run `export.py`.


```
```python
python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH] python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH]
``` ```


@@ -168,5 +171,5 @@ python export.py --device_target [PLATFORM] --checkpoint_path [CKPT_PATH]
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py.


# [ModelZoo Homepage](#contents) # [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

+ 14
- 5
model_zoo/official/cv/mobilenetv3/eval.py View File

@@ -21,7 +21,9 @@ from mindspore import nn
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import create_dataset from src.dataset import create_dataset
from src.dataset import create_dataset_cifar
from src.config import config_gpu from src.config import config_gpu
from src.config import config_cpu
from src.mobilenetV3 import mobilenet_v3_large from src.mobilenetV3 import mobilenet_v3_large




@@ -38,17 +40,24 @@ if __name__ == '__main__':
config = config_gpu config = config_gpu
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="GPU", save_graphs=False) device_target="GPU", save_graphs=False)
dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False,
config=config,
device_target=args_opt.device_target,
batch_size=config.batch_size)
elif args_opt.device_target == "CPU":
config = config_cpu
context.set_context(mode=context.GRAPH_MODE,
device_target="CPU", save_graphs=False)
dataset = create_dataset_cifar(dataset_path=args_opt.dataset_path,
do_train=False,
batch_size=config.batch_size)
else: else:
raise ValueError("Unsupported device_target.") raise ValueError("Unsupported device_target.")


loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net = mobilenet_v3_large(num_classes=config.num_classes, activation="Softmax") net = mobilenet_v3_large(num_classes=config.num_classes, activation="Softmax")


dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=False,
config=config,
device_target=args_opt.device_target,
batch_size=config.batch_size)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()


if args_opt.checkpoint_path: if args_opt.checkpoint_path:


+ 4
- 0
model_zoo/official/cv/mobilenetv3/export.py View File

@@ -19,6 +19,7 @@ import argparse
import numpy as np import numpy as np
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from src.config import config_gpu from src.config import config_gpu
from src.config import config_cpu
from src.mobilenetV3 import mobilenet_v3_large from src.mobilenetV3 import mobilenet_v3_large




@@ -32,6 +33,9 @@ if __name__ == '__main__':
if args_opt.device_target == "GPU": if args_opt.device_target == "GPU":
cfg = config_gpu cfg = config_gpu
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
elif args_opt.device_target == "CPU":
cfg = config_cpu
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
else: else:
raise ValueError("Unsupported device_target.") raise ValueError("Unsupported device_target.")




+ 1
- 0
model_zoo/official/cv/mobilenetv3/scripts/run_infer.sh View File

@@ -16,6 +16,7 @@
if [ $# != 3 ] if [ $# != 3 ]
then then
echo "GPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]" echo "GPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]"
echo "CPU: sh run_infer.sh [DEVICE_TARGET] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1 exit 1
fi fi




+ 41
- 7
model_zoo/official/cv/mobilenetv3/scripts/run_train.sh View File

@@ -16,6 +16,14 @@


run_gpu() run_gpu()
{ {
if [ $# -gt 5 ] || [ $# -lt 4 ]
then
echo "Usage:\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \
CPU: sh run_train.sh CPU [DATASET_PATH]\n \
"
exit 1
fi
if [ $2 -lt 1 ] && [ $2 -gt 8 ] if [ $2 -lt 1 ] && [ $2 -gt 8 ]
then then
echo "error: DEVICE_NUM=$2 is not in (1-8)" echo "error: DEVICE_NUM=$2 is not in (1-8)"
@@ -45,16 +53,42 @@ run_gpu()
&> ../train.log & # dataset train folder &> ../train.log & # dataset train folder
} }


if [ $# -gt 5 ] || [ $# -lt 4 ]
then
echo "Usage:\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \
"
exit 1
fi
run_cpu()
{
if [ $# -gt 3 ] || [ $# -lt 2 ]
then
echo "Usage:\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \
CPU: sh run_train.sh CPU [DATASET_PATH]\n \
"
exit 1
fi

if [ ! -d $2 ]
then
echo "error: DATASET_PATH=$2 is not a directory"
exit 1
fi

BASEPATH=$(cd "`dirname $0`" || exit; pwd)
export PYTHONPATH=${BASEPATH}:$PYTHONPATH
if [ -d "../train" ];
then
rm -rf ../train
fi
mkdir ../train
cd ../train || exit

python ${BASEPATH}/../train.py \
--dataset_path=$2 \
--device_target=$1 \
&> ../train.log & # dataset train folder
}


if [ $1 = "GPU" ] ; then if [ $1 = "GPU" ] ; then
run_gpu "$@" run_gpu "$@"
elif [ $1 = "CPU" ] ; then
run_cpu "$@"
else else
echo "Unsupported device_target" echo "Unsupported device_target"
fi; fi;


+ 20
- 0
model_zoo/official/cv/mobilenetv3/src/config.py View File

@@ -36,3 +36,23 @@ config_gpu = ed({
"export_format": "MINDIR", "export_format": "MINDIR",
"export_file": "mobilenetv3" "export_file": "mobilenetv3"
}) })

config_cpu = ed({
"num_classes": 10,
"image_height": 224,
"image_width": 224,
"batch_size": 32,
"epoch_size": 120,
"warmup_epochs": 5,
"lr": 0.1,
"momentum": 0.9,
"weight_decay": 1e-4,
"label_smooth": 0.1,
"loss_scale": 1024,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 500,
"save_checkpoint_path": "./checkpoint",
"export_format": "MINDIR",
"export_file": "mobilenetv3"
})

+ 57
- 0
model_zoo/official/cv/mobilenetv3/src/dataset.py View File

@@ -83,3 +83,60 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
data_set = data_set.repeat(repeat_num) data_set = data_set.repeat(repeat_num)


return data_set return data_set

def create_dataset_cifar(dataset_path,
do_train,
repeat_num=1,
batch_size=32,
target="CPU"):
"""
create a train or evaluate cifar10 dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend

Returns:
dataset
"""
data_set = ds.Cifar10Dataset(dataset_path,
num_parallel_workers=8,
shuffle=True)
# define map operations
if do_train:
trans = [
C.RandomCrop((32, 32), (4, 4, 4, 4)),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4),
C.Resize((224, 224)),
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.CutOut(112),
C.HWC2CHW()
]
else:
trans = [
C.Resize((224, 224)),
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
C.HWC2CHW()
]

type_cast_op = C2.TypeCast(mstype.int32)

data_set = data_set.map(operations=type_cast_op,
input_columns="label",
num_parallel_workers=8)
data_set = data_set.map(operations=trans,
input_columns="image",
num_parallel_workers=8)

# apply batch operations
data_set = data_set.batch(batch_size, drop_remainder=True)

# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)

return data_set

+ 69
- 50
model_zoo/official/cv/mobilenetv3/train.py View File

@@ -37,8 +37,10 @@ from mindspore.common import set_seed
from mindspore.communication.management import init, get_group_size, get_rank from mindspore.communication.management import init, get_group_size, get_rank


from src.dataset import create_dataset from src.dataset import create_dataset
from src.dataset import create_dataset_cifar
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.config import config_gpu from src.config import config_gpu
from src.config import config_cpu
from src.mobilenetV3 import mobilenet_v3_large from src.mobilenetV3 import mobilenet_v3_large


set_seed(1) set_seed(1)
@@ -59,6 +61,10 @@ if args_opt.device_target == "GPU":
context.set_auto_parallel_context(device_num=get_group_size(), context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True)
elif args_opt.device_target == "CPU":
context.set_context(mode=context.GRAPH_MODE,
device_target="CPU",
save_graphs=False)
else: else:
raise ValueError("Unsupported device_target.") raise ValueError("Unsupported device_target.")


@@ -151,58 +157,71 @@ class Monitor(Callback):




if __name__ == '__main__': if __name__ == '__main__':
config_ = None
if args_opt.device_target == "GPU":
config_ = config_gpu
elif args_opt.device_target == "CPU":
config_ = config_cpu
else:
raise ValueError("Unsupported device_target.")
# train on device
print("train args: ", args_opt)
print("cfg: ", config_)

# define net
net = mobilenet_v3_large(num_classes=config_.num_classes)
# define loss
if config_.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(
smooth_factor=config_.label_smooth, num_classes=config_.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define dataset
epoch_size = config_.epoch_size
if args_opt.device_target == "GPU": if args_opt.device_target == "GPU":
# train on gpu
print("train args: ", args_opt)
print("cfg: ", config_gpu)

# define net
net = mobilenet_v3_large(num_classes=config_gpu.num_classes)
# define loss
if config_gpu.label_smooth > 0:
loss = CrossEntropyWithLabelSmooth(
smooth_factor=config_gpu.label_smooth, num_classes=config_gpu.num_classes)
else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# define dataset
epoch_size = config_gpu.epoch_size
dataset = create_dataset(dataset_path=args_opt.dataset_path, dataset = create_dataset(dataset_path=args_opt.dataset_path,
do_train=True, do_train=True,
config=config_gpu,
config=config_,
device_target=args_opt.device_target, device_target=args_opt.device_target,
repeat_num=1, repeat_num=1,
batch_size=config_gpu.batch_size,
run_distribute=args_opt.run_distribute)
step_size = dataset.get_dataset_size()
# resume
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
# define optimizer
loss_scale = FixedLossScaleManager(
config_gpu.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=0,
lr_init=0,
lr_end=0,
lr_max=config_gpu.lr,
warmup_epochs=config_gpu.warmup_epochs,
total_epochs=epoch_size,
steps_per_epoch=step_size))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_gpu.momentum,
config_gpu.weight_decay, config_gpu.loss_scale)
# define model
model = Model(net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale)

cb = [Monitor(lr_init=lr.asnumpy())]
if args_opt.run_distribute:
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
else:
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/"
if config_gpu.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_gpu.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# begine train
model.train(epoch_size, dataset, callbacks=cb)
batch_size=config_.batch_size,
run_distribute=False)
elif args_opt.device_target == "CPU":
dataset = create_dataset_cifar(args_opt.dataset_path,
do_train=True,
batch_size=config_.batch_size)
else:
raise ValueError("Unsupported device_target.")
step_size = dataset.get_dataset_size()
# resume
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_param_into_net(net, param_dict)
# define optimizer
loss_scale = FixedLossScaleManager(
config_.loss_scale, drop_overflow_update=False)
lr = Tensor(get_lr(global_step=0,
lr_init=0,
lr_end=0,
lr_max=config_.lr,
warmup_epochs=config_.warmup_epochs,
total_epochs=epoch_size,
steps_per_epoch=step_size))
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config_.momentum,
config_.weight_decay, config_.loss_scale)
# define model
model = Model(net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale)

cb = [Monitor(lr_init=lr.asnumpy())]
if args_opt.run_distribute:
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/"
else:
ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/"
if config_.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config_.save_checkpoint_epochs * step_size,
keep_checkpoint_max=config_.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck)
cb += [ckpt_cb]
# begine train
model.train(epoch_size, dataset, callbacks=cb)

Loading…
Cancel
Save