| @@ -81,7 +81,7 @@ class GoogleNet(nn.Cell): | |||||
| Googlenet architecture | Googlenet architecture | ||||
| """ | """ | ||||
| def __init__(self, num_classes): | |||||
| def __init__(self, num_classes, include_top=True): | |||||
| super(GoogleNet, self).__init__() | super(GoogleNet, self).__init__() | ||||
| self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0) | self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0) | ||||
| self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") | ||||
| @@ -104,11 +104,13 @@ class GoogleNet(nn.Cell): | |||||
| self.block5a = Inception(832, 256, 160, 320, 32, 128, 128) | self.block5a = Inception(832, 256, 160, 320, 32, 128, 128) | ||||
| self.block5b = Inception(832, 384, 192, 384, 48, 128, 128) | self.block5b = Inception(832, 384, 192, 384, 48, 128, 128) | ||||
| self.mean = P.ReduceMean(keep_dims=True) | |||||
| self.dropout = nn.Dropout(keep_prob=0.8) | self.dropout = nn.Dropout(keep_prob=0.8) | ||||
| self.flatten = nn.Flatten() | |||||
| self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(), | |||||
| bias_init=weight_variable()) | |||||
| self.include_top = include_top | |||||
| if self.include_top: | |||||
| self.mean = P.ReduceMean(keep_dims=True) | |||||
| self.flatten = nn.Flatten() | |||||
| self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(), | |||||
| bias_init=weight_variable()) | |||||
| def construct(self, x): | def construct(self, x): | ||||
| @@ -133,6 +135,8 @@ class GoogleNet(nn.Cell): | |||||
| x = self.block5a(x) | x = self.block5a(x) | ||||
| x = self.block5b(x) | x = self.block5b(x) | ||||
| if not self.include_top: | |||||
| return x | |||||
| x = self.mean(x, (2, 3)) | x = self.mean(x, (2, 3)) | ||||
| x = self.flatten(x) | x = self.flatten(x) | ||||