| @@ -14,27 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """LeNet.""" | """LeNet.""" | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): | |||||
| """weight initial for conv layer""" | |||||
| weight = weight_variable() | |||||
| return nn.Conv2d(in_channels, out_channels, | |||||
| kernel_size=kernel_size, stride=stride, padding=padding, | |||||
| weight_init=weight, has_bias=False, pad_mode="valid") | |||||
| def fc_with_initialize(input_channels, out_channels): | |||||
| """weight initial for fc layer""" | |||||
| weight = weight_variable() | |||||
| bias = weight_variable() | |||||
| return nn.Dense(input_channels, out_channels, weight, bias) | |||||
| def weight_variable(): | |||||
| """weight initial""" | |||||
| return TruncatedNormal(0.02) | |||||
| class LeNet5(nn.Cell): | class LeNet5(nn.Cell): | ||||
| @@ -43,6 +22,7 @@ class LeNet5(nn.Cell): | |||||
| Args: | Args: | ||||
| num_class (int): Num classes. Default: 10. | num_class (int): Num classes. Default: 10. | ||||
| channel (int): Num classes. Default: 1. | |||||
| Returns: | Returns: | ||||
| Tensor, output tensor | Tensor, output tensor | ||||
| @@ -53,26 +33,20 @@ class LeNet5(nn.Cell): | |||||
| def __init__(self, num_class=10, channel=1): | def __init__(self, num_class=10, channel=1): | ||||
| super(LeNet5, self).__init__() | super(LeNet5, self).__init__() | ||||
| self.num_class = num_class | self.num_class = num_class | ||||
| self.conv1 = conv(channel, 6, 5) | |||||
| self.conv2 = conv(6, 16, 5) | |||||
| self.fc1 = fc_with_initialize(16 * 5 * 5, 120) | |||||
| self.fc2 = fc_with_initialize(120, 84) | |||||
| self.fc3 = fc_with_initialize(84, self.num_class) | |||||
| self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid') | |||||
| self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') | |||||
| self.fc1 = nn.Dense(16 * 5 * 5, 120) | |||||
| self.fc2 = nn.Dense(120, 84) | |||||
| self.fc3 = nn.Dense(84, self.num_class) | |||||
| self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | ||||
| self.flatten = nn.Flatten() | self.flatten = nn.Flatten() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.conv1(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.conv2(x) | |||||
| x = self.relu(x) | |||||
| x = self.max_pool2d(x) | |||||
| x = self.max_pool2d(self.relu(self.conv1(x))) | |||||
| x = self.max_pool2d(self.relu(self.conv2(x))) | |||||
| x = self.flatten(x) | x = self.flatten(x) | ||||
| x = self.fc1(x) | |||||
| x = self.relu(x) | |||||
| x = self.fc2(x) | |||||
| x = self.relu(x) | |||||
| x = self.relu(self.fc1(x)) | |||||
| x = self.relu(self.fc2(x)) | |||||
| x = self.fc3(x) | x = self.fc3(x) | ||||
| return x | return x | ||||