Merge pull request !1390 from gengdongjie/mastertags/v0.3.0-alpha
| @@ -28,7 +28,7 @@ config = ed({ | |||||
| "image_height": 224, | "image_height": 224, | ||||
| "image_width": 224, | "image_width": 224, | ||||
| "save_checkpoint": True, | "save_checkpoint": True, | ||||
| "save_checkpoint_steps": 195, | |||||
| "save_checkpoint_steps": 1950, | |||||
| "keep_checkpoint_max": 10, | "keep_checkpoint_max": 10, | ||||
| "save_checkpoint_path": "./", | "save_checkpoint_path": "./", | ||||
| "warmup_epochs": 5, | "warmup_epochs": 5, | ||||
| @@ -45,6 +45,7 @@ Parameters for both training and inference can be set in config.py. | |||||
| "momentum": 0.9, # momentum optimizer | "momentum": 0.9, # momentum optimizer | ||||
| "weight_decay": 1e-4, # weight decay | "weight_decay": 1e-4, # weight decay | ||||
| "epoch_size": 90, # only valid for taining, which is always 1 for inference | "epoch_size": 90, # only valid for taining, which is always 1 for inference | ||||
| "pretrained_epoch_size": 1, # epoch size that model has been trained before load pretrained checkpoint | |||||
| "buffer_size": 1000, # number of queue size in data preprocessing | "buffer_size": 1000, # number of queue size in data preprocessing | ||||
| "image_height": 224, # image height | "image_height": 224, # image height | ||||
| "image_width": 224, # image width | "image_width": 224, # image width | ||||
| @@ -68,10 +69,11 @@ Parameters for both training and inference can be set in config.py. | |||||
| ``` | ``` | ||||
| # distributed training | # distributed training | ||||
| Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] | |||||
| Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) | |||||
| # standalone training | # standalone training | ||||
| Usage: sh run_standalone_train.sh [DATASET_PATH] | |||||
| Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional) | |||||
| ``` | ``` | ||||
| @@ -81,8 +83,14 @@ Usage: sh run_standalone_train.sh [DATASET_PATH] | |||||
| # distributed training example(8 pcs) | # distributed training example(8 pcs) | ||||
| sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc | sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc | ||||
| # If you want to load pretrained ckpt file | |||||
| sh run_distribute_train.sh rank_table_8p.json dataset/ilsvrc ./pretrained.ckpt | |||||
| # standalone training example(1 pcs) | # standalone training example(1 pcs) | ||||
| sh run_standalone_train.sh dataset/ilsvrc | sh run_standalone_train.sh dataset/ilsvrc | ||||
| # If you want to load pretrained ckpt file | |||||
| sh run_standalone_train.sh dataset/ilsvrc ./pretrained.ckpt | |||||
| ``` | ``` | ||||
| > About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). | > About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). | ||||
| @@ -24,6 +24,7 @@ config = ed({ | |||||
| "momentum": 0.9, | "momentum": 0.9, | ||||
| "weight_decay": 1e-4, | "weight_decay": 1e-4, | ||||
| "epoch_size": 90, | "epoch_size": 90, | ||||
| "pretrained_epoch_size": 1, | |||||
| "buffer_size": 1000, | "buffer_size": 1000, | ||||
| "image_height": 224, | "image_height": 224, | ||||
| "image_width": 224, | "image_width": 224, | ||||
| @@ -17,12 +17,11 @@ import math | |||||
| import numpy as np | import numpy as np | ||||
| def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): | |||||
| def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): | |||||
| """ | """ | ||||
| generate learning rate array | generate learning rate array | ||||
| Args: | Args: | ||||
| global_step(int): total steps of the training | |||||
| lr_init(float): init learning rate | lr_init(float): init learning rate | ||||
| lr_end(float): end learning rate | lr_end(float): end learning rate | ||||
| lr_max(float): max learning rate | lr_max(float): max learning rate | ||||
| @@ -83,8 +82,6 @@ def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, st | |||||
| lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) | lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) | ||||
| lr_each_step.append(lr) | lr_each_step.append(lr) | ||||
| current_step = global_step | |||||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||||
| learning_rate = lr_each_step[current_step:] | |||||
| learning_rate = np.array(lr_each_step).astype(np.float32) | |||||
| return learning_rate | return learning_rate | ||||
| @@ -14,9 +14,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| if [ $# != 2 ] | |||||
| if [ $# != 2 ] && [ $# != 3 ] | |||||
| then | then | ||||
| echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]" | |||||
| echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -30,6 +30,10 @@ get_real_path(){ | |||||
| PATH1=$(get_real_path $1) | PATH1=$(get_real_path $1) | ||||
| PATH2=$(get_real_path $2) | PATH2=$(get_real_path $2) | ||||
| if [ $# == 3 ] | |||||
| then | |||||
| PATH3=$(get_real_path $3) | |||||
| fi | |||||
| if [ ! -f "$PATH1" ] | if [ ! -f "$PATH1" ] | ||||
| then | then | ||||
| @@ -43,6 +47,12 @@ then | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| if [ ! -f "$PATH3" ] | |||||
| then | |||||
| echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | ulimit -u unlimited | ||||
| export DEVICE_NUM=8 | export DEVICE_NUM=8 | ||||
| export RANK_SIZE=8 | export RANK_SIZE=8 | ||||
| @@ -60,6 +70,11 @@ do | |||||
| cd ./train_parallel$i || exit | cd ./train_parallel$i || exit | ||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | echo "start training for rank $RANK_ID, device $DEVICE_ID" | ||||
| env > env.log | env > env.log | ||||
| python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log & | |||||
| if [ $# == 2 ] | |||||
| then | |||||
| python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &> log & | |||||
| else | |||||
| python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 --pre_trained=$PATH3 &> log & | |||||
| fi | |||||
| cd .. | cd .. | ||||
| done | done | ||||
| @@ -14,9 +14,9 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| if [ $# != 1 ] | |||||
| if [ $# != 1 ] && [ $# != 2 ] | |||||
| then | then | ||||
| echo "Usage: sh run_standalone_train.sh [DATASET_PATH]" | |||||
| echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -29,6 +29,10 @@ get_real_path(){ | |||||
| } | } | ||||
| PATH1=$(get_real_path $1) | PATH1=$(get_real_path $1) | ||||
| if [ $# == 2 ] | |||||
| then | |||||
| PATH2=$(get_real_path $2) | |||||
| fi | |||||
| if [ ! -d "$PATH1" ] | if [ ! -d "$PATH1" ] | ||||
| then | then | ||||
| @@ -36,6 +40,12 @@ then | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| if [ ! -f "$PATH2" ] | |||||
| then | |||||
| echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" | |||||
| exit 1 | |||||
| fi | |||||
| ulimit -u unlimited | ulimit -u unlimited | ||||
| export DEVICE_NUM=1 | export DEVICE_NUM=1 | ||||
| export DEVICE_ID=0 | export DEVICE_ID=0 | ||||
| @@ -51,5 +61,10 @@ cp *.sh ./train | |||||
| cd ./train || exit | cd ./train || exit | ||||
| echo "start training for device $DEVICE_ID" | echo "start training for device $DEVICE_ID" | ||||
| env > env.log | env > env.log | ||||
| python train.py --do_train=True --dataset_path=$PATH1 &> log & | |||||
| if [ $# == 1 ] | |||||
| then | |||||
| python train.py --do_train=True --dataset_path=$PATH1 &> log & | |||||
| else | |||||
| python train.py --do_train=True --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & | |||||
| fi | |||||
| cd .. | cd .. | ||||
| @@ -28,6 +28,7 @@ from mindspore.train.model import Model, ParallelMode | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | ||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.communication.management import init | from mindspore.communication.management import init | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.common.initializer as weight_init | import mindspore.common.initializer as weight_init | ||||
| @@ -39,6 +40,7 @@ parser.add_argument('--device_num', type=int, default=1, help='Device num.') | |||||
| parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.') | parser.add_argument('--do_train', type=bool, default=True, help='Do train or not.') | ||||
| parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.') | parser.add_argument('--do_eval', type=bool, default=False, help='Do eval or not.') | ||||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | ||||
| parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') | |||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| device_id = int(os.getenv('DEVICE_ID')) | device_id = int(os.getenv('DEVICE_ID')) | ||||
| @@ -58,15 +60,20 @@ if __name__ == '__main__': | |||||
| net = resnet50(class_num=config.class_num) | net = resnet50(class_num=config.class_num) | ||||
| # weight init | # weight init | ||||
| for _, cell in net.cells_and_names(): | |||||
| if isinstance(cell, nn.Conv2d): | |||||
| cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | |||||
| cell.weight.default_input.shape(), | |||||
| cell.weight.default_input.dtype()).to_tensor() | |||||
| if isinstance(cell, nn.Dense): | |||||
| cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | |||||
| cell.weight.default_input.shape(), | |||||
| cell.weight.default_input.dtype()).to_tensor() | |||||
| if args_opt.pre_trained: | |||||
| param_dict = load_checkpoint(args_opt.pre_trained) | |||||
| load_param_into_net(net, param_dict) | |||||
| epoch_size = config.epoch_size - config.pretrained_epoch_size | |||||
| else: | |||||
| for _, cell in net.cells_and_names(): | |||||
| if isinstance(cell, nn.Conv2d): | |||||
| cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), | |||||
| cell.weight.default_input.shape(), | |||||
| cell.weight.default_input.dtype()).to_tensor() | |||||
| if isinstance(cell, nn.Dense): | |||||
| cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), | |||||
| cell.weight.default_input.shape(), | |||||
| cell.weight.default_input.dtype()).to_tensor() | |||||
| if not config.use_label_smooth: | if not config.use_label_smooth: | ||||
| config.label_smooth_factor = 0.0 | config.label_smooth_factor = 0.0 | ||||
| @@ -78,9 +85,11 @@ if __name__ == '__main__': | |||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | ||||
| lr = Tensor(get_lr(global_step=0, lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max, | |||||
| warmup_epochs=config.warmup_epochs, total_epochs=epoch_size, steps_per_epoch=step_size, | |||||
| lr_decay_mode='cosine')) | |||||
| lr = get_lr(lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs, | |||||
| total_epochs=config.epoch_size, steps_per_epoch=step_size, lr_decay_mode='cosine') | |||||
| if args_opt.pre_trained: | |||||
| lr = lr[config.pretrained_epoch_size * step_size:] | |||||
| lr = Tensor(lr) | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, | opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, | ||||
| config.weight_decay, config.loss_scale) | config.weight_decay, config.loss_scale) | ||||