| @@ -28,7 +28,7 @@ from src.utils import switch_precision, set_context | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| args_opt = eval_parse_args() | args_opt = eval_parse_args() | ||||
| config = set_config(args_opt) | config = set_config(args_opt) | ||||
| backbone_net, head_net, net = define_net(args_opt, config) | |||||
| backbone_net, head_net, net = define_net(config) | |||||
| #load the trained checkpoint file to the net for evaluation | #load the trained checkpoint file to the net for evaluation | ||||
| if args_opt.head_ckpt: | if args_opt.head_ckpt: | ||||
| @@ -42,6 +42,10 @@ if __name__ == '__main__': | |||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config) | dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, config=config) | ||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| if step_size == 0: | |||||
| raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \ | |||||
| than batch_size in config.py") | |||||
| net.set_train(False) | net.set_train(False) | ||||
| loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') | ||||
| @@ -103,9 +103,9 @@ run_cpu() | |||||
| if [ $# -gt 4 ] || [ $# -lt 3 ] | if [ $# -gt 4 ] || [ $# -lt 3 ] | ||||
| then | then | ||||
| echo "Usage: | echo "Usage: | ||||
| Ascend: sh run_infer.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] | |||||
| GPU: sh run_infer.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] | |||||
| CPU: sh run_infer.sh [PLATFORM] [DATASET_PATH] [BACKBONE_CKPT] [HEAD_CKPT]" | |||||
| Ascend: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] | |||||
| GPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [PRETRAIN_CKPT] | |||||
| CPU: sh run_eval.sh [PLATFORM] [DATASET_PATH] [BACKBONE_CKPT] [HEAD_CKPT]" | |||||
| exit 1 | exit 1 | ||||
| fi | fi | ||||
| @@ -36,7 +36,6 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1): | |||||
| config(struct): the config of train and eval in diffirent platform. | config(struct): the config of train and eval in diffirent platform. | ||||
| repeat_num(int): the repeat times of dataset. Default: 1. | repeat_num(int): the repeat times of dataset. Default: 1. | ||||
| Returns: | Returns: | ||||
| dataset | dataset | ||||
| """ | """ | ||||
| @@ -96,11 +95,7 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1): | |||||
| # apply dataset repeat operation | # apply dataset repeat operation | ||||
| ds = ds.repeat(repeat_num) | ds = ds.repeat(repeat_num) | ||||
| step_size = ds.get_dataset_size() | |||||
| if step_size == 0: | |||||
| raise ValueError("The step_size of dataset is zero. Check if the images of train dataset is more than batch_\ | |||||
| size in config.py") | |||||
| return ds, step_size | |||||
| return ds | |||||
| def extract_features(net, dataset_path, config): | def extract_features(net, dataset_path, config): | ||||
| @@ -112,12 +107,16 @@ def extract_features(net, dataset_path, config): | |||||
| config=config, | config=config, | ||||
| repeat_num=1) | repeat_num=1) | ||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| if step_size == 0: | |||||
| raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \ | |||||
| than batch_size in config.py") | |||||
| model = Model(net) | model = Model(net) | ||||
| for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True)): | for i, data in enumerate(dataset.create_dict_iterator(output_numpy=True)): | ||||
| features_path = os.path.join(features_folder, f"feature_{i}.npy") | features_path = os.path.join(features_folder, f"feature_{i}.npy") | ||||
| label_path = os.path.join(features_folder, f"label_{i}.npy") | label_path = os.path.join(features_folder, f"label_{i}.npy") | ||||
| if not os.path.exists(features_path or not os.path.exists(label_path)): | |||||
| if not os.path.exists(features_path) or not os.path.exists(label_path): | |||||
| image = data["image"] | image = data["image"] | ||||
| label = data["label"] | label = data["label"] | ||||
| features = model.predict(Tensor(image)) | features = model.predict(Tensor(image)) | ||||
| @@ -284,8 +284,11 @@ class MobileNetV2(nn.Cell): | |||||
| MobileNetV2 architecture. | MobileNetV2 architecture. | ||||
| Args: | Args: | ||||
| backbone(nn.Cell): | |||||
| head(nn.Cell): | |||||
| class_num (Cell): number of classes. | |||||
| width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. | |||||
| has_dropout (bool): Is dropout used. Default is false | |||||
| inverted_residual_setting (list): Inverted residual settings. Default is None | |||||
| round_nearest (list): Channel round to . Default is 8 | |||||
| Returns: | Returns: | ||||
| Tensor, output tensor. | Tensor, output tensor. | ||||
| @@ -310,14 +313,11 @@ class MobileNetV2(nn.Cell): | |||||
| class MobileNetV2Combine(nn.Cell): | class MobileNetV2Combine(nn.Cell): | ||||
| """ | """ | ||||
| MobileNetV2 architecture. | |||||
| MobileNetV2Combine architecture. | |||||
| Args: | Args: | ||||
| class_num (Cell): number of classes. | |||||
| width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1. | |||||
| has_dropout (bool): Is dropout used. Default is false | |||||
| inverted_residual_setting (list): Inverted residual settings. Default is None | |||||
| round_nearest (list): Channel round to . Default is 8 | |||||
| backbone (Cell): the features extract layers. | |||||
| head (Cell): the fully connected layers. | |||||
| Returns: | Returns: | ||||
| Tensor, output tensor. | Tensor, output tensor. | ||||
| @@ -326,7 +326,7 @@ class MobileNetV2Combine(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self, backbone, head): | def __init__(self, backbone, head): | ||||
| super(MobileNetV2Combine, self).__init__() | |||||
| super(MobileNetV2Combine, self).__init__(auto_prefix=False) | |||||
| self.backbone = backbone | self.backbone = backbone | ||||
| self.head = head | self.head = head | ||||
| @@ -119,20 +119,9 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True): | |||||
| for param in network.get_parameters(): | for param in network.get_parameters(): | ||||
| param.requires_grad = False | param.requires_grad = False | ||||
| def define_net(args, config): | |||||
| def define_net(config): | |||||
| backbone_net = MobileNetV2Backbone() | backbone_net = MobileNetV2Backbone() | ||||
| head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) | head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) | ||||
| net = mobilenet_v2(backbone_net, head_net) | net = mobilenet_v2(backbone_net, head_net) | ||||
| # load the ckpt file to the network for fine tune or incremental leaning | |||||
| if args.pretrain_ckpt: | |||||
| if args.train_method == "fine_tune": | |||||
| load_ckpt(net, args.pretrain_ckpt) | |||||
| elif args.train_method == "incremental_learn": | |||||
| load_ckpt(backbone_net, args.pretrain_ckpt, trainable=False) | |||||
| elif args.train_method == "train": | |||||
| pass | |||||
| else: | |||||
| raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None") | |||||
| return backbone_net, head_net, net | return backbone_net, head_net, net | ||||
| @@ -35,7 +35,7 @@ from src.config import set_config | |||||
| from src.args import train_parse_args | from src.args import train_parse_args | ||||
| from src.utils import context_device_init, switch_precision, config_ckpoint | from src.utils import context_device_init, switch_precision, config_ckpoint | ||||
| from src.models import CrossEntropyWithLabelSmooth, define_net | |||||
| from src.models import CrossEntropyWithLabelSmooth, define_net, load_ckpt | |||||
| set_seed(1) | set_seed(1) | ||||
| @@ -50,7 +50,18 @@ if __name__ == '__main__': | |||||
| context_device_init(config) | context_device_init(config) | ||||
| # define network | # define network | ||||
| backbone_net, head_net, net = define_net(args_opt, config) | |||||
| backbone_net, head_net, net = define_net(config) | |||||
| # load the ckpt file to the network for fine tune or incremental leaning | |||||
| if args_opt.pretrain_ckpt: | |||||
| if args_opt.train_method == "fine_tune": | |||||
| load_ckpt(net, args_opt.pretrain_ckpt) | |||||
| elif args_opt.train_method == "incremental_learn": | |||||
| load_ckpt(backbone_net, args_opt.pretrain_ckpt, trainable=False) | |||||
| elif args_opt.train_method == "train": | |||||
| pass | |||||
| else: | |||||
| raise ValueError("must input the usage of pretrain_ckpt when the pretrain_ckpt isn't None") | |||||
| # CPU only support "incremental_learn" | # CPU only support "incremental_learn" | ||||
| if args_opt.train_method == "incremental_learn": | if args_opt.train_method == "incremental_learn": | ||||
| @@ -60,7 +71,11 @@ if __name__ == '__main__': | |||||
| elif args_opt.train_method in ("train", "fine_tune"): | elif args_opt.train_method in ("train", "fine_tune"): | ||||
| if args_opt.platform == "CPU": | if args_opt.platform == "CPU": | ||||
| raise ValueError("Currently, CPU only support \"incremental_learn\", not \"fine_tune\" or \"train\".") | raise ValueError("Currently, CPU only support \"incremental_learn\", not \"fine_tune\" or \"train\".") | ||||
| dataset, step_size = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config) | |||||
| dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, config=config) | |||||
| step_size = dataset.get_dataset_size() | |||||
| if step_size == 0: | |||||
| raise ValueError("The step_size of dataset is zero. Check if the images count of train dataset is more \ | |||||
| than batch_size in config.py") | |||||
| # Currently, only Ascend support switch precision. | # Currently, only Ascend support switch precision. | ||||
| switch_precision(net, mstype.float16, config) | switch_precision(net, mstype.float16, config) | ||||