Browse Source

add googlenet include top for hub

tags/v1.0.0
huzhifeng 5 years ago
parent
commit
60759e8522
1 changed files with 9 additions and 5 deletions
  1. +9
    -5
      model_zoo/official/cv/googlenet/src/googlenet.py

+ 9
- 5
model_zoo/official/cv/googlenet/src/googlenet.py View File

@@ -81,7 +81,7 @@ class GoogleNet(nn.Cell):
Googlenet architecture
"""

def __init__(self, num_classes):
def __init__(self, num_classes, include_top=True):
super(GoogleNet, self).__init__()
self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
@@ -104,11 +104,13 @@ class GoogleNet(nn.Cell):
self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)

self.mean = P.ReduceMean(keep_dims=True)
self.dropout = nn.Dropout(keep_prob=0.8)
self.flatten = nn.Flatten()
self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
bias_init=weight_variable())
self.include_top = include_top
if self.include_top:
self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
bias_init=weight_variable())


def construct(self, x):
@@ -133,6 +135,8 @@ class GoogleNet(nn.Cell):

x = self.block5a(x)
x = self.block5b(x)
if not self.include_top:
return x

x = self.mean(x, (2, 3))
x = self.flatten(x)


Loading…
Cancel
Save