|
|
|
@@ -23,36 +23,22 @@ def _weight_variable(shape, factor=0.01): |
|
|
|
init_value = np.random.randn(*shape).astype(np.float32) * factor |
|
|
|
return Tensor(init_value) |
|
|
|
|
|
|
|
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=False): |
|
|
|
"""Batchnorm2D wrapper.""" |
|
|
|
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32)) |
|
|
|
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32)) |
|
|
|
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32)) |
|
|
|
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32)) |
|
|
|
|
|
|
|
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init, |
|
|
|
beta_init=beta_init, moving_mean_init=moving_mean_init, |
|
|
|
moving_var_init=moving_var_init, use_batch_statistics=use_batch_statistics) |
|
|
|
|
|
|
|
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad', weights_update=True): |
|
|
|
"""Conv2D wrapper.""" |
|
|
|
weights = 'ones' |
|
|
|
layers = [] |
|
|
|
conv = nn.Conv2d(in_channels, out_channels, |
|
|
|
kernel_size=kernel_size, stride=stride, padding=padding, |
|
|
|
pad_mode=pad_mode, weight_init=weights, has_bias=False) |
|
|
|
pad_mode=pad_mode, has_bias=False) |
|
|
|
if not weights_update: |
|
|
|
conv.weight.requires_grad = False |
|
|
|
layers += [conv] |
|
|
|
layers += [_BatchNorm2dInit(out_channels)] |
|
|
|
layers += [nn.BatchNorm2d(out_channels)] |
|
|
|
return nn.SequentialCell(layers) |
|
|
|
|
|
|
|
|
|
|
|
def _fc(in_channels, out_channels): |
|
|
|
'''full connection layer''' |
|
|
|
weight = _weight_variable((out_channels, in_channels)) |
|
|
|
bias = _weight_variable((out_channels,)) |
|
|
|
return nn.Dense(in_channels, out_channels, weight, bias) |
|
|
|
return nn.Dense(in_channels, out_channels) |
|
|
|
|
|
|
|
|
|
|
|
class VGG16FeatureExtraction(nn.Cell): |
|
|
|
@@ -141,36 +127,38 @@ class VGG16Classfier(nn.Cell): |
|
|
|
self.relu = nn.ReLU() |
|
|
|
self.fc1 = _fc(in_channels=7*7*512, out_channels=4096) |
|
|
|
self.fc2 = _fc(in_channels=4096, out_channels=4096) |
|
|
|
self.batch_size = 32 |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.dropout = nn.Dropout(0.5) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
""" |
|
|
|
:param x: shape=(B, 512, 7, 7) |
|
|
|
:return: |
|
|
|
""" |
|
|
|
x = self.reshape(x, (self.batch_size, 7*7*512)) |
|
|
|
x = self.reshape(x, (-1, 7*7*512)) |
|
|
|
x = self.fc1(x) |
|
|
|
x = self.relu(x) |
|
|
|
x = self.dropout(x) |
|
|
|
x = self.fc2(x) |
|
|
|
x = self.relu(x) |
|
|
|
x = self.dropout(x) |
|
|
|
return x |
|
|
|
|
|
|
|
class VGG16(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
def __init__(self, num_classes): |
|
|
|
"""VGG16 construct for training backbone""" |
|
|
|
super(VGG16, self).__init__() |
|
|
|
self.feature_extraction = VGG16FeatureExtraction(weights_update=True) |
|
|
|
self.vgg16_feature_extractor = VGG16FeatureExtraction(weights_update=True) |
|
|
|
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
self.classifier = VGG16Classfier() |
|
|
|
self.fc3 = _fc(in_channels=4096, out_channels=1000) |
|
|
|
self.fc3 = _fc(in_channels=4096, out_channels=num_classes) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
""" |
|
|
|
:param x: shape=(B, 3, 224, 224) |
|
|
|
:return: logits, shape=(B, 1000) |
|
|
|
""" |
|
|
|
feature_maps = self.feature_extraction(x) |
|
|
|
feature_maps = self.vgg16_feature_extractor(x) |
|
|
|
x = self.max_pool(feature_maps) |
|
|
|
x = self.classifier(x) |
|
|
|
x = self.fc3(x) |
|
|
|
|