|
|
@@ -60,6 +60,7 @@ class Vgg(nn.Cell): |
|
|
num_classes (int): Class numbers. Default: 1000. |
|
|
num_classes (int): Class numbers. Default: 1000. |
|
|
batch_norm (bool): Whether to do the batchnorm. Default: False. |
|
|
batch_norm (bool): Whether to do the batchnorm. Default: False. |
|
|
batch_size (int): Batch size. Default: 1. |
|
|
batch_size (int): Batch size. Default: 1. |
|
|
|
|
|
include_top(bool): Whether to include the 3 fully-connected layers at the top of the network. Default: True. |
|
|
|
|
|
|
|
|
Returns: |
|
|
Returns: |
|
|
Tensor, infer output tensor. |
|
|
Tensor, infer output tensor. |
|
|
@@ -69,10 +70,12 @@ class Vgg(nn.Cell): |
|
|
>>> num_classes=1000, batch_norm=False, batch_size=1) |
|
|
>>> num_classes=1000, batch_norm=False, batch_size=1) |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train"): |
|
|
|
|
|
|
|
|
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1, args=None, phase="train", |
|
|
|
|
|
include_top=True): |
|
|
super(Vgg, self).__init__() |
|
|
super(Vgg, self).__init__() |
|
|
_ = batch_size |
|
|
_ = batch_size |
|
|
self.layers = _make_layer(base, args, batch_norm=batch_norm) |
|
|
self.layers = _make_layer(base, args, batch_norm=batch_norm) |
|
|
|
|
|
self.include_top = include_top |
|
|
self.flatten = nn.Flatten() |
|
|
self.flatten = nn.Flatten() |
|
|
dropout_ratio = 0.5 |
|
|
dropout_ratio = 0.5 |
|
|
if not args.has_dropout or phase == "test": |
|
|
if not args.has_dropout or phase == "test": |
|
|
@@ -91,8 +94,9 @@ class Vgg(nn.Cell): |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
x = self.layers(x) |
|
|
x = self.layers(x) |
|
|
x = self.flatten(x) |
|
|
|
|
|
x = self.classifier(x) |
|
|
|
|
|
|
|
|
if self.include_top: |
|
|
|
|
|
x = self.flatten(x) |
|
|
|
|
|
x = self.classifier(x) |
|
|
return x |
|
|
return x |
|
|
|
|
|
|
|
|
def custom_init_weight(self): |
|
|
def custom_init_weight(self): |
|
|
|