Merge pull request !3719 from RobinGrosman/mastertags/v0.7.0-beta
| @@ -58,11 +58,8 @@ if __name__ == '__main__': | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False, device_id=device_id) | |||
| # create dataset | |||
| if args_opt.net == "resnet50": | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size, | |||
| target=target) | |||
| else: | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size) | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, batch_size=config.batch_size, | |||
| target=target) | |||
| step_size = dataset.get_dataset_size() | |||
| # define net | |||
| @@ -0,0 +1,93 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 3 ] && [ $# != 4 ] | |||
| then | |||
| echo "Usage: sh run_distribute_train_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||
| exit 1 | |||
| fi | |||
| if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] | |||
| then | |||
| echo "error: the selected net is neither resnet50 nor resnet101" | |||
| exit 1 | |||
| fi | |||
| if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ] | |||
| then | |||
| echo "error: the selected dataset is neither cifar10 nor imagenet2012" | |||
| exit 1 | |||
| fi | |||
| if [ $1 == "resnet101" ] && [ $2 == "cifar10" ] | |||
| then | |||
| echo "error: training resnet101 with cifar10 dataset is unsupported now!" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $3) | |||
| if [ $# == 4 ] | |||
| then | |||
| PATH2=$(get_real_path $4) | |||
| fi | |||
| if [ ! -d $PATH2 ] | |||
| then | |||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||
| exit 1 | |||
| fi | |||
| if [ $# == 5 ] && [ ! -f $PATH2 ] | |||
| then | |||
| echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=8 | |||
| export RANK_SIZE=8 | |||
| rm -rf ./train_parallel | |||
| mkdir ./train_parallel | |||
| cp ../*.py ./train_parallel | |||
| cp *.sh ./train_parallel | |||
| cp -r ../src ./train_parallel | |||
| cd ./train_parallel || exit | |||
| if [ $# == 3 ] | |||
| then | |||
| mpirun --allow-run-as-root -n $RANK_SIZE \ | |||
| python train.py --net=$1 --dataset=$2 --run_distribute=True \ | |||
| --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 &> log & | |||
| fi | |||
| if [ $# == 4 ] | |||
| then | |||
| mpirun --allow-run-as-root -n $RANK_SIZE \ | |||
| python train.py --net=$1 --dataset=$2 --run_distribute=True \ | |||
| --device_num=$DEVICE_NUM --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & | |||
| fi | |||
| @@ -0,0 +1,95 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# != 3 ] && [ $# != 4 ] | |||
| then | |||
| echo "Usage: sh run_standalone_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||
| exit 1 | |||
| fi | |||
| if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] | |||
| then | |||
| echo "error: the selected net is neither resnet50 nor resnet101" | |||
| exit 1 | |||
| fi | |||
| if [ $2 != "cifar10" ] && [ $2 != "imagenet2012" ] | |||
| then | |||
| echo "error: the selected dataset is neither cifar10 nor imagenet2012" | |||
| exit 1 | |||
| fi | |||
| if [ $1 == "resnet101" ] && [ $2 == "cifar10" ] | |||
| then | |||
| echo "error: training resnet101 with cifar10 dataset is unsupported now!" | |||
| exit 1 | |||
| fi | |||
| get_real_path(){ | |||
| if [ "${1:0:1}" == "/" ]; then | |||
| echo "$1" | |||
| else | |||
| echo "$(realpath -m $PWD/$1)" | |||
| fi | |||
| } | |||
| PATH1=$(get_real_path $3) | |||
| if [ $# == 4 ] | |||
| then | |||
| PATH2=$(get_real_path $4) | |||
| fi | |||
| if [ ! -d $PATH1 ] | |||
| then | |||
| echo "error: DATASET_PATH=$PATH1 is not a directory" | |||
| exit 1 | |||
| fi | |||
| if [ $# == 4 ] && [ ! -f $PATH2 ] | |||
| then | |||
| echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" | |||
| exit 1 | |||
| fi | |||
| ulimit -u unlimited | |||
| export DEVICE_NUM=1 | |||
| export DEVICE_ID=0 | |||
| export RANK_ID=0 | |||
| export RANK_SIZE=1 | |||
| if [ -d "train" ]; | |||
| then | |||
| rm -rf ./train | |||
| fi | |||
| mkdir ./train | |||
| cp ../*.py ./train | |||
| cp *.sh ./train | |||
| cp -r ../src ./train | |||
| cd ./train || exit | |||
| echo "start training for device $DEVICE_ID" | |||
| env > env.log | |||
| if [ $# == 3 ] | |||
| then | |||
| python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 &> log & | |||
| fi | |||
| if [ $# == 4 ] | |||
| then | |||
| python train.py --net=$1 --dataset=$2 --device_target="GPU" --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & | |||
| fi | |||
| cd .. | |||
| @@ -139,7 +139,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= | |||
| return ds | |||
| def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32): | |||
| def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"): | |||
| """ | |||
| create a train or eval imagenet2012 dataset for resnet101 | |||
| Args: | |||
| @@ -158,36 +158,26 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32): | |||
| else: | |||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=device_num, shard_id=rank_id) | |||
| resize_height = 224 | |||
| rescale = 1.0 / 255.0 | |||
| shift = 0.0 | |||
| image_size = 224 | |||
| mean = [0.475 * 255, 0.451 * 255, 0.392 * 255] | |||
| std = [0.275 * 255, 0.267 * 255, 0.278 * 255] | |||
| # define map operations | |||
| decode_op = C.Decode() | |||
| random_resize_crop_op = C.RandomResizedCrop(resize_height, (0.08, 1.0), (0.75, 1.33), max_attempts=100) | |||
| horizontal_flip_op = C.RandomHorizontalFlip(rank_id / (rank_id + 1)) | |||
| resize_op_256 = C.Resize((256, 256)) | |||
| center_crop = C.CenterCrop(224) | |||
| rescale_op = C.Rescale(rescale, shift) | |||
| normalize_op = C.Normalize((0.475, 0.451, 0.392), (0.275, 0.267, 0.278)) | |||
| changeswap_op = C.HWC2CHW() | |||
| if do_train: | |||
| trans = [decode_op, | |||
| random_resize_crop_op, | |||
| horizontal_flip_op, | |||
| rescale_op, | |||
| normalize_op, | |||
| changeswap_op] | |||
| trans = [ | |||
| C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), | |||
| C.RandomHorizontalFlip(rank_id/ (rank_id +1)), | |||
| C.Normalize(mean=mean, std=std), | |||
| C.HWC2CHW() | |||
| ] | |||
| else: | |||
| trans = [decode_op, | |||
| resize_op_256, | |||
| center_crop, | |||
| rescale_op, | |||
| normalize_op, | |||
| changeswap_op] | |||
| trans = [ | |||
| C.Decode(), | |||
| C.Resize(256), | |||
| C.CenterCrop(image_size), | |||
| C.Normalize(mean=mean, std=std), | |||
| C.HWC2CHW() | |||
| ] | |||
| type_cast_op = C2.TypeCast(mstype.int32) | |||
| @@ -86,12 +86,8 @@ if __name__ == '__main__': | |||
| ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||
| # create dataset | |||
| if args_opt.net == "resnet50": | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, | |||
| batch_size=config.batch_size, target=target) | |||
| else: | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, | |||
| batch_size=config.batch_size) | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, repeat_num=1, | |||
| batch_size=config.batch_size, target=target) | |||
| step_size = dataset.get_dataset_size() | |||
| # define net | |||
| @@ -122,7 +118,7 @@ if __name__ == '__main__': | |||
| lr = get_lr(lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs, | |||
| total_epochs=config.epoch_size, steps_per_epoch=step_size, lr_decay_mode='cosine') | |||
| else: | |||
| lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, 120, | |||
| lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size, | |||
| config.pretrain_epoch_size * step_size) | |||
| lr = Tensor(lr) | |||
| @@ -147,9 +143,13 @@ if __name__ == '__main__': | |||
| amp_level="O2", keep_batchnorm_fp32=False) | |||
| else: | |||
| # GPU target | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, is_grad=False, reduction='mean') | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum) | |||
| loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", is_grad=False, | |||
| smooth_factor=config.label_smooth_factor, num_classes=config.class_num) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| ##Mixed precision | |||
| #model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||
| # amp_level="O2", keep_batchnorm_fp32=True) | |||
| # define callbacks | |||
| time_cb = TimeMonitor(data_size=step_size) | |||