|
|
|
@@ -119,9 +119,11 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True): |
|
|
|
for param in network.get_parameters(): |
|
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
def define_net(config): |
|
|
|
def define_net(config, is_training): |
|
|
|
backbone_net = MobileNetV2Backbone() |
|
|
|
activation = config.activation if not args.is_training else "None" |
|
|
|
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) |
|
|
|
net = mobilenet_v2(backbone_net, head_net, activation=activation) |
|
|
|
activation = config.activation if not is_training else "None" |
|
|
|
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, |
|
|
|
num_classes=config.num_classes, |
|
|
|
activation=activation) |
|
|
|
net = mobilenet_v2(backbone_net, head_net) |
|
|
|
return backbone_net, head_net, net |