| @@ -15,6 +15,7 @@ | |||||
| """Alexnet.""" | """Alexnet.""" | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.common.initializer import TruncatedNormal | from mindspore.common.initializer import TruncatedNormal | ||||
| from mindspore.ops import operations as P | |||||
| def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid"): | def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid"): | ||||
| weight = weight_variable() | weight = weight_variable() | ||||
| @@ -44,7 +45,7 @@ class AlexNet(nn.Cell): | |||||
| self.conv4 = conv(384, 384, 3, pad_mode="same") | self.conv4 = conv(384, 384, 3, pad_mode="same") | ||||
| self.conv5 = conv(384, 256, 3, pad_mode="same") | self.conv5 = conv(384, 256, 3, pad_mode="same") | ||||
| self.relu = nn.ReLU() | self.relu = nn.ReLU() | ||||
| self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2) | |||||
| self.max_pool2d = P.MaxPool(ksize=3, strides=2) | |||||
| self.flatten = nn.Flatten() | self.flatten = nn.Flatten() | ||||
| self.fc1 = fc_with_initialize(6*6*256, 4096) | self.fc1 = fc_with_initialize(6*6*256, 4096) | ||||
| self.fc2 = fc_with_initialize(4096, 4096) | self.fc2 = fc_with_initialize(4096, 4096) | ||||