| @@ -16,6 +16,8 @@ | |||||
| ##############test googlenet example on cifar10################# | ##############test googlenet example on cifar10################# | ||||
| python eval.py | python eval.py | ||||
| """ | """ | ||||
| import argparse | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.nn.optim.momentum import Momentum | from mindspore.nn.optim.momentum import Momentum | ||||
| @@ -26,10 +28,15 @@ from src.config import cifar_cfg as cfg | |||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.googlenet import GoogleNet | from src.googlenet import GoogleNet | ||||
| parser = argparse.ArgumentParser(description='googlenet') | |||||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | |||||
| args_opt = parser.parse_args() | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| device_target = cfg.device_target | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | ||||
| context.set_context(device_id=cfg.device_id) | |||||
| if device_target == "Ascend": | |||||
| context.set_context(device_id=cfg.device_id) | |||||
| net = GoogleNet(num_classes=cfg.num_classes) | net = GoogleNet(num_classes=cfg.num_classes) | ||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, | opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, | ||||
| @@ -37,7 +44,11 @@ if __name__ == '__main__': | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | ||||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) | ||||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||||
| if device_target == "Ascend": | |||||
| param_dict = load_checkpoint(cfg.checkpoint_path) | |||||
| else: # GPU | |||||
| param_dict = load_checkpoint(args_opt.checkpoint_path) | |||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| net.set_train(False) | net.set_train(False) | ||||
| dataset = create_dataset(cfg.data_path, 1, False) | dataset = create_dataset(cfg.data_path, 1, False) | ||||
| @@ -0,0 +1,43 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| ulimit -u unlimited | |||||
| if [ $# != 1 ] | |||||
| then | |||||
| echo "GPU: sh run_eval_gpu.sh [CHECKPOINT_PATH]" | |||||
| exit 1 | |||||
| fi | |||||
| # check checkpoint file | |||||
| if [ ! -f $1 ] | |||||
| then | |||||
| echo "error: CHECKPOINT_PATH=$1 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| BASEPATH=$(cd "`dirname $0`" || exit; pwd) | |||||
| export PYTHONPATH=${BASEPATH}:$PYTHONPATH | |||||
| export DEVICE_ID=0 | |||||
| if [ -d "../eval" ]; | |||||
| then | |||||
| rm -rf ../eval | |||||
| fi | |||||
| mkdir ../eval | |||||
| cd ../eval || exit | |||||
| python3 ${BASEPATH}/../eval.py --checkpoint_path=$1 > ./eval.log 2>&1 & | |||||
| @@ -0,0 +1,45 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| if [ $# -lt 2 ] | |||||
| then | |||||
| echo "Usage:\n \ | |||||
| sh run_train.sh [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)]\n \ | |||||
| " | |||||
| exit 1 | |||||
| fi | |||||
| if [ $1 -lt 1 ] && [ $1 -gt 8 ] | |||||
| then | |||||
| echo "error: DEVICE_NUM=$1 is not in (1-8)" | |||||
| exit 1 | |||||
| fi | |||||
| export DEVICE_NUM=$1 | |||||
| export RANK_SIZE=$1 | |||||
| BASEPATH=$(cd "`dirname $0`" || exit; pwd) | |||||
| export PYTHONPATH=${BASEPATH}:$PYTHONPATH | |||||
| if [ -d "../train" ]; | |||||
| then | |||||
| rm -rf ../train | |||||
| fi | |||||
| mkdir ../train | |||||
| cd ../train || exit | |||||
| export CUDA_VISIBLE_DEVICES="$2" | |||||
| mpirun -n $1 --allow-run-as-root \ | |||||
| python3 ${BASEPATH}/../train.py > train.log 2>&1 & | |||||
| @@ -25,7 +25,7 @@ import numpy as np | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.communication.management import init | |||||
| from mindspore.communication.management import init, get_rank | |||||
| from mindspore.nn.optim.momentum import Momentum | from mindspore.nn.optim.momentum import Momentum | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | ||||
| from mindspore.train.model import Model, ParallelMode | from mindspore.train.model import Model, ParallelMode | ||||
| @@ -38,7 +38,6 @@ from src.googlenet import GoogleNet | |||||
| random.seed(1) | random.seed(1) | ||||
| np.random.seed(1) | np.random.seed(1) | ||||
| def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): | def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): | ||||
| """Set learning rate.""" | """Set learning rate.""" | ||||
| lr_each_step = [] | lr_each_step = [] | ||||
| @@ -65,18 +64,31 @@ if __name__ == '__main__': | |||||
| parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') | parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | |||||
| if args_opt.device_id is not None: | |||||
| context.set_context(device_id=args_opt.device_id) | |||||
| else: | |||||
| context.set_context(device_id=cfg.device_id) | |||||
| device_target = cfg.device_target | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) | |||||
| device_num = int(os.environ.get("DEVICE_NUM", 1)) | device_num = int(os.environ.get("DEVICE_NUM", 1)) | ||||
| if device_num > 1: | |||||
| context.reset_auto_parallel_context() | |||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True) | |||||
| init() | |||||
| if device_target == "Ascend": | |||||
| if args_opt.device_id is not None: | |||||
| context.set_context(device_id=args_opt.device_id) | |||||
| else: | |||||
| context.set_context(device_id=cfg.device_id) | |||||
| if device_num > 1: | |||||
| context.reset_auto_parallel_context() | |||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True) | |||||
| init() | |||||
| elif device_target == "GPU": | |||||
| init("nccl") | |||||
| if device_num > 1: | |||||
| context.reset_auto_parallel_context() | |||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True) | |||||
| else: | |||||
| raise ValueError("Unsupport platform.") | |||||
| dataset = create_dataset(cfg.data_path, 1) | dataset = create_dataset(cfg.data_path, 1) | ||||
| batch_num = dataset.get_dataset_size() | batch_num = dataset.get_dataset_size() | ||||
| @@ -90,12 +102,19 @@ if __name__ == '__main__': | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, | opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, | ||||
| weight_decay=cfg.weight_decay) | weight_decay=cfg.weight_decay) | ||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) | ||||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||||
| amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||||
| if device_target == "Ascend": | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||||
| amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) | |||||
| ckpt_save_dir = "./" | |||||
| else: # GPU | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, | |||||
| amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=None) | |||||
| ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/" | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) | config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) | ||||
| time_cb = TimeMonitor(data_size=batch_num) | time_cb = TimeMonitor(data_size=batch_num) | ||||
| ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", config=config_ck) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory=ckpt_save_dir, config=config_ck) | |||||
| loss_cb = LossMonitor() | loss_cb = LossMonitor() | ||||
| model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) | ||||
| print("train success") | print("train success") | ||||