Browse Source

!6688 vgg16 hub support

Merge pull request !6688 from caojian05/ms_master_vgg16_hub
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
f7150a6dd8
2 changed files with 33 additions and 3 deletions
  1. +26
    -0
      model_zoo/official/cv/vgg16/mindspore_hub_conf.py
  2. +7
    -3
      model_zoo/official/cv/vgg16/src/vgg.py

+ 26
- 0
model_zoo/official/cv/vgg16/mindspore_hub_conf.py View File

@@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.vgg import vgg16 as VGG16


def vgg16(*args, **kwargs):
return VGG16(*args, **kwargs)


def create_network(name, *args, **kwargs):
if name == "vgg16":
return vgg16(*args, **kwargs)
raise NotImplementedError(f"{name} is not implemented in the repo")

+ 7
- 3
model_zoo/official/cv/vgg16/src/vgg.py View File

@@ -60,6 +60,7 @@ class Vgg(nn.Cell):
num_classes (int): Class numbers. Default: 1000. num_classes (int): Class numbers. Default: 1000.
batch_norm (bool): Whether to do the batchnorm. Default: False. batch_norm (bool): Whether to do the batchnorm. Default: False.
batch_size (int): Batch size. Default: 1. batch_size (int): Batch size. Default: 1.
include_top(bool): Whether to include the 3 fully-connected layers at the top of the network. Default: True.


Returns: Returns:
Tensor, infer output tensor. Tensor, infer output tensor.
@@ -69,10 +70,12 @@ class Vgg(nn.Cell):
>>> num_classes=1000, batch_norm=False, batch_size=1) >>> num_classes=1000, batch_norm=False, batch_size=1)
""" """


def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"):
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train",
include_top=True):
super(Vgg, self).__init__() super(Vgg, self).__init__()
_ = batch_size _ = batch_size
self.layers = _make_layer(base, args, batch_norm=batch_norm) self.layers = _make_layer(base, args, batch_norm=batch_norm)
self.include_top = include_top
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
dropout_ratio = 0.5 dropout_ratio = 0.5
if not args.has_dropout or phase == "test": if not args.has_dropout or phase == "test":
@@ -91,8 +94,9 @@ class Vgg(nn.Cell):


def construct(self, x): def construct(self, x):
x = self.layers(x) x = self.layers(x)
x = self.flatten(x)
x = self.classifier(x)
if self.include_top:
x = self.flatten(x)
x = self.classifier(x)
return x return x


def custom_init_weight(self): def custom_init_weight(self):


Loading…
Cancel
Save