| @@ -28,22 +28,17 @@ from src.utils import switch_precision, set_context | |||
| if __name__ == '__main__': | |||
| args_opt = eval_parse_args() | |||
| config = set_config(args_opt) | |||
| set_context(config) | |||
| backbone_net, head_net, net = define_net(config, args_opt.is_training) | |||
| #load the trained checkpoint file to the net for evaluation | |||
| if args_opt.head_ckpt: | |||
| load_ckpt(backbone_net, args_opt.pretrain_ckpt) | |||
| load_ckpt(head_net, args_opt.head_ckpt) | |||
| else: | |||
| load_ckpt(net, args_opt.pretrain_ckpt) | |||
| load_ckpt(net, args_opt.pretrain_ckpt) | |||
| set_context(config) | |||
| switch_precision(net, mstype.float16, config) | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config) | |||
| step_size = dataset.get_dataset_size() | |||
| if step_size == 0: | |||
| raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \ | |||
| raise ValueError("The step_size of dataset is zero. Check if the images count of eval dataset is more \ | |||
| than batch_size in config.py") | |||
| net.set_train(False) | |||
| @@ -53,5 +48,3 @@ if __name__ == '__main__': | |||
| res = model.eval(dataset) | |||
| print(f"result:{res}\npretrain_ckpt={args_opt.pretrain_ckpt}") | |||
| if args_opt.head_ckpt: | |||
| print(f"head_ckpt={args_opt.head_ckpt}") | |||
| @@ -43,7 +43,6 @@ run_ascend() | |||
| --platform=$1 \ | |||
| --dataset_path=$2 \ | |||
| --pretrain_ckpt=$3 \ | |||
| --head_ckpt=$4 \ | |||
| &> ../eval.log & # dataset val folder path | |||
| } | |||
| @@ -69,7 +68,6 @@ run_gpu() | |||
| --platform=$1 \ | |||
| --dataset_path=$2 \ | |||
| --pretrain_ckpt=$3 \ | |||
| --head_ckpt=$4 \ | |||
| &> ../eval.log & # dataset train folder | |||
| } | |||
| @@ -95,17 +93,16 @@ run_cpu() | |||
| --platform=$1 \ | |||
| --dataset_path=$2 \ | |||
| --pretrain_ckpt=$3 \ | |||
| --head_ckpt=$4 \ | |||
| &> ../eval.log & # dataset train folder | |||
| } | |||
| if [ $# -gt 4 ] || [ $# -lt 3 ] | |||
| if [ $# -ne 3 ] | |||
| then | |||
| echo "Usage: | |||
| Ascend: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] | |||
| GPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] | |||
| CPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [BACKBONE_CKPT] [HEAD_CKPT]" | |||
| CPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT]" | |||
| exit 1 | |||
| fi | |||
| @@ -123,5 +120,5 @@ elif [ $1 = "GPU" ] ; then | |||
| elif [ $1 = "Ascend" ] ; then | |||
| run_ascend "$@" | |||
| else | |||
| echo "Unsupported device_target." | |||
| echo "Unsupported platform." | |||
| fi; | |||
| @@ -43,8 +43,8 @@ run_ascend() | |||
| --visible_devices=$3 \ | |||
| --training_script=${BASEPATH}/../train.py \ | |||
| --dataset_path=$5 \ | |||
| --train_method=$6 \ | |||
| --pretrain_ckpt=$7 \ | |||
| --pretrain_ckpt=$6 \ | |||
| --freeze_layer=$7 \ | |||
| &> ../train.log & # dataset train folder | |||
| } | |||
| @@ -76,8 +76,8 @@ run_gpu() | |||
| python ${BASEPATH}/../train.py \ | |||
| --platform=$1 \ | |||
| --dataset_path=$4 \ | |||
| --train_method=$5 \ | |||
| --pretrain_ckpt=$6 \ | |||
| --pretrain_ckpt=$5 \ | |||
| --freeze_layer=$6 \ | |||
| &> ../train.log & # dataset train folder | |||
| } | |||
| @@ -102,17 +102,17 @@ run_cpu() | |||
| python ${BASEPATH}/../train.py \ | |||
| --platform=$1 \ | |||
| --dataset_path=$2 \ | |||
| --train_method=$3 \ | |||
| --pretrain_ckpt=$4 \ | |||
| --pretrain_ckpt=$3 \ | |||
| --freeze_layer=$4 \ | |||
| &> ../train.log & # dataset train folder | |||
| } | |||
| if [ $# -gt 7 ] || [ $# -lt 4 ] | |||
| then | |||
| echo "Usage: | |||
| Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH] | |||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH] | |||
| CPU: sh run_train.sh CPU [DATASET_PATH] [TRAIN_METHOD] [CKPT_PATH]" | |||
| Ascend: sh run_train.sh Ascend [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [RANK_TABLE_FILE] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] | |||
| GPU: sh run_train.sh GPU [DEVICE_NUM] [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER] | |||
| CPU: sh run_train.sh CPU [DATASET_PATH] [CKPT_PATH] [FREEZE_LAYER]" | |||
| exit 1 | |||
| fi | |||
| @@ -123,5 +123,5 @@ elif [ $1 = "GPU" ] ; then | |||
| elif [ $1 = "CPU" ] ; then | |||
| run_cpu "$@" | |||
| else | |||
| echo "Unsupported device_target." | |||
| echo "Unsupported platform." | |||
| fi; | |||
| @@ -41,11 +41,10 @@ def train_parse_args(): | |||
| train_parser.add_argument('--platform', type=str, default="Ascend", choices=("CPU", "GPU", "Ascend"), \ | |||
| help='run platform, only support CPU, GPU and Ascend') | |||
| train_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') | |||
| train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \ | |||
| help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \ | |||
| train from initialization model") | |||
| train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \ | |||
| for fine tune or incremental learning') | |||
| train_parser.add_argument('--freeze_layer', type=str, default=None, choices=["none", "backbone"], \ | |||
| help="freeze the weights of network from start to which layers") | |||
| train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') | |||
| train_args = train_parser.parse_args() | |||
| train_args.is_training = True | |||
| @@ -58,8 +57,6 @@ def eval_parse_args(): | |||
| eval_parser.add_argument('--dataset_path', type=str, required=True, help='Dataset path') | |||
| eval_parser.add_argument('--pretrain_ckpt', type=str, required=True, help='Pretrained checkpoint path \ | |||
| for fine tune or incremental learning') | |||
| eval_parser.add_argument('--head_ckpt', type=str, default=None, help='Pretrained checkpoint path \ | |||
| for incremental learning') | |||
| eval_parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='If run distribute in GPU.') | |||
| eval_args = eval_parser.parse_args() | |||
| eval_args.is_training = False | |||
| @@ -99,13 +99,12 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1): | |||
| def extract_features(net, dataset_path, config): | |||
| features_folder = dataset_path + '_features' | |||
| features_folder = os.path.abspath(dataset_path) + '_features' | |||
| if not os.path.exists(features_folder): | |||
| os.makedirs(features_folder) | |||
| dataset = create_dataset(dataset_path=dataset_path, | |||
| do_train=False, | |||
| config=config, | |||
| repeat_num=1) | |||
| config=config) | |||
| step_size = dataset.get_dataset_size() | |||
| if step_size == 0: | |||
| raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \ | |||
| @@ -122,5 +121,5 @@ def extract_features(net, dataset_path, config): | |||
| features = model.predict(Tensor(image)) | |||
| np.save(features_path, features.asnumpy()) | |||
| np.save(label_path, label) | |||
| print(f"Complete the batch {i}/{step_size}") | |||
| print(f"Complete the batch {i+1}/{step_size}") | |||
| return step_size | |||
| @@ -298,8 +298,6 @@ class MobileNetV2(nn.Cell): | |||
| has_dropout (bool): Is dropout used. Default is false | |||
| inverted_residual_setting (list): Inverted residual settings. Default is None | |||
| round_nearest (list): Channel round to . Default is 8 | |||
| backbone(nn.Cell): Backbone of MobileNetV2. | |||
| head(nn.Cell): Classification head of MobileNetV2. | |||
| Returns: | |||
| Tensor, output tensor. | |||
| @@ -49,8 +49,7 @@ def context_device_init(config): | |||
| if config.run_distribute: | |||
| context.set_auto_parallel_context(device_num=config.rank_size, | |||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True, | |||
| all_reduce_fusion_config=[140]) | |||
| gradients_mean=True) | |||
| init() | |||
| else: | |||
| raise ValueError("Only support CPU, GPU and Ascend.") | |||
| @@ -82,6 +81,6 @@ def config_ckpoint(config, lr, step_size): | |||
| rank = get_rank() | |||
| ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(rank) + "/" | |||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) | |||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetv2", directory=ckpt_save_dir, config=config_ck) | |||
| cb += [ckpt_cb] | |||
| return cb | |||
| @@ -53,30 +53,23 @@ if __name__ == '__main__': | |||
| # define network | |||
| backbone_net, head_net, net = define_net(config, args_opt.is_training) | |||
| # load the ckpt file to the network for fine tune or incremental leaning | |||
| if args_opt.pretrain_ckpt: | |||
| if args_opt.train_method == "fine_tune": | |||
| load_ckpt(net, args_opt.pretrain_ckpt) | |||
| elif args_opt.train_method == "incremental_learn": | |||
| load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False) | |||
| elif args_opt.train_method == "train": | |||
| pass | |||
| else: | |||
| raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None") | |||
| # CPU only support "incremental_learn" | |||
| if args_opt.train_method == "incremental_learn": | |||
| if args_opt.pretrain_ckpt and args_opt.freeze_layer == "backbone": | |||
| load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False) | |||
| step_size = extract_features(backbone_net, args_opt.dataset_path, config) | |||
| net = head_net | |||
| elif args_opt.train_method in ("train", "fine_tune"): | |||
| else: | |||
| if args_opt.platform == "CPU": | |||
| raise ValueError("Currently, CPU only support \"incremental_learn\", not \"fine_tune\" or \"train\".") | |||
| raise ValueError("CPU only support fine tune the head net, doesn't support fine tune the all net") | |||
| if args_opt.pretrain_ckpt: | |||
| load_ckpt(backbone_net, args_opt.pretrain_ckpt) | |||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config) | |||
| step_size = dataset.get_dataset_size() | |||
| if step_size == 0: | |||
| raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \ | |||
| than batch_size in config.py") | |||
| if step_size == 0: | |||
| raise ValueError("The step_size of dataset is zero. Check if the images' count of train dataset is more \ | |||
| than batch_size in config.py") | |||
| # Currently, only Ascend support switch precision. | |||
| switch_precision(net, mstype.float16, config) | |||
| @@ -99,15 +92,32 @@ if __name__ == '__main__': | |||
| total_epochs=epoch_size, | |||
| steps_per_epoch=step_size)) | |||
| if args_opt.train_method == "incremental_learn": | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) | |||
| if args_opt.pretrain_ckpt is None or args_opt.freeze_layer == "none": | |||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \ | |||
| config.weight_decay, config.loss_scale) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) | |||
| cb = config_ckpoint(config, lr, step_size) | |||
| print("============== Starting Training ==============") | |||
| model.train(epoch_size, dataset, callbacks=cb) | |||
| print("============== End Training ==============") | |||
| network = WithLossCell(net, loss) | |||
| else: | |||
| opt = Momentum(filter(lambda x: x.requires_grad, head_net.get_parameters()), lr, config.momentum, config.weight_decay) | |||
| network = WithLossCell(head_net, loss) | |||
| network = TrainOneStepCell(network, opt) | |||
| network.set_train() | |||
| features_path = args_opt.dataset_path + '_features' | |||
| idx_list = list(range(step_size)) | |||
| rank = 0 | |||
| if config.run_distribute: | |||
| rank = get_rank() | |||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') | |||
| if not os.path.isdir(save_ckpt_path): | |||
| os.mkdir(save_ckpt_path) | |||
| for epoch in range(epoch_size): | |||
| random.shuffle(idx_list) | |||
| @@ -119,24 +129,8 @@ if __name__ == '__main__': | |||
| losses.append(network(feature, label).asnumpy()) | |||
| epoch_mseconds = (time.time()-epoch_start) * 1000 | |||
| per_step_mseconds = epoch_mseconds / step_size | |||
| print("epoch[{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\ | |||
| .format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses)))) | |||
| print("epoch[{}/{}], iter[{}] cost: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.3f}"\ | |||
| .format(epoch + 1, epoch_size, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses)))) | |||
| if (epoch + 1) % config.save_checkpoint_epochs == 0: | |||
| rank = 0 | |||
| if config.run_distribute: | |||
| rank = get_rank() | |||
| save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') | |||
| save_checkpoint(network, os.path.join(save_ckpt_path, \ | |||
| f"mobilenetv2_head_{epoch+1}.ckpt")) | |||
| save_checkpoint(net, os.path.join(save_ckpt_path, f"mobilenetv2_{epoch+1}.ckpt")) | |||
| print("total cost {:5.4f} s".format(time.time() - start)) | |||
| elif args_opt.train_method in ("train", "fine_tune"): | |||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | |||
| opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, \ | |||
| config.weight_decay, config.loss_scale) | |||
| model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale) | |||
| cb = config_ckpoint(config, lr, step_size) | |||
| print("============== Starting Training ==============") | |||
| model.train(epoch_size, dataset, callbacks=cb) | |||
| print("============== End Training ==============") | |||