Browse Source

!6636 fix mobilenetv2 script error

Merge pull request !6636 from zhaoting/hub
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
0037afee74
3 changed files with 8 additions and 6 deletions
  1. +1
    -1
      model_zoo/official/cv/mobilenetv2/eval.py
  2. +6
    -4
      model_zoo/official/cv/mobilenetv2/src/models.py
  3. +1
    -1
      model_zoo/official/cv/mobilenetv2/train.py

+ 1
- 1
model_zoo/official/cv/mobilenetv2/eval.py View File

@@ -28,7 +28,7 @@ from src.utils import switch_precision, set_context
if __name__ == '__main__': if __name__ == '__main__':
args_opt = eval_parse_args() args_opt = eval_parse_args()
config = set_config(args_opt) config = set_config(args_opt)
backbone_net, head_net, net = define_net(config)
backbone_net, head_net, net = define_net(config, args_opt.is_training)


#load the trained checkpoint file to the net for evaluation #load the trained checkpoint file to the net for evaluation
if args_opt.head_ckpt: if args_opt.head_ckpt:


+ 6
- 4
model_zoo/official/cv/mobilenetv2/src/models.py View File

@@ -119,9 +119,11 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True):
for param in network.get_parameters(): for param in network.get_parameters():
param.requires_grad = False param.requires_grad = False


def define_net(config):
def define_net(config, is_training):
backbone_net = MobileNetV2Backbone() backbone_net = MobileNetV2Backbone()
activation = config.activation if not args.is_training else "None"
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes)
net = mobilenet_v2(backbone_net, head_net, activation=activation)
activation = config.activation if not is_training else "None"
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels,
num_classes=config.num_classes,
activation=activation)
net = mobilenet_v2(backbone_net, head_net)
return backbone_net, head_net, net return backbone_net, head_net, net

+ 1
- 1
model_zoo/official/cv/mobilenetv2/train.py View File

@@ -51,7 +51,7 @@ if __name__ == '__main__':
context_device_init(config) context_device_init(config)


# define network # define network
backbone_net, head_net, net = define_net(config)
backbone_net, head_net, net = define_net(config, args_opt.is_training)


# load the ckpt file to the network for fine tune or incremental leaning # load the ckpt file to the network for fine tune or incremental leaning
if args_opt.pretrain_ckpt: if args_opt.pretrain_ckpt:


Loading…
Cancel
Save