diff --git a/model_zoo/official/cv/alexnet/mindspore_hub_conf.py b/model_zoo/official/cv/alexnet/mindspore_hub_conf.py new file mode 100644 index 0000000000..b205736a9a --- /dev/null +++ b/model_zoo/official/cv/alexnet/mindspore_hub_conf.py @@ -0,0 +1,25 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config""" +from src.alexnet import AlexNet + +def alexnet(*args, **kwargs): + return AlexNet(*args, **kwargs) + + +def create_network(name, *args, **kwargs): + if name == "alexnet": + return alexnet(*args, **kwargs) + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/cv/alexnet/src/alexnet.py b/model_zoo/official/cv/alexnet/src/alexnet.py index 0b38da0cbc..2d558dc5e6 100644 --- a/model_zoo/official/cv/alexnet/src/alexnet.py +++ b/model_zoo/official/cv/alexnet/src/alexnet.py @@ -36,7 +36,7 @@ class AlexNet(nn.Cell): """ Alexnet """ - def __init__(self, num_classes=10, channel=3): + def __init__(self, num_classes=10, channel=3, include_top=True): super(AlexNet, self).__init__() self.conv1 = conv(channel, 96, 11, stride=4) self.conv2 = conv(96, 256, 5, pad_mode="same") @@ -45,10 +45,12 @@ class AlexNet(nn.Cell): self.conv5 = conv(384, 256, 3, pad_mode="same") self.relu = nn.ReLU() self.max_pool2d = P.MaxPool(ksize=3, strides=2) - self.flatten = nn.Flatten() - self.fc1 = fc_with_initialize(6*6*256, 4096) - self.fc2 = fc_with_initialize(4096, 4096) - self.fc3 = fc_with_initialize(4096, num_classes) + self.include_top = include_top + if self.include_top: + self.flatten = nn.Flatten() + self.fc1 = fc_with_initialize(6 * 6 * 256, 4096) + self.fc2 = fc_with_initialize(4096, 4096) + self.fc3 = fc_with_initialize(4096, num_classes) def construct(self, x): """define network""" @@ -65,6 +67,8 @@ class AlexNet(nn.Cell): x = self.conv5(x) x = self.relu(x) x = self.max_pool2d(x) + if not self.include_top: + return x x = self.flatten(x) x = self.fc1(x) x = self.relu(x) diff --git a/model_zoo/official/cv/lenet/mindspore_hub_conf.py b/model_zoo/official/cv/lenet/mindspore_hub_conf.py new file mode 100644 index 0000000000..835a26451e --- /dev/null +++ b/model_zoo/official/cv/lenet/mindspore_hub_conf.py @@ -0,0 +1,25 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config""" +from src.lenet import LeNet5 + +def lenet(*args, **kwargs): + return LeNet5(*args, **kwargs) + + +def create_network(name, *args, **kwargs): + if name == "lenet": + return lenet(*args, **kwargs) + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/cv/lenet/src/lenet.py b/model_zoo/official/cv/lenet/src/lenet.py index 8bec0ae5f9..cb878155f4 100644 --- a/model_zoo/official/cv/lenet/src/lenet.py +++ b/model_zoo/official/cv/lenet/src/lenet.py @@ -31,20 +31,25 @@ class LeNet5(nn.Cell): >>> LeNet(num_class=10) """ - def __init__(self, num_class=10, num_channel=1): + def __init__(self, num_class=10, num_channel=1, include_top=True): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') - self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) - self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) - self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.flatten = nn.Flatten() + self.include_top = include_top + if self.include_top: + self.flatten = nn.Flatten() + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) + def construct(self, x): x = self.max_pool2d(self.relu(self.conv1(x))) x = self.max_pool2d(self.relu(self.conv2(x))) + if not self.include_top: + return x x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x))