Merge pull request !6636 from zhaoting/hubtags/v1.0.0
| @@ -28,7 +28,7 @@ from src.utils import switch_precision, set_context | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| args_opt = eval_parse_args() | args_opt = eval_parse_args() | ||||
| config = set_config(args_opt) | config = set_config(args_opt) | ||||
| backbone_net, head_net, net = define_net(config) | |||||
| backbone_net, head_net, net = define_net(config, args_opt.is_training) | |||||
| #load the trained checkpoint file to the net for evaluation | #load the trained checkpoint file to the net for evaluation | ||||
| if args_opt.head_ckpt: | if args_opt.head_ckpt: | ||||
| @@ -119,9 +119,11 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True): | |||||
| for param in network.get_parameters(): | for param in network.get_parameters(): | ||||
| param.requires_grad = False | param.requires_grad = False | ||||
| def define_net(config): | |||||
| def define_net(config, is_training): | |||||
| backbone_net = MobileNetV2Backbone() | backbone_net = MobileNetV2Backbone() | ||||
| activation = config.activation if not args.is_training else "None" | |||||
| head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) | |||||
| net = mobilenet_v2(backbone_net, head_net, activation=activation) | |||||
| activation = config.activation if not is_training else "None" | |||||
| head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, | |||||
| num_classes=config.num_classes, | |||||
| activation=activation) | |||||
| net = mobilenet_v2(backbone_net, head_net) | |||||
| return backbone_net, head_net, net | return backbone_net, head_net, net | ||||
| @@ -51,7 +51,7 @@ if __name__ == '__main__': | |||||
| context_device_init(config) | context_device_init(config) | ||||
| # define network | # define network | ||||
| backbone_net, head_net, net = define_net(config) | |||||
| backbone_net, head_net, net = define_net(config, args_opt.is_training) | |||||
| # load the ckpt file to the network for fine tune or incremental leaning | # load the ckpt file to the network for fine tune or incremental leaning | ||||
| if args_opt.pretrain_ckpt: | if args_opt.pretrain_ckpt: | ||||