You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnetv1_5.py 9.6 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore.nn as nn
  17. from mindspore import Tensor
  18. from mindspore.ops import operations as P
  19. def weight_variable(shape):
  20. ones = np.ones(shape).astype(np.float32)
  21. return Tensor(ones * 0.01)
  22. def weight_variable_0(shape):
  23. zeros = np.zeros(shape).astype(np.float32)
  24. return Tensor(zeros)
  25. def weight_variable_1(shape):
  26. ones = np.ones(shape).astype(np.float32)
  27. return Tensor(ones)
  28. def conv3x3(in_channels, out_channels, stride=1, padding=0):
  29. """3x3 convolution """
  30. weight_shape = (out_channels, in_channels, 3, 3)
  31. weight = weight_variable(weight_shape)
  32. return nn.Conv2d(in_channels, out_channels,
  33. kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
  34. def conv1x1(in_channels, out_channels, stride=1, padding=0):
  35. """1x1 convolution"""
  36. weight_shape = (out_channels, in_channels, 1, 1)
  37. weight = weight_variable(weight_shape)
  38. return nn.Conv2d(in_channels, out_channels,
  39. kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
  40. def conv7x7(in_channels, out_channels, stride=1, padding=0):
  41. """1x1 convolution"""
  42. weight_shape = (out_channels, in_channels, 7, 7)
  43. weight = weight_variable(weight_shape)
  44. return nn.Conv2d(in_channels, out_channels,
  45. kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
  46. def bn_with_initialize(out_channels):
  47. shape = (out_channels)
  48. mean = weight_variable_0(shape)
  49. var = weight_variable_1(shape)
  50. beta = weight_variable_0(shape)
  51. gamma = weight_variable_1(shape)
  52. bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma,
  53. beta_init=beta, moving_mean_init=mean, moving_var_init=var)
  54. return bn
  55. def bn_with_initialize_last(out_channels):
  56. shape = (out_channels)
  57. mean = weight_variable_0(shape)
  58. var = weight_variable_1(shape)
  59. beta = weight_variable_0(shape)
  60. gamma = weight_variable_0(shape)
  61. bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma,
  62. beta_init=beta, moving_mean_init=mean, moving_var_init=var)
  63. return bn
  64. def fc_with_initialize(input_channels, out_channels):
  65. weight_shape = (out_channels, input_channels)
  66. bias_shape = (out_channels)
  67. weight = weight_variable(weight_shape)
  68. bias = weight_variable_0(bias_shape)
  69. return nn.Dense(input_channels, out_channels, weight, bias)
  70. class ResidualBlock(nn.Cell):
  71. expansion = 4
  72. def __init__(self,
  73. in_channels,
  74. out_channels,
  75. stride=1):
  76. super(ResidualBlock, self).__init__()
  77. out_chls = out_channels // self.expansion
  78. self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
  79. self.bn1 = bn_with_initialize(out_chls)
  80. self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=0)
  81. self.bn2 = bn_with_initialize(out_chls)
  82. self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
  83. self.bn3 = bn_with_initialize_last(out_channels)
  84. self.relu = P.ReLU()
  85. self.add = P.TensorAdd()
  86. def construct(self, x):
  87. identity = x
  88. out = self.conv1(x)
  89. out = self.bn1(out)
  90. out = self.relu(out)
  91. out = self.conv2(out)
  92. out = self.bn2(out)
  93. out = self.relu(out)
  94. out = self.conv3(out)
  95. out = self.bn3(out)
  96. out = self.add(out, identity)
  97. out = self.relu(out)
  98. return out
  99. class ResidualBlockWithDown(nn.Cell):
  100. expansion = 4
  101. def __init__(self,
  102. in_channels,
  103. out_channels,
  104. stride=1,
  105. down_sample=False):
  106. super(ResidualBlockWithDown, self).__init__()
  107. out_chls = out_channels // self.expansion
  108. self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
  109. self.bn1 = bn_with_initialize(out_chls)
  110. self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=0)
  111. self.bn2 = bn_with_initialize(out_chls)
  112. self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
  113. self.bn3 = bn_with_initialize_last(out_channels)
  114. self.relu = P.ReLU()
  115. self.downSample = down_sample
  116. self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0)
  117. self.bn_down_sample = bn_with_initialize(out_channels)
  118. self.add = P.TensorAdd()
  119. def construct(self, x):
  120. identity = x
  121. out = self.conv1(x)
  122. out = self.bn1(out)
  123. out = self.relu(out)
  124. out = self.conv2(out)
  125. out = self.bn2(out)
  126. out = self.relu(out)
  127. out = self.conv3(out)
  128. out = self.bn3(out)
  129. identity = self.conv_down_sample(identity)
  130. identity = self.bn_down_sample(identity)
  131. out = self.add(out, identity)
  132. out = self.relu(out)
  133. return out
  134. class MakeLayer0(nn.Cell):
  135. def __init__(self, block, in_channels, out_channels, stride):
  136. super(MakeLayer0, self).__init__()
  137. self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True)
  138. self.b = block(out_channels, out_channels, stride=stride)
  139. self.c = block(out_channels, out_channels, stride=1)
  140. def construct(self, x):
  141. x = self.a(x)
  142. x = self.b(x)
  143. x = self.c(x)
  144. return x
  145. class MakeLayer1(nn.Cell):
  146. def __init__(self, block, in_channels, out_channels, stride):
  147. super(MakeLayer1, self).__init__()
  148. self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
  149. self.b = block(out_channels, out_channels, stride=1)
  150. self.c = block(out_channels, out_channels, stride=1)
  151. self.d = block(out_channels, out_channels, stride=1)
  152. def construct(self, x):
  153. x = self.a(x)
  154. x = self.b(x)
  155. x = self.c(x)
  156. x = self.d(x)
  157. return x
  158. class MakeLayer2(nn.Cell):
  159. def __init__(self, block, in_channels, out_channels, stride):
  160. super(MakeLayer2, self).__init__()
  161. self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
  162. self.b = block(out_channels, out_channels, stride=1)
  163. self.c = block(out_channels, out_channels, stride=1)
  164. self.d = block(out_channels, out_channels, stride=1)
  165. self.e = block(out_channels, out_channels, stride=1)
  166. self.f = block(out_channels, out_channels, stride=1)
  167. def construct(self, x):
  168. x = self.a(x)
  169. x = self.b(x)
  170. x = self.c(x)
  171. x = self.d(x)
  172. x = self.e(x)
  173. x = self.f(x)
  174. return x
  175. class MakeLayer3(nn.Cell):
  176. def __init__(self, block, in_channels, out_channels, stride):
  177. super(MakeLayer3, self).__init__()
  178. self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
  179. self.b = block(out_channels, out_channels, stride=1)
  180. self.c = block(out_channels, out_channels, stride=1)
  181. def construct(self, x):
  182. x = self.a(x)
  183. x = self.b(x)
  184. x = self.c(x)
  185. return x
  186. class ResNet(nn.Cell):
  187. def __init__(self, block, num_classes=100, batch_size=32):
  188. super(ResNet, self).__init__()
  189. self.batch_size = batch_size
  190. self.num_classes = num_classes
  191. self.conv1 = conv7x7(3, 64, stride=2, padding=0)
  192. self.bn1 = bn_with_initialize(64)
  193. self.relu = P.ReLU()
  194. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="SAME")
  195. self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1)
  196. self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2)
  197. self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2)
  198. self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2)
  199. self.pool = P.ReduceMean(keep_dims=True)
  200. self.fc = fc_with_initialize(512 * block.expansion, num_classes)
  201. self.flatten = nn.Flatten()
  202. def construct(self, x):
  203. x = self.conv1(x)
  204. x = self.bn1(x)
  205. x = self.relu(x)
  206. x = self.maxpool(x)
  207. x = self.layer1(x)
  208. x = self.layer2(x)
  209. x = self.layer3(x)
  210. x = self.layer4(x)
  211. x = self.pool(x, (-2, -1))
  212. x = self.flatten(x)
  213. x = self.fc(x)
  214. return x
  215. def resnet50(batch_size, num_classes):
  216. return ResNet(ResidualBlock, num_classes, batch_size)