|
|
@@ -14,7 +14,6 @@ |
|
|
# ============================================================================ |
|
|
# ============================================================================ |
|
|
"""VGG.""" |
|
|
"""VGG.""" |
|
|
import mindspore.nn as nn |
|
|
import mindspore.nn as nn |
|
|
from mindspore.ops import operations as P |
|
|
|
|
|
from mindspore.common.initializer import initializer |
|
|
from mindspore.common.initializer import initializer |
|
|
import mindspore.common.dtype as mstype |
|
|
import mindspore.common.dtype as mstype |
|
|
|
|
|
|
|
|
@@ -63,8 +62,7 @@ class Vgg(nn.Cell): |
|
|
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1): |
|
|
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1): |
|
|
super(Vgg, self).__init__() |
|
|
super(Vgg, self).__init__() |
|
|
self.layers = _make_layer(base, batch_norm=batch_norm) |
|
|
self.layers = _make_layer(base, batch_norm=batch_norm) |
|
|
self.reshape = P.Reshape() |
|
|
|
|
|
self.shp = (batch_size, -1) |
|
|
|
|
|
|
|
|
self.flatten = nn.Flatten() |
|
|
self.classifier = nn.SequentialCell([ |
|
|
self.classifier = nn.SequentialCell([ |
|
|
nn.Dense(512 * 7 * 7, 4096), |
|
|
nn.Dense(512 * 7 * 7, 4096), |
|
|
nn.ReLU(), |
|
|
nn.ReLU(), |
|
|
@@ -74,7 +72,7 @@ class Vgg(nn.Cell): |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
x = self.layers(x) |
|
|
x = self.layers(x) |
|
|
x = self.reshape(x, self.shp) |
|
|
|
|
|
|
|
|
x = self.flatten(x) |
|
|
x = self.classifier(x) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
return x |
|
|
|
|
|
|
|
|
@@ -87,20 +85,19 @@ cfg = { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vgg16(batch_size=1, num_classes=1000): |
|
|
|
|
|
|
|
|
def vgg16(num_classes=1000): |
|
|
""" |
|
|
""" |
|
|
Get Vgg16 neural network with batch normalization. |
|
|
Get Vgg16 neural network with batch normalization. |
|
|
|
|
|
|
|
|
Args: |
|
|
Args: |
|
|
batch_size (int): Batch size. Default: 1. |
|
|
|
|
|
num_classes (int): Class numbers. Default: 1000. |
|
|
num_classes (int): Class numbers. Default: 1000. |
|
|
|
|
|
|
|
|
Returns: |
|
|
Returns: |
|
|
Cell, cell instance of Vgg16 neural network with batch normalization. |
|
|
Cell, cell instance of Vgg16 neural network with batch normalization. |
|
|
|
|
|
|
|
|
Examples: |
|
|
Examples: |
|
|
>>> vgg16(batch_size=1, num_classes=1000) |
|
|
|
|
|
|
|
|
>>> vgg16(num_classes=1000) |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True, batch_size=batch_size) |
|
|
|
|
|
|
|
|
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True) |
|
|
return net |
|
|
return net |