From fb568c7dbad163bea1d27883d7bfdfd9f82f6c79 Mon Sep 17 00:00:00 2001 From: wukesong Date: Thu, 3 Sep 2020 21:33:29 +0800 Subject: [PATCH] modify lenet normal --- model_zoo/official/cv/lenet/src/lenet.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/model_zoo/official/cv/lenet/src/lenet.py b/model_zoo/official/cv/lenet/src/lenet.py index 003c3e0b85..dd46d98198 100644 --- a/model_zoo/official/cv/lenet/src/lenet.py +++ b/model_zoo/official/cv/lenet/src/lenet.py @@ -14,6 +14,7 @@ # ============================================================================ """LeNet.""" import mindspore.nn as nn +from mindspore.common.initializer import Normal class LeNet5(nn.Cell): @@ -22,7 +23,7 @@ class LeNet5(nn.Cell): Args: num_class (int): Num classes. Default: 10. - channel (int): Num classes. Default: 1. + num_channel (int): Num channels. Default: 1. Returns: Tensor, output tensor @@ -30,14 +31,13 @@ class LeNet5(nn.Cell): >>> LeNet(num_class=10) """ - def __init__(self, num_class=10, channel=1): + def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() - self.num_class = num_class - self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid') + self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') - self.fc1 = nn.Dense(16 * 5 * 5, 120) - self.fc2 = nn.Dense(120, 84) - self.fc3 = nn.Dense(84, self.num_class) + self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02)) + self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02)) + self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02)) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten()