|
|
|
@@ -203,7 +203,7 @@ class AuxLogits(nn.Cell): |
|
|
|
|
|
|
|
|
|
|
|
class InceptionV3(nn.Cell): |
|
|
|
def __init__(self, num_classes=10, is_training=True, has_bias=False, dropout_keep_prob=0.8): |
|
|
|
def __init__(self, num_classes=10, is_training=True, has_bias=False, dropout_keep_prob=0.8, include_top=True): |
|
|
|
super(InceptionV3, self).__init__() |
|
|
|
self.is_training = is_training |
|
|
|
self.Conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2, pad_mode='valid', has_bias=has_bias) |
|
|
|
@@ -226,7 +226,9 @@ class InceptionV3(nn.Cell): |
|
|
|
self.Mixed_7c = Inception_E(2048, has_bias=has_bias) |
|
|
|
if is_training: |
|
|
|
self.aux_logits = AuxLogits(768, num_classes) |
|
|
|
self.logits = Logits(num_classes, dropout_keep_prob) |
|
|
|
self.include_top = include_top |
|
|
|
if self.include_top: |
|
|
|
self.logits = Logits(num_classes, dropout_keep_prob) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.Conv2d_1a(x) |
|
|
|
@@ -251,6 +253,8 @@ class InceptionV3(nn.Cell): |
|
|
|
x = self.Mixed_7a(x) |
|
|
|
x = self.Mixed_7b(x) |
|
|
|
x = self.Mixed_7c(x) |
|
|
|
if not self.include_top: |
|
|
|
return x |
|
|
|
logits = self.logits(x) |
|
|
|
if self.is_training: |
|
|
|
return logits, aux_logits |
|
|
|
|