| @@ -28,22 +28,17 @@ from src.utils import switch_precision, set_context | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| args_opt = eval_parse_args() | args_opt = eval_parse_args() | ||||
| config = set_config(args_opt) | config = set_config(args_opt) | ||||
| set_context(config) | |||||
| backbone_net, head_net, net = define_net(config, args_opt.is_training) | backbone_net, head_net, net = define_net(config, args_opt.is_training) | ||||
| #load the trained checkpoint file to the net for evaluation | |||||
| if args_opt.head_ckpt: | |||||
| load_ckpt(backbone_net, args_opt.pretrain_ckpt) | |||||
| load_ckpt(head_net, args_opt.head_ckpt) | |||||
| else: | |||||
| load_ckpt(net, args_opt.pretrain_ckpt) | |||||
| load_ckpt(net, args_opt.pretrain_ckpt) | |||||
| set_context(config) | |||||
| switch_precision(net, mstype.float16, config) | switch_precision(net, mstype.float16, config) | ||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config) | dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config) | ||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| if step_size == 0: | if step_size == 0: | ||||
| raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \ | |||||
| raise ValueError("The step_size of dataset is zero. Check if the images count of eval dataset is more \ | |||||
| than batch_size in config.py") | than batch_size in config.py") | ||||
| net.set_train(False) | net.set_train(False) | ||||
| @@ -53,5 +48,3 @@ if __name__ == '__main__': | |||||
| res = model.eval(dataset) | res = model.eval(dataset) | ||||
| print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}") | print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}") | ||||
| if args_opt.head_ckpt: | |||||
| print(f"head_ckpt={args_opt.head_ckpt}") | |||||
| @@ -43,7 +43,6 @@ run_ascend() | |||||
| --platform=$1 \ | --platform=$1 \ | ||||
| --dataset_path=$2 \ | --dataset_path=$2 \ | ||||
| --pretrain_ckpt=$3 \ | --pretrain_ckpt=$3 \ | ||||
| --head_ckpt=$4 \ | |||||
| &> ../eval.log & # dataset val folder path | &> ../eval.log & # dataset val folder path | ||||
| } | } | ||||
| @@ -69,7 +68,6 @@ run_gpu() | |||||
| --platform=$1 \ | --platform=$1 \ | ||||
| --dataset_path=$2 \ | --dataset_path=$2 \ | ||||
| --pretrain_ckpt=$3 \ | --pretrain_ckpt=$3 \ | ||||
| --head_ckpt=$4 \ | |||||
| &> ../eval.log & # dataset train folder | &> ../eval.log & # dataset train folder | ||||
| } | } | ||||
| @@ -95,17 +93,16 @@ run_cpu() | |||||
| --platform=$1 \ | --platform=$1 \ | ||||
| --dataset_path=$2 \ | --dataset_path=$2 \ | ||||
| --pretrain_ckpt=$3 \ | --pretrain_ckpt=$3 \ | ||||
| --head_ckpt=$4 \ | |||||
| &> ../eval.log & # dataset train folder | &> ../eval.log & # dataset train folder | ||||
| } | } | ||||
| if [ $# -gt 4 ] || [ $# -lt 3 ] | |||||
| if [ $# -ne 3 ] | |||||
| then | then | ||||
| echo "Usage: | echo "Usage: | ||||
| Ascend: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] | Ascend: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] | ||||
| GPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] | GPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] | ||||
| CPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [BACKBONE_CKPT] [HEAD_CKPT]" | |||||
| CPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -123,5 +120,5 @@ elif [ $1 = "GPU" ] ; then | |||||
| elif [ $1 = "Ascend" ] ; then | elif [ $1 = "Ascend" ] ; then | ||||
| run_ascend "$@" | run_ascend "$@" | ||||
| else | else | ||||
| echo "Unsupported device_target." | |||||
| echo "Unsupported platform." | |||||
| fi; | fi; | ||||
| @@ -43,8 +43,8 @@ run_ascend() | |||||
| --visible_devices=$3 \ | --visible_devices=$3 \ | ||||
| --training_script=${BASEPATH}/../train.py \ | --training_script=${BASEPATH}/../train.py \ | ||||
| --dataset_path=$5 \ | --dataset_path=$5 \ | ||||
| --train_method=$6 \ | |||||
| --pretrain_ckpt=$7 \ | |||||
| --pretrain_ckpt=$6 \ | |||||
| --freeze_layer=$7 \ | |||||
| &> ../train.log & # dataset train folder | &> ../train.log & # dataset train folder | ||||
| } | } | ||||
| @@ -76,8 +76,8 @@ run_gpu() | |||||
| python ${BASEPATH}/../train.py \ | python ${BASEPATH}/../train.py \ | ||||
| --platform=$1 \ | --platform=$1 \ | ||||
| --dataset_path=$4 \ | --dataset_path=$4 \ | ||||
| --train_method=$5 \ | |||||
| --pretrain_ckpt=$6 \ | |||||
| --pretrain_ckpt=$5 \ | |||||
| --freeze_layer=$6 \ | |||||
| &> ../train.log & # dataset train folder | &> ../train.log & # dataset train folder | ||||
| } | } | ||||
| @@ -102,17 +102,17 @@ run_cpu() | |||||
| python ${BASEPATH}/../train.py \ | python ${BASEPATH}/../train.py \ | ||||
| --platform=$1 \ | --platform=$1 \ | ||||
| --dataset_path=$2 \ | --dataset_path=$2 \ | ||||
| --train_method=$3 \ | |||||
| --pretrain_ckpt=$4 \ | |||||
| --pretrain_ckpt=$3 \ | |||||
| --freeze_layer=$4 \ | |||||
| &> ../train.log & # dataset train folder | &> ../train.log & # dataset train folder | ||||
| } | } | ||||
| if [ $# -gt 7 ] || [ $# -lt 4 ] | if [ $# -gt 7 ] || [ $# -lt 4 ] | ||||
| then | then | ||||
| echo "Usage: | echo "Usage: | ||||
| Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH] | |||||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH] | |||||
| CPU: sh run_train.sh CPU [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]" | |||||
| Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] | |||||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] | |||||
| CPU: sh run_train.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -123,5 +123,5 @@ elif [ $1 = "GPU" ] ; then | |||||
| elif [ $1 = "CPU" ] ; then | elif [ $1 = "CPU" ] ; then | ||||
| run_cpu "$@" | run_cpu "$@" | ||||
| else | else | ||||
| echo "Unsupported device_target." | |||||
| echo "Unsupported platform." | |||||
| fi; | fi; | ||||
| @@ -41,11 +41,10 @@ def train_parse_args(): | |||||
| train_parser.add_argument('--platform', type=str, default="Ascend", choices=("CPU", "GPU", "Ascend"), \ | train_parser.add_argument('--platform', type=str, default="Ascend", choices=("CPU", "GPU", "Ascend"), \ | ||||
| help='run platform, only support CPU, GPU and Ascend') | help='run platform, only support CPU, GPU and Ascend') | ||||
| train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') | train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') | ||||
| train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \ | |||||
| help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \ | |||||
| train from initialization model") | |||||
| train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \ | train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \ | ||||
| for fine tune or incremental learning') | for fine tune or incremental learning') | ||||
| train_parser.add_argument('--freeze_layer', type=str, default=None, choices=["none", "backbone"], \ | |||||
| help="freeze the weights of network from start to which layers") | |||||
| train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') | train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') | ||||
| train_args = train_parser.parse_args() | train_args = train_parser.parse_args() | ||||
| train_args.is_training = True | train_args.is_training = True | ||||
| @@ -58,8 +57,6 @@ def eval_parse_args(): | |||||
| eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') | eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') | ||||
| eval_parser.add_argument('--pretrain_ckpt', type=str, required=True, help='Pretrained checkpoint path \ | eval_parser.add_argument('--pretrain_ckpt', type=str, required=True, help='Pretrained checkpoint path \ | ||||
| for fine tune or incremental learning') | for fine tune or incremental learning') | ||||
| eval_parser.add_argument('--head_ckpt', type=str, default=None, help='Pretrained checkpoint path \ | |||||
| for incremental learning') | |||||
| eval_parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='If run distribute in GPU.') | eval_parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='If run distribute in GPU.') | ||||
| eval_args = eval_parser.parse_args() | eval_args = eval_parser.parse_args() | ||||
| eval_args.is_training = False | eval_args.is_training = False | ||||
| @@ -99,13 +99,12 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1): | |||||
| def extract_features(net, dataset_path, config): | def extract_features(net, dataset_path, config): | ||||
| features_folder = dataset_path + '_features' | |||||
| features_folder = os.path.abspath(dataset_path) + '_features' | |||||
| if not os.path.exists(features_folder): | if not os.path.exists(features_folder): | ||||
| os.makedirs(features_folder) | os.makedirs(features_folder) | ||||
| dataset = create_dataset(dataset_path=dataset_path, | dataset = create_dataset(dataset_path=dataset_path, | ||||
| do_train=False, | do_train=False, | ||||
| config=config, | |||||
| repeat_num=1) | |||||
| config=config) | |||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| if step_size == 0: | if step_size == 0: | ||||
| raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \ | raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \ | ||||
| @@ -122,5 +121,5 @@ def extract_features(net, dataset_path, config): | |||||
| features = model.predict(Tensor(image)) | features = model.predict(Tensor(image)) | ||||
| np.save(features_path, features.asnumpy()) | np.save(features_path, features.asnumpy()) | ||||
| np.save(label_path, label) | np.save(label_path, label) | ||||
| print(f"Complete the batch {i}/{step_size}") | |||||
| print(f"Complete the batch {i+1}/{step_size}") | |||||
| return step_size | return step_size | ||||
| @@ -298,8 +298,6 @@ class MobileNetV2(nn.Cell): | |||||
| has_dropout (bool): Is dropout used. Default is false | has_dropout (bool): Is dropout used. Default is false | ||||
| inverted_residual_setting (list): Inverted residual settings. Default is None | inverted_residual_setting (list): Inverted residual settings. Default is None | ||||
| round_nearest (list): Channel round to . Default is 8 | round_nearest (list): Channel round to . Default is 8 | ||||
| backbone(nn.Cell): Backbone of MobileNetV2. | |||||
| head(nn.Cell): Classification head of MobileNetV2. | |||||
| Returns: | Returns: | ||||
| Tensor, output tensor. | Tensor, output tensor. | ||||
| @@ -49,8 +49,7 @@ def context_device_init(config): | |||||
| if config.run_distribute: | if config.run_distribute: | ||||
| context.set_auto_parallel_context(device_num=config.rank_size, | context.set_auto_parallel_context(device_num=config.rank_size, | ||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| gradients_mean=True, | |||||
| all_reduce_fusion_config=[140]) | |||||
| gradients_mean=True) | |||||
| init() | init() | ||||
| else: | else: | ||||
| raise ValueError("Only support CPU, GPU and Ascend.") | raise ValueError("Only support CPU, GPU and Ascend.") | ||||
| @@ -82,6 +81,6 @@ def config_ckpoint(config, lr, step_size): | |||||
| rank = get_rank() | rank = get_rank() | ||||
| ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(rank) + "/" | ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(rank) + "/" | ||||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) | |||||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetv2", directory=ckpt_save_dir, config=config_ck) | |||||
| cb += [ckpt_cb] | cb += [ckpt_cb] | ||||
| return cb | return cb | ||||
| @@ -53,30 +53,23 @@ if __name__ == '__main__': | |||||
| # define network | # define network | ||||
| backbone_net, head_net, net = define_net(config, args_opt.is_training) | backbone_net, head_net, net = define_net(config, args_opt.is_training) | ||||
| # load the ckpt file to the network for fine tune or incremental leaning | |||||
| if args_opt.pretrain_ckpt: | |||||
| if args_opt.train_method == "fine_tune": | |||||
| load_ckpt(net, args_opt.pretrain_ckpt) | |||||
| elif args_opt.train_method == "incremental_learn": | |||||
| load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False) | |||||
| elif args_opt.train_method == "train": | |||||
| pass | |||||
| else: | |||||
| raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None") | |||||
| # CPU only support "incremental_learn" | |||||
| if args_opt.train_method == "incremental_learn": | |||||
| if args_opt.pretrain_ckpt and args_opt.freeze_layer == "backbone": | |||||
| load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False) | |||||
| step_size = extract_features(backbone_net, args_opt.dataset_path, config) | step_size = extract_features(backbone_net, args_opt.dataset_path, config) | ||||
| net = head_net | |||||
| elif args_opt.train_method in ("train", "fine_tune"): | |||||
| else: | |||||
| if args_opt.platform == "CPU": | if args_opt.platform == "CPU": | ||||
| raise ValueError("Currently, CPU only support \"incremental_learn\", not \"fine_tune\" or \"train\".") | |||||
| raise ValueError("CPU only support fine tune the head net, doesn't support fine tune the all net") | |||||
| if args_opt.pretrain_ckpt: | |||||
| load_ckpt(backbone_net, args_opt.pretrain_ckpt) | |||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config) | dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config) | ||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| if step_size == 0: | |||||
| raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \ | |||||
| than batch_size in config.py") | |||||
| if step_size == 0: | |||||
| raise ValueError("The step_size of dataset is zero. Check if the images' count of train dataset is more \ | |||||
| than batch_size in config.py") | |||||
| # Currently, only Ascend support switch precision. | # Currently, only Ascend support switch precision. | ||||
| switch_precision(net, mstype.float16, config) | switch_precision(net, mstype.float16, config) | ||||
| @@ -99,15 +92,32 @@ if __name__ == '__main__': | |||||
| total_epochs=epoch_size, | total_epochs=epoch_size, | ||||
| steps_per_epoch=step_size)) | steps_per_epoch=step_size)) | ||||
| if args_opt.train_method == "incremental_learn": | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) | |||||
| if args_opt.pretrain_ckpt is None or args_opt.freeze_layer == "none": | |||||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \ | |||||
| config.weight_decay, config.loss_scale) | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) | |||||
| cb = config_ckpoint(config, lr, step_size) | |||||
| print("============== Starting Training ==============") | |||||
| model.train(epoch_size, dataset, callbacks=cb) | |||||
| print("============== End Training ==============") | |||||
| network = WithLossCell(net, loss) | |||||
| else: | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, head_net.get_parameters()), lr, config.momentum, config.weight_decay) | |||||
| network = WithLossCell(head_net, loss) | |||||
| network = TrainOneStepCell(network, opt) | network = TrainOneStepCell(network, opt) | ||||
| network.set_train() | network.set_train() | ||||
| features_path = args_opt.dataset_path + '_features' | features_path = args_opt.dataset_path + '_features' | ||||
| idx_list = list(range(step_size)) | idx_list = list(range(step_size)) | ||||
| rank = 0 | |||||
| if config.run_distribute: | |||||
| rank = get_rank() | |||||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') | |||||
| if not os.path.isdir(save_ckpt_path): | |||||
| os.mkdir(save_ckpt_path) | |||||
| for epoch in range(epoch_size): | for epoch in range(epoch_size): | ||||
| random.shuffle(idx_list) | random.shuffle(idx_list) | ||||
| @@ -119,24 +129,8 @@ if __name__ == '__main__': | |||||
| losses.append(network(feature, label).asnumpy()) | losses.append(network(feature, label).asnumpy()) | ||||
| epoch_mseconds = (time.time()-epoch_start) * 1000 | epoch_mseconds = (time.time()-epoch_start) * 1000 | ||||
| per_step_mseconds = epoch_mseconds / step_size | per_step_mseconds = epoch_mseconds / step_size | ||||
| print("epoch[{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\ | |||||
| .format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses)))) | |||||
| print("epoch[{}/{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\ | |||||
| .format(epoch + 1, epoch_size, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses)))) | |||||
| if (epoch + 1) % config.save_checkpoint_epochs == 0: | if (epoch + 1) % config.save_checkpoint_epochs == 0: | ||||
| rank = 0 | |||||
| if config.run_distribute: | |||||
| rank = get_rank() | |||||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') | |||||
| save_checkpoint(network, os.path.join(save_ckpt_path, \ | |||||
| f"mobilenetv2_head_{epoch+1}.ckpt")) | |||||
| save_checkpoint(net, os.path.join(save_ckpt_path, f"mobilenetv2_{epoch+1}.ckpt")) | |||||
| print("total cost {:5.4f} s".format(time.time() - start)) | print("total cost {:5.4f} s".format(time.time() - start)) | ||||
| elif args_opt.train_method in ("train", "fine_tune"): | |||||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \ | |||||
| config.weight_decay, config.loss_scale) | |||||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) | |||||
| cb = config_ckpoint(config, lr, step_size) | |||||
| print("============== Starting Training ==============") | |||||
| model.train(epoch_size, dataset, callbacks=cb) | |||||
| print("============== End Training ==============") | |||||