| @@ -16,6 +16,8 @@ | |||
| ##############test googlenet example on cifar10################# | |||
| python eval.py | |||
| """ | |||
| import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| @@ -26,10 +28,15 @@ from src.config import cifar_cfg as cfg | |||
| from src.dataset import create_dataset | |||
| from src.googlenet import GoogleNet | |||
| parser = argparse.ArgumentParser(description='googlenet') | |||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | |||
| args_opt = parser.parse_args() | |||
| if __name__ == '__main__': | |||
| device_target = cfg.device_target | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | |||
| context.set_context(device_id=cfg.device_id) | |||
| if device_target == "Ascend": | |||
| context.set_context(device_id=cfg.device_id) | |||
| net = GoogleNet(num_classes=cfg.num_classes) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, | |||
| @@ -37,7 +44,11 @@ if __name__ == '__main__': | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | |||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||
| if device_target == "Ascend": | |||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||
| else: # GPU | |||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||
| load_param_into_net(net, param_dict) | |||
| net.set_train(False) | |||
| dataset = create_dataset(cfg.data_path, 1, False) | |||
| @@ -0,0 +1,43 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| ulimit -u unlimited | |||
| if [ $# != 1 ] | |||
| then | |||
| echo "GPU: sh run_eval_gpu.sh [CHECKPOINT_PATH]" | |||
| exit 1 | |||
| fi | |||
| # check checkpoint file | |||
| if [ ! -f $1 ] | |||
| then | |||
| echo "error: CHECKPOINT_PATH=$1 is not a file" | |||
| exit 1 | |||
| fi | |||
| BASEPATH=$(cd "`dirname $0`" || exit; pwd) | |||
| export PYTHONPATH=${BASEPATH}:$PYTHONPATH | |||
| export DEVICE_ID=0 | |||
| if [ -d "../eval" ]; | |||
| then | |||
| rm -rf ../eval | |||
| fi | |||
| mkdir ../eval | |||
| cd ../eval || exit | |||
| python3 ${BASEPATH}/../eval.py --checkpoint_path=$1 > ./eval.log 2>&1 & | |||
| @@ -0,0 +1,45 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# -lt 2 ] | |||
| then | |||
| echo "Usage:\n \ | |||
| sh run_train.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)]\n \ | |||
| " | |||
| exit 1 | |||
| fi | |||
| if [ $1 -lt 1 ] && [ $1 -gt 8 ] | |||
| then | |||
| echo "error: DEVICE_NUM=$1 is not in (1-8)" | |||
| exit 1 | |||
| fi | |||
| export DEVICE_NUM=$1 | |||
| export RANK_SIZE=$1 | |||
| BASEPATH=$(cd "`dirname $0`" || exit; pwd) | |||
| export PYTHONPATH=${BASEPATH}:$PYTHONPATH | |||
| if [ -d "../train" ]; | |||
| then | |||
| rm -rf ../train | |||
| fi | |||
| mkdir ../train | |||
| cd ../train || exit | |||
| export CUDA_VISIBLE_DEVICES="$2" | |||
| mpirun -n $1 --allow-run-as-root \ | |||
| python3 ${BASEPATH}/../train.py > train.log 2>&1 & | |||
| @@ -25,7 +25,7 @@ import numpy as np | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.communication.management import init | |||
| from mindspore.communication.management import init, get_rank | |||
| from mindspore.nn.optim.momentum import Momentum | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | |||
| from mindspore.train.model import Model, ParallelMode | |||
| @@ -38,7 +38,6 @@ from src.googlenet import GoogleNet | |||
| random.seed(1) | |||
| np.random.seed(1) | |||
| def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): | |||
| """Set learning rate.""" | |||
| lr_each_step = [] | |||
| @@ -65,18 +64,31 @@ if __name__ == '__main__': | |||
| parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') | |||
| args_opt = parser.parse_args() | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | |||
| if args_opt.device_id is not None: | |||
| context.set_context(device_id=args_opt.device_id) | |||
| else: | |||
| context.set_context(device_id=cfg.device_id) | |||
| device_target = cfg.device_target | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | |||
| device_num = int(os.environ.get("DEVICE_NUM", 1)) | |||
| if device_num > 1: | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| mirror_mean=True) | |||
| init() | |||
| if device_target == "Ascend": | |||
| if args_opt.device_id is not None: | |||
| context.set_context(device_id=args_opt.device_id) | |||
| else: | |||
| context.set_context(device_id=cfg.device_id) | |||
| if device_num > 1: | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| mirror_mean=True) | |||
| init() | |||
| elif device_target == "GPU": | |||
| init("nccl") | |||
| if device_num > 1: | |||
| context.reset_auto_parallel_context() | |||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| mirror_mean=True) | |||
| else: | |||
| raise ValueError("Unsupport platform.") | |||
| dataset = create_dataset(cfg.data_path, 1) | |||
| batch_num = dataset.get_dataset_size() | |||
| @@ -90,12 +102,19 @@ if __name__ == '__main__': | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, | |||
| weight_decay=cfg.weight_decay) | |||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||
| amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||
| if device_target == "Ascend": | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||
| amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||
| ckpt_save_dir = "./" | |||
| else: # GPU | |||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||
| amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=None) | |||
| ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/" | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| time_cb = TimeMonitor(data_size=batch_num) | |||
| ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", config=config_ck) | |||
| ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory=ckpt_save_dir, config=config_ck) | |||
| loss_cb = LossMonitor() | |||
| model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | |||
| print("train success") | |||