|
|
|
@@ -81,7 +81,7 @@ class GoogleNet(nn.Cell): |
|
|
|
Googlenet architecture |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, num_classes): |
|
|
|
def __init__(self, num_classes, include_top=True): |
|
|
|
super(GoogleNet, self).__init__() |
|
|
|
self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0) |
|
|
|
self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") |
|
|
|
@@ -104,11 +104,13 @@ class GoogleNet(nn.Cell): |
|
|
|
self.block5a = Inception(832, 256, 160, 320, 32, 128, 128) |
|
|
|
self.block5b = Inception(832, 384, 192, 384, 48, 128, 128) |
|
|
|
|
|
|
|
self.mean = P.ReduceMean(keep_dims=True) |
|
|
|
self.dropout = nn.Dropout(keep_prob=0.8) |
|
|
|
self.flatten = nn.Flatten() |
|
|
|
self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(), |
|
|
|
bias_init=weight_variable()) |
|
|
|
self.include_top = include_top |
|
|
|
if self.include_top: |
|
|
|
self.mean = P.ReduceMean(keep_dims=True) |
|
|
|
self.flatten = nn.Flatten() |
|
|
|
self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(), |
|
|
|
bias_init=weight_variable()) |
|
|
|
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
@@ -133,6 +135,8 @@ class GoogleNet(nn.Cell): |
|
|
|
|
|
|
|
x = self.block5a(x) |
|
|
|
x = self.block5b(x) |
|
|
|
if not self.include_top: |
|
|
|
return x |
|
|
|
|
|
|
|
x = self.mean(x, (2, 3)) |
|
|
|
x = self.flatten(x) |
|
|
|
|