| @@ -0,0 +1,25 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """hub config""" | |||
| from src.alexnet import AlexNet | |||
| def alexnet(*args, **kwargs): | |||
| return AlexNet(*args, **kwargs) | |||
| def create_network(name, *args, **kwargs): | |||
| if name == "alexnet": | |||
| return alexnet(*args, **kwargs) | |||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||
| @@ -36,7 +36,7 @@ class AlexNet(nn.Cell): | |||
| """ | |||
| Alexnet | |||
| """ | |||
| def __init__(self, num_classes=10, channel=3): | |||
| def __init__(self, num_classes=10, channel=3, include_top=True): | |||
| super(AlexNet, self).__init__() | |||
| self.conv1 = conv(channel, 96, 11, stride=4) | |||
| self.conv2 = conv(96, 256, 5, pad_mode="same") | |||
| @@ -45,10 +45,12 @@ class AlexNet(nn.Cell): | |||
| self.conv5 = conv(384, 256, 3, pad_mode="same") | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = P.MaxPool(ksize=3, strides=2) | |||
| self.flatten = nn.Flatten() | |||
| self.fc1 = fc_with_initialize(6*6*256, 4096) | |||
| self.fc2 = fc_with_initialize(4096, 4096) | |||
| self.fc3 = fc_with_initialize(4096, num_classes) | |||
| self.include_top = include_top | |||
| if self.include_top: | |||
| self.flatten = nn.Flatten() | |||
| self.fc1 = fc_with_initialize(6 * 6 * 256, 4096) | |||
| self.fc2 = fc_with_initialize(4096, 4096) | |||
| self.fc3 = fc_with_initialize(4096, num_classes) | |||
| def construct(self, x): | |||
| """define network""" | |||
| @@ -65,6 +67,8 @@ class AlexNet(nn.Cell): | |||
| x = self.conv5(x) | |||
| x = self.relu(x) | |||
| x = self.max_pool2d(x) | |||
| if not self.include_top: | |||
| return x | |||
| x = self.flatten(x) | |||
| x = self.fc1(x) | |||
| x = self.relu(x) | |||
| @@ -0,0 +1,25 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """hub config""" | |||
| from src.lenet import LeNet5 | |||
| def lenet(*args, **kwargs): | |||
| return LeNet5(*args, **kwargs) | |||
| def create_network(name, *args, **kwargs): | |||
| if name == "lenet": | |||
| return lenet(*args, **kwargs) | |||
| raise NotImplementedError(f"{name} is not implemented in the repo") | |||
| @@ -31,20 +31,25 @@ class LeNet5(nn.Cell): | |||
| >>> LeNet(num_class=10) | |||
| """ | |||
| def __init__(self, num_class=10, num_channel=1): | |||
| def __init__(self, num_class=10, num_channel=1, include_top=True): | |||
| super(LeNet5, self).__init__() | |||
| self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') | |||
| self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') | |||
| self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) | |||
| self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) | |||
| self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) | |||
| self.relu = nn.ReLU() | |||
| self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) | |||
| self.flatten = nn.Flatten() | |||
| self.include_top = include_top | |||
| if self.include_top: | |||
| self.flatten = nn.Flatten() | |||
| self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) | |||
| self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) | |||
| self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) | |||
| def construct(self, x): | |||
| x = self.max_pool2d(self.relu(self.conv1(x))) | |||
| x = self.max_pool2d(self.relu(self.conv2(x))) | |||
| if not self.include_top: | |||
| return x | |||
| x = self.flatten(x) | |||
| x = self.relu(self.fc1(x)) | |||
| x = self.relu(self.fc2(x)) | |||