| @@ -14,6 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """LeNet.""" | """LeNet.""" | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.common.initializer import Normal | |||||
| class LeNet5(nn.Cell): | class LeNet5(nn.Cell): | ||||
| @@ -22,7 +23,7 @@ class LeNet5(nn.Cell): | |||||
| Args: | Args: | ||||
| num_class (int): Num classes. Default: 10. | num_class (int): Num classes. Default: 10. | ||||
| channel (int): Num classes. Default: 1. | |||||
| num_channel (int): Num channels. Default: 1. | |||||
| Returns: | Returns: | ||||
| Tensor, output tensor | Tensor, output tensor | ||||
| @@ -30,14 +31,13 @@ class LeNet5(nn.Cell): | |||||
| >>> LeNet(num_class=10) | >>> LeNet(num_class=10) | ||||
| """ | """ | ||||
| def __init__(self, num_class=10, channel=1): | |||||
| def __init__(self, num_class=10, num_channel=1): | |||||
| super(LeNet5, self).__init__() | super(LeNet5, self).__init__() | ||||
| self.num_class = num_class | |||||
| self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid') | |||||
| self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') | |||||
| self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') | self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') | ||||
| self.fc1 = nn.Dense(16 * 5 * 5, 120) | |||||
| self.fc2 = nn.Dense(120, 84) | |||||
| self.fc3 = nn.Dense(84, self.num_class) | |||||
| self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) | |||||
| self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) | |||||
| self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) | |||||
| self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | ||||
| self.flatten = nn.Flatten() | self.flatten = nn.Flatten() | ||||