You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

googlenet.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """GoogleNet"""
  16. import mindspore.nn as nn
  17. from mindspore.common.initializer import TruncatedNormal
  18. from mindspore.ops import operations as P
  19. def weight_variable():
  20. """Weight variable."""
  21. return TruncatedNormal(0.02)
  22. class Conv2dBlock(nn.Cell):
  23. """
  24. Basic convolutional block
  25. Args:
  26. in_channles (int): Input channel.
  27. out_channels (int): Output channel.
  28. kernel_size (int): Input kernel size. Default: 1
  29. stride (int): Stride size for the first convolutional layer. Default: 1.
  30. padding (int): Implicit paddings on both sides of the input. Default: 0.
  31. pad_mode (str): Padding mode. Optional values are "same", "valid", "pad". Default: "same".
  32. Returns:
  33. Tensor, output tensor.
  34. """
  35. def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, pad_mode="same"):
  36. super(Conv2dBlock, self).__init__()
  37. self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
  38. padding=padding, pad_mode=pad_mode, weight_init=weight_variable())
  39. self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
  40. self.relu = nn.ReLU()
  41. def construct(self, x):
  42. x = self.conv(x)
  43. x = self.bn(x)
  44. x = self.relu(x)
  45. return x
  46. class Inception(nn.Cell):
  47. """
  48. Inception Block
  49. """
  50. def __init__(self, in_channels, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
  51. super(Inception, self).__init__()
  52. self.b1 = Conv2dBlock(in_channels, n1x1, kernel_size=1)
  53. self.b2 = nn.SequentialCell([Conv2dBlock(in_channels, n3x3red, kernel_size=1),
  54. Conv2dBlock(n3x3red, n3x3, kernel_size=3, padding=0)])
  55. self.b3 = nn.SequentialCell([Conv2dBlock(in_channels, n5x5red, kernel_size=1),
  56. Conv2dBlock(n5x5red, n5x5, kernel_size=3, padding=0)])
  57. self.maxpool = P.MaxPoolWithArgmax(ksize=3, strides=1, padding="same")
  58. self.b4 = Conv2dBlock(in_channels, pool_planes, kernel_size=1)
  59. self.concat = P.Concat(axis=1)
  60. def construct(self, x):
  61. branch1 = self.b1(x)
  62. branch2 = self.b2(x)
  63. branch3 = self.b3(x)
  64. cell, argmax = self.maxpool(x)
  65. branch4 = self.b4(cell)
  66. _ = argmax
  67. return self.concat((branch1, branch2, branch3, branch4))
  68. class GoogleNet(nn.Cell):
  69. """
  70. Googlenet architecture
  71. """
  72. def __init__(self, num_classes):
  73. super(GoogleNet, self).__init__()
  74. self.conv1 = Conv2dBlock(3, 64, kernel_size=7, stride=2, padding=0)
  75. self.maxpool1 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")
  76. self.conv2 = Conv2dBlock(64, 64, kernel_size=1)
  77. self.conv3 = Conv2dBlock(64, 192, kernel_size=3, padding=0)
  78. self.maxpool2 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")
  79. self.block3a = Inception(192, 64, 96, 128, 16, 32, 32)
  80. self.block3b = Inception(256, 128, 128, 192, 32, 96, 64)
  81. self.maxpool3 = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="same")
  82. self.block4a = Inception(480, 192, 96, 208, 16, 48, 64)
  83. self.block4b = Inception(512, 160, 112, 224, 24, 64, 64)
  84. self.block4c = Inception(512, 128, 128, 256, 24, 64, 64)
  85. self.block4d = Inception(512, 112, 144, 288, 32, 64, 64)
  86. self.block4e = Inception(528, 256, 160, 320, 32, 128, 128)
  87. self.maxpool4 = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="same")
  88. self.block5a = Inception(832, 256, 160, 320, 32, 128, 128)
  89. self.block5b = Inception(832, 384, 192, 384, 48, 128, 128)
  90. self.mean = P.ReduceMean(keep_dims=True)
  91. self.dropout = nn.Dropout(keep_prob=0.8)
  92. self.flatten = nn.Flatten()
  93. self.classifier = nn.Dense(1024, num_classes, weight_init=weight_variable(),
  94. bias_init=weight_variable())
  95. def construct(self, x):
  96. x = self.conv1(x)
  97. x, argmax = self.maxpool1(x)
  98. x = self.conv2(x)
  99. x = self.conv3(x)
  100. x, argmax = self.maxpool2(x)
  101. x = self.block3a(x)
  102. x = self.block3b(x)
  103. x, argmax = self.maxpool3(x)
  104. x = self.block4a(x)
  105. x = self.block4b(x)
  106. x = self.block4c(x)
  107. x = self.block4d(x)
  108. x = self.block4e(x)
  109. x, argmax = self.maxpool4(x)
  110. x = self.block5a(x)
  111. x = self.block5b(x)
  112. x = self.mean(x, (2, 3))
  113. x = self.flatten(x)
  114. x = self.classifier(x)
  115. _ = argmax
  116. return x