Browse Source

!4041 modify mobilenetv2 quant scripts and fix bug

Merge pull request !4041 from chengxb7532/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
92d93ebce5
3 changed files with 15 additions and 9 deletions
  1. +5
    -5
      model_zoo/official/cv/mobilenetv2_quant/scripts/run_train_quant.sh
  2. +8
    -2
      model_zoo/official/cv/mobilenetv2_quant/src/dataset.py
  3. +2
    -2
      model_zoo/official/cv/mobilenetv2_quant/train.py

+ 5
- 5
model_zoo/official/cv/mobilenetv2_quant/scripts/run_train_quant.sh View File

@@ -75,15 +75,15 @@ run_gpu()
python ${BASEPATH}/../train.py \
--dataset_path=$4 \
--device_target=$1 \
--quantization_aware=True \
&> ../train.log & # dataset train folder
--pre_trained=$5 \
--quantization_aware=True &> ../train.log & # dataset train folder
}

if [ $# -gt 6 ] || [ $# -lt 4 ]
if [ $# -gt 6 ] || [ $# -lt 5 ]
then
echo "Usage:\n \
Ascend: sh run_train.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]\n \
Ascend: sh run_train_quant.sh Ascend [DEVICE_NUM] [SERVER_IP(x.x.x.x)] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
GPU: sh run_train_quant.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH]\n \
"
exit 1
fi


+ 8
- 2
model_zoo/official/cv/mobilenetv2_quant/src/dataset.py View File

@@ -22,7 +22,6 @@ import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.transforms.vision.py_transforms as P
from src.config import config_ascend


def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32):
@@ -42,7 +41,7 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
rank_size = int(os.getenv("RANK_SIZE"))
rank_id = int(os.getenv("RANK_ID"))
columns_list = ['image', 'label']
if config_ascend.data_load_mode == "mindrecord":
if config.data_load_mode == "mindrecord":
load_func = partial(de.MindDataset, dataset_path, columns_list)
else:
load_func = partial(de.ImageFolderDatasetV2, dataset_path)
@@ -54,6 +53,13 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
num_shards=rank_size, shard_id=rank_id)
else:
ds = load_func(num_parallel_workers=8, shuffle=False)
elif device_target == "GPU":
if do_train:
from mindspore.communication.management import get_rank, get_group_size
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=get_group_size(), shard_id=get_rank())
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True)
else:
raise ValueError("Unsupport device_target.")



+ 2
- 2
model_zoo/official/cv/mobilenetv2_quant/train.py View File

@@ -56,7 +56,7 @@ if args_opt.device_target == "Ascend":
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
device_id=device_id, save_graphs=False)
elif args_opt.platform == "GPU":
elif args_opt.device_target == "GPU":
init("nccl")
context.set_auto_parallel_context(device_num=get_group_size(),
parallel_mode=ParallelMode.DATA_PARALLEL,
@@ -205,5 +205,5 @@ def train_on_gpu():
if __name__ == '__main__':
if args_opt.device_target == "Ascend":
train_on_ascend()
elif args_opt.platform == "GPU":
elif args_opt.device_target == "GPU":
train_on_gpu()

Loading…
Cancel
Save