| @@ -14,13 +14,20 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| echo "==============================================================================================================" | |||||
| echo "=================================================================================================================" | |||||
| echo "Please run the scipt as: " | echo "Please run the scipt as: " | ||||
| echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDSPORE_HCCL_CONFIG_PATH" | |||||
| echo "for example: sh run_distribute_train.sh 8 150 coco /data/hccl.json" | |||||
| echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE DATASET MINDSPORE_HCCL_CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" | |||||
| echo "for example: sh run_distribute_train.sh 8 350 coco /data/hccl.json /opt/ssd-300.ckpt(optional) 200(optional)" | |||||
| echo "It is better to use absolute path." | echo "It is better to use absolute path." | ||||
| echo "The learning rate is 0.4 as default, if you want other lr, please change the value in this script." | echo "The learning rate is 0.4 as default, if you want other lr, please change the value in this script." | ||||
| echo "==============================================================================================================" | |||||
| echo "=================================================================================================================" | |||||
| if [ $# != 4 ] && [ $# != 6 ] | |||||
| then | |||||
| echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [DATASET] \ | |||||
| [MINDSPORE_HCCL_CONFIG_PATH] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| # Before start distribute train, first create mindrecord files. | # Before start distribute train, first create mindrecord files. | ||||
| python train.py --only_create_dataset=1 | python train.py --only_create_dataset=1 | ||||
| @@ -30,6 +37,8 @@ echo "After running the scipt, the network runs in the background. The log will | |||||
| export RANK_SIZE=$1 | export RANK_SIZE=$1 | ||||
| EPOCH_SIZE=$2 | EPOCH_SIZE=$2 | ||||
| DATASET=$3 | DATASET=$3 | ||||
| PRE_TRAINED=$5 | |||||
| PRE_TRAINED_EPOCH_SIZE=$6 | |||||
| export MINDSPORE_HCCL_CONFIG_PATH=$4 | export MINDSPORE_HCCL_CONFIG_PATH=$4 | ||||
| @@ -43,12 +52,29 @@ do | |||||
| export RANK_ID=$i | export RANK_ID=$i | ||||
| echo "start training for rank $i, device $DEVICE_ID" | echo "start training for rank $i, device $DEVICE_ID" | ||||
| env > env.log | env > env.log | ||||
| python ../train.py \ | |||||
| --distribute=1 \ | |||||
| --lr=0.4 \ | |||||
| --dataset=$DATASET \ | |||||
| --device_num=$RANK_SIZE \ | |||||
| --device_id=$DEVICE_ID \ | |||||
| --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & | |||||
| if [ $# == 4 ] | |||||
| then | |||||
| python ../train.py \ | |||||
| --distribute=1 \ | |||||
| --lr=0.4 \ | |||||
| --dataset=$DATASET \ | |||||
| --device_num=$RANK_SIZE \ | |||||
| --device_id=$DEVICE_ID \ | |||||
| --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & | |||||
| fi | |||||
| if [ $# == 6 ] | |||||
| then | |||||
| python ../train.py \ | |||||
| --distribute=1 \ | |||||
| --lr=0.4 \ | |||||
| --dataset=$DATASET \ | |||||
| --device_num=$RANK_SIZE \ | |||||
| --device_id=$DEVICE_ID \ | |||||
| --pre_trained=$PRE_TRAINED \ | |||||
| --pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \ | |||||
| --epoch_size=$EPOCH_SIZE > log.txt 2>&1 & | |||||
| fi | |||||
| cd ../ | cd ../ | ||||
| done | done | ||||
| @@ -88,6 +88,7 @@ def main(): | |||||
| parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.") | parser.add_argument("--epoch_size", type=int, default=70, help="Epoch size, default is 70.") | ||||
| parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") | parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") | ||||
| parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") | parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.") | ||||
| parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.") | |||||
| parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") | parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") | ||||
| parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | ||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| @@ -150,17 +151,20 @@ def main(): | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | ||||
| ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config) | ckpoint_cb = ModelCheckpoint(prefix="ssd", directory=None, config=ckpt_config) | ||||
| lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=args_opt.lr, | |||||
| warmup_epochs=max(args_opt.epoch_size // 20, 1), | |||||
| total_epochs=args_opt.epoch_size, | |||||
| steps_per_epoch=dataset_size)) | |||||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale) | |||||
| net = TrainingWrapper(net, opt, loss_scale) | |||||
| if args_opt.pre_trained: | if args_opt.pre_trained: | ||||
| if args_opt.pre_trained_epoch_size <= 0: | |||||
| raise KeyError("pre_trained_epoch_size must be greater than 0.") | |||||
| param_dict = load_checkpoint(args_opt.pre_trained) | param_dict = load_checkpoint(args_opt.pre_trained) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| lr = Tensor(get_lr(global_step=args_opt.pre_trained_epoch_size * dataset_size, | |||||
| lr_init=0, lr_end=0, lr_max=args_opt.lr, | |||||
| warmup_epochs=max(350 // 20, 1), | |||||
| total_epochs=350, | |||||
| steps_per_epoch=dataset_size)) | |||||
| opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, 0.9, 0.0001, loss_scale) | |||||
| net = TrainingWrapper(net, opt, loss_scale) | |||||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | ||||
| model = Model(net) | model = Model(net) | ||||
| @@ -14,18 +14,27 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| echo "==============================================================================================================" | |||||
| echo "=======================================================================================================================================================" | |||||
| echo "Please run the scipt as: " | echo "Please run the scipt as: " | ||||
| echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH" | |||||
| echo "for example: sh run_distribute_train.sh 8 100 /data/Mindrecord_train /data /data/train.txt /data/hccl.json" | |||||
| echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH MINDSPORE_HCCL_CONFIG_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" | |||||
| echo "For example: sh run_distribute_train.sh 8 150 /data/Mindrecord_train /data /data/train.txt /data/hccl.json /opt/yolov3-150.ckpt(optional) 100(optional)" | |||||
| echo "It is better to use absolute path." | echo "It is better to use absolute path." | ||||
| echo "The learning rate is 0.005 as default, if you want other lr, please change the value in this script." | echo "The learning rate is 0.005 as default, if you want other lr, please change the value in this script." | ||||
| echo "==============================================================================================================" | |||||
| echo "=======================================================================================================================================================" | |||||
| if [ $# != 6 ] && [ $# != 8 ] | |||||
| then | |||||
| echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [MINDRECORD_DIR] [IMAGE_DIR] [ANNO_PATH] [MINDSPORE_HCCL_CONFIG_PATH] \ | |||||
| [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| EPOCH_SIZE=$2 | EPOCH_SIZE=$2 | ||||
| MINDRECORD_DIR=$3 | MINDRECORD_DIR=$3 | ||||
| IMAGE_DIR=$4 | IMAGE_DIR=$4 | ||||
| ANNO_PATH=$5 | ANNO_PATH=$5 | ||||
| PRE_TRAINED=$7 | |||||
| PRE_TRAINED_EPOCH_SIZE=$8 | |||||
| # Before start distribute train, first create mindrecord files. | # Before start distribute train, first create mindrecord files. | ||||
| python train.py --only_create_dataset=1 --mindrecord_dir=$MINDRECORD_DIR --image_dir=$IMAGE_DIR \ | python train.py --only_create_dataset=1 --mindrecord_dir=$MINDRECORD_DIR --image_dir=$IMAGE_DIR \ | ||||
| @@ -51,14 +60,34 @@ do | |||||
| export RANK_ID=$i | export RANK_ID=$i | ||||
| echo "start training for rank $i, device $DEVICE_ID" | echo "start training for rank $i, device $DEVICE_ID" | ||||
| env > env.log | env > env.log | ||||
| taskset -c $cmdopt python ../train.py \ | |||||
| --distribute=1 \ | |||||
| --lr=0.005 \ | |||||
| --device_num=$RANK_SIZE \ | |||||
| --device_id=$DEVICE_ID \ | |||||
| --mindrecord_dir=$MINDRECORD_DIR \ | |||||
| --image_dir=$IMAGE_DIR \ | |||||
| --epoch_size=$EPOCH_SIZE \ | |||||
| --anno_path=$ANNO_PATH > log.txt 2>&1 & | |||||
| if [ $# == 6 ] | |||||
| then | |||||
| taskset -c $cmdopt python ../train.py \ | |||||
| --distribute=1 \ | |||||
| --lr=0.005 \ | |||||
| --device_num=$RANK_SIZE \ | |||||
| --device_id=$DEVICE_ID \ | |||||
| --mindrecord_dir=$MINDRECORD_DIR \ | |||||
| --image_dir=$IMAGE_DIR \ | |||||
| --epoch_size=$EPOCH_SIZE \ | |||||
| --anno_path=$ANNO_PATH > log.txt 2>&1 & | |||||
| fi | |||||
| if [ $# == 8 ] | |||||
| then | |||||
| taskset -c $cmdopt python ../train.py \ | |||||
| --distribute=1 \ | |||||
| --lr=0.005 \ | |||||
| --device_num=$RANK_SIZE \ | |||||
| --device_id=$DEVICE_ID \ | |||||
| --mindrecord_dir=$MINDRECORD_DIR \ | |||||
| --image_dir=$IMAGE_DIR \ | |||||
| --epoch_size=$EPOCH_SIZE \ | |||||
| --pre_trained=$PRE_TRAINED \ | |||||
| --pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \ | |||||
| --anno_path=$ANNO_PATH > log.txt 2>&1 & | |||||
| fi | |||||
| cd ../ | cd ../ | ||||
| done | done | ||||
| @@ -14,10 +14,25 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| echo "==============================================================================================================" | |||||
| echo "=========================================================================================================================================" | |||||
| echo "Please run the scipt as: " | echo "Please run the scipt as: " | ||||
| echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH" | |||||
| echo "for example: sh run_standalone_train.sh 0 50 ./Mindrecord_train ./dataset ./dataset/train.txt" | |||||
| echo "==============================================================================================================" | |||||
| echo "sh run_standalone_train.sh DEVICE_ID EPOCH_SIZE MINDRECORD_DIR IMAGE_DIR ANNO_PATH PRE_TRAINED PRE_TRAINED_EPOCH_SIZE" | |||||
| echo "for example: sh run_standalone_train.sh 0 50 ./Mindrecord_train ./dataset ./dataset/train.txt /opt/yolov3-50.ckpt(optional) 30(optional)" | |||||
| echo "=========================================================================================================================================" | |||||
| python train.py --device_id=$1 --epoch_size=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5 | |||||
| if [ $# != 5 ] && [ $# != 7 ] | |||||
| then | |||||
| echo "Usage: sh run_standalone_train.sh [DEVICE_ID] [EPOCH_SIZE] [MINDRECORD_DIR] [IMAGE_DIR] [ANNO_PATH] \ | |||||
| [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)" | |||||
| exit 1 | |||||
| fi | |||||
| if [ $# == 5 ] | |||||
| then | |||||
| python train.py --device_id=$1 --epoch_size=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5 | |||||
| fi | |||||
| if [ $# == 7 ] | |||||
| then | |||||
| python train.py --device_id=$1 --epoch_size=$2 --mindrecord_dir=$3 --image_dir=$4 --anno_path=$5 --pre_trained=$6 --pre_trained_epoch_size=$7 | |||||
| fi | |||||
| @@ -71,6 +71,7 @@ def main(): | |||||
| parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10") | parser.add_argument("--epoch_size", type=int, default=10, help="Epoch size, default is 10") | ||||
| parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") | parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.") | ||||
| parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained checkpoint file path") | parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained checkpoint file path") | ||||
| parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size") | |||||
| parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") | parser.add_argument("--save_checkpoint_epochs", type=int, default=5, help="Save checkpoint epochs, default is 5.") | ||||
| parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.") | ||||
| parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train", | parser.add_argument("--mindrecord_dir", type=str, default="./Mindrecord_train", | ||||
| @@ -133,14 +134,19 @@ def main(): | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | ckpt_config = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs) | ||||
| ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config) | ckpoint_cb = ModelCheckpoint(prefix="yolov3", directory=None, config=ckpt_config) | ||||
| lr = Tensor(get_lr(learning_rate=args_opt.lr, start_step=0, global_step=args_opt.epoch_size * dataset_size, | |||||
| decay_step=1000, decay_rate=0.95, steps=True)) | |||||
| opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) | |||||
| net = TrainingWrapper(net, opt, loss_scale) | |||||
| if args_opt.pre_trained: | if args_opt.pre_trained: | ||||
| if args_opt.pre_trained_epoch_size <= 0: | |||||
| raise KeyError("pre_trained_epoch_size must be greater than 0.") | |||||
| param_dict = load_checkpoint(args_opt.pre_trained) | param_dict = load_checkpoint(args_opt.pre_trained) | ||||
| load_param_into_net(net, param_dict) | load_param_into_net(net, param_dict) | ||||
| total_epoch_size = 60 | |||||
| if args_opt.distribute: | |||||
| total_epoch_size = 160 | |||||
| lr = Tensor(get_lr(learning_rate=args_opt.lr, start_step=args_opt.pre_trained_epoch_size * dataset_size, | |||||
| global_step=total_epoch_size * dataset_size, | |||||
| decay_step=1000, decay_rate=0.95, steps=True)) | |||||
| opt = nn.Adam(filter(lambda x: x.requires_grad, net.get_parameters()), lr, loss_scale=loss_scale) | |||||
| net = TrainingWrapper(net, opt, loss_scale) | |||||
| callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | callback = [TimeMonitor(data_size=dataset_size), LossMonitor(), ckpoint_cb] | ||||