diff --git a/model_zoo/official/cv/vgg16/mindspore_hub_conf.py b/model_zoo/official/cv/vgg16/mindspore_hub_conf.py new file mode 100644 index 0000000000..72af0e8fb7 --- /dev/null +++ b/model_zoo/official/cv/vgg16/mindspore_hub_conf.py @@ -0,0 +1,26 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.vgg import vgg16 as VGG16 + + +def vgg16(*args, **kwargs): + return VGG16(*args, **kwargs) + + +def create_network(name, *args, **kwargs): + if name == "vgg16": + return vgg16(*args, **kwargs) + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/cv/vgg16/src/vgg.py b/model_zoo/official/cv/vgg16/src/vgg.py index 51c27e4162..25a46a2e82 100644 --- a/model_zoo/official/cv/vgg16/src/vgg.py +++ b/model_zoo/official/cv/vgg16/src/vgg.py @@ -60,6 +60,7 @@ class Vgg(nn.Cell): num_classes (int): Class numbers. Default: 1000. batch_norm (bool): Whether to do the batchnorm. Default: False. batch_size (int): Batch size. Default: 1. + include_top(bool): Whether to include the 3 fully-connected layers at the top of the network. Default: True. Returns: Tensor, infer output tensor. @@ -69,10 +70,12 @@ class Vgg(nn.Cell): >>> num_classes=1000, batch_norm=False, batch_size=1) """ - def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"): + def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train", + include_top=True): super(Vgg, self).__init__() _ = batch_size self.layers = _make_layer(base, args, batch_norm=batch_norm) + self.include_top = include_top self.flatten = nn.Flatten() dropout_ratio = 0.5 if not args.has_dropout or phase == "test": @@ -91,8 +94,9 @@ class Vgg(nn.Cell): def construct(self, x): x = self.layers(x) - x = self.flatten(x) - x = self.classifier(x) + if self.include_top: + x = self.flatten(x) + x = self.classifier(x) return x def custom_init_weight(self):