You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 15 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ResNet."""
  16. import math
  17. import numpy as np
  18. import mindspore
  19. from mindspore import ParameterTuple
  20. import mindspore.nn as nn
  21. from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits, L1Loss
  22. from mindspore.nn import Momentum
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import composite as C
  25. from mindspore.ops import functional as F
  26. from mindspore.common.initializer import HeNormal
  27. from mindspore.common.initializer import Normal
  28. from mindspore import Tensor
  29. from .stn import STN
  30. def _weight_variable(shape, factor=0.01):
  31. init_value = np.random.randn(*shape).astype(np.float32) * factor
  32. return Tensor(init_value)
  33. def _conv3x3(in_channel, out_channel, stride=1):
  34. n = 3*3*out_channel
  35. normal = Normal(math.sqrt(2. / n))
  36. return nn.Conv2d(in_channel, out_channel,
  37. kernel_size=3, stride=stride, padding=1, pad_mode='pad', weight_init=normal)
  38. def _conv1x1(in_channel, out_channel, stride=1):
  39. n = 1*1*out_channel
  40. normal = Normal(math.sqrt(2. / n))
  41. return nn.Conv2d(in_channel, out_channel,
  42. kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=normal)
  43. def _conv7x7(in_channel, out_channel, stride=1):
  44. n = 7*7*out_channel
  45. normal = Normal(math.sqrt(2. / n))
  46. return nn.Conv2d(in_channel, out_channel,
  47. kernel_size=7, stride=stride, padding=3, pad_mode='pad', weight_init=normal)
  48. def _bn(channel):
  49. return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
  50. gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1, use_batch_statistics=None)
  51. def _bn1(channel):
  52. return nn.BatchNorm1d(channel, eps=1e-4, momentum=0.9,
  53. gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1, use_batch_statistics=None)
  54. def _bn1_kaiming(channel):
  55. return nn.BatchNorm1d(channel, eps=1e-4, momentum=0.9,
  56. gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1, use_batch_statistics=None)
  57. def _bn2_kaiming(channel):
  58. return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
  59. gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1, use_batch_statistics=None)
  60. def _bn_last(channel):
  61. return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
  62. gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)
  63. def _fc(in_channel, out_channel):
  64. he_normal = HeNormal()
  65. return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=he_normal, bias_init='zeros')
  66. class ResidualBlock(nn.Cell):
  67. """
  68. ResNet V1 residual block definition.
  69. Args:
  70. in_channel (int): Input channel.
  71. out_channel (int): Output channel.
  72. stride (int): Stride size for the first convolutional layer. Default: 1.
  73. Returns:
  74. Tensor, output tensor.
  75. Examples:
  76. >>> ResidualBlock(3, 256, stride=2)
  77. """
  78. expansion = 4
  79. def __init__(self,
  80. in_channel,
  81. channel,
  82. out_channel,
  83. stride=1):
  84. super(ResidualBlock, self).__init__()
  85. self.conv1 = _conv1x1(in_channel, channel, stride=1)
  86. self.bn1 = _bn(channel)
  87. self.conv2 = _conv3x3(channel, channel, stride=stride)
  88. self.bn2 = _bn(channel)
  89. self.conv3 = _conv1x1(channel, out_channel, stride=1)
  90. self.bn3 = _bn(out_channel)
  91. self.relu = nn.ReLU()
  92. self.down_sample = False
  93. if stride != 1 or in_channel != out_channel:
  94. self.down_sample = True
  95. self.down_sample_layer = None
  96. if self.down_sample:
  97. self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride),
  98. _bn(out_channel)])
  99. self.add = P.TensorAdd()
  100. def construct(self, x):
  101. '''construct'''
  102. identity = x
  103. out = self.conv1(x)
  104. out = self.bn1(out)
  105. out = self.relu(out)
  106. out = self.conv2(out)
  107. out = self.bn2(out)
  108. out = self.relu(out)
  109. out = self.conv3(out)
  110. out = self.bn3(out)
  111. if self.down_sample:
  112. identity = self.down_sample_layer(identity)
  113. out = self.add(out, identity)
  114. out = self.relu(out)
  115. return out
  116. class HardAttn(nn.Cell):
  117. '''LPD module'''
  118. def __init__(self, in_channels):
  119. super(HardAttn, self).__init__()
  120. self.relu = nn.ReLU()
  121. self.fc1 = _fc(128*128, 32)
  122. self.bn1 = _bn1(32)
  123. self.fc2 = _fc(32, 4)
  124. self.bn2 = _bn1(4)
  125. self.reshape = P.Reshape()
  126. self.shape = P.Shape()
  127. self.reduce_mean = P.ReduceMean()
  128. def construct(self, x):
  129. '''construct'''
  130. x = self.reduce_mean(x, 1)
  131. x_size = self.shape(x)
  132. x = self.reshape(x, (x_size[0], 128*128))
  133. x = self.fc1(x)
  134. x = self.bn1(x)
  135. x = self.relu(x)
  136. x = self.fc2(x)
  137. x = self.bn2(x)
  138. x = self.reshape(x, (x_size[0], 4))
  139. return x
  140. class ResNet(nn.Cell):
  141. """
  142. ResNet architecture.
  143. Args:
  144. block (Cell): Block for network.
  145. layer_nums (list): Numbers of block in different layers.
  146. in_channels (list): Input channel in each layer.
  147. out_channels (list): Output channel in each layer.
  148. strides (list): Stride size in each layer.
  149. num_classes (int): The number of classes that the training images are belonging to.
  150. Returns:
  151. Tensor, output tensor.
  152. Examples:
  153. >>> ResNet(ResidualBlock,
  154. >>> [3, 4, 6, 3],
  155. >>> [64, 256, 512, 1024],
  156. >>> [256, 512, 1024, 2048],
  157. >>> [1, 2, 2, 2],
  158. >>> 10)
  159. """
  160. def __init__(self,
  161. block,
  162. layer_nums,
  163. in_channels,
  164. channels,
  165. out_channels,
  166. strides,
  167. num_classes, is_train):
  168. super(ResNet, self).__init__()
  169. if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
  170. raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")
  171. self.ha3 = HardAttn(2048)
  172. self.is_train = is_train
  173. self.conv1 = _conv7x7(3, 64, stride=2)
  174. self.bn1 = _bn(64)
  175. self.relu = nn.ReLU()
  176. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
  177. self.layer1 = self._make_layer(block,
  178. layer_nums[0],
  179. in_channel=in_channels[0],
  180. channel=channels[0],
  181. out_channel=out_channels[0],
  182. stride=strides[0])
  183. self.layer2 = self._make_layer(block,
  184. layer_nums[1],
  185. in_channel=in_channels[1],
  186. channel=channels[1],
  187. out_channel=out_channels[1],
  188. stride=strides[1])
  189. self.layer3 = self._make_layer(block,
  190. layer_nums[2],
  191. in_channel=in_channels[2],
  192. channel=channels[2],
  193. out_channel=out_channels[2],
  194. stride=strides[2])
  195. self.layer4 = self._make_layer(block,
  196. layer_nums[3],
  197. in_channel=in_channels[3],
  198. channel=channels[3],
  199. out_channel=out_channels[3],
  200. stride=strides[3])
  201. self.max = P.ReduceMax(keep_dims=True)
  202. self.flatten = nn.Flatten()
  203. self.global_bn = _bn2_kaiming(out_channels[3])
  204. self.partial_bn = _bn2_kaiming(out_channels[3])
  205. normal = Normal(0.001)
  206. self.global_fc = nn.Dense(out_channels[3], num_classes, has_bias=False, weight_init=normal, bias_init='zeros')
  207. self.partial_fc = nn.Dense(out_channels[3], num_classes, has_bias=False, weight_init=normal, bias_init='zeros')
  208. self.theta_0 = Tensor(np.zeros((128, 4)), mindspore.float32)
  209. self.theta_6 = Tensor(np.zeros((128, 4))+0.6, mindspore.float32)
  210. self.STN = STN(128, 128)
  211. self.concat = P.Concat(axis=1)
  212. self.shape = P.Shape()
  213. self.tanh = P.Tanh()
  214. self.slice = P.Slice()
  215. self.split = P.Split(1, 4)
  216. def _make_layer(self, block, layer_num, in_channel, channel, out_channel, stride):
  217. """
  218. Make stage network of ResNet.
  219. Args:
  220. block (Cell): Resnet block.
  221. layer_num (int): Layer number.
  222. in_channel (int): Input channel.
  223. out_channel (int): Output channel.
  224. stride (int): Stride size for the first convolutional layer.
  225. Returns:
  226. SequentialCell, the output layer.
  227. Examples:
  228. >>> _make_layer(ResidualBlock, 3, 128, 256, 2)
  229. """
  230. layers = []
  231. resnet_block = block(in_channel, channel, out_channel, stride=stride)
  232. layers.append(resnet_block)
  233. for _ in range(1, layer_num):
  234. resnet_block = block(out_channel, channel, out_channel, stride=1)
  235. layers.append(resnet_block)
  236. return nn.SequentialCell(layers)
  237. def stn(self, x, stn_theta):
  238. '''stn'''
  239. x_size = self.shape(x)
  240. theta = self.tanh(stn_theta)
  241. theta1, theta5, theta6, theta3 = self.split(theta)
  242. theta_0 = self.slice(self.theta_0, (0, 0), (x_size[0], 4))
  243. theta2, theta4, _, _ = self.split(theta_0)
  244. theta = self.concat((theta1, theta2, theta3, theta4, theta5, theta6))
  245. flip_feature = self.STN(x, theta)
  246. return flip_feature, theta5
  247. def construct(self, x):
  248. '''construct'''
  249. stn_theta = self.ha3(x)
  250. x_p, theta = self.stn(x, stn_theta)
  251. x = self.conv1(x)
  252. x = self.bn1(x)
  253. x = self.relu(x)
  254. c1 = self.maxpool(x)
  255. c2 = self.layer1(c1)
  256. c3 = self.layer2(c2)
  257. c4 = self.layer3(c3)
  258. c5 = self.layer4(c4)
  259. out = self.max(c5, (2, 3))
  260. out = self.global_bn(out)
  261. global_f = self.flatten(out)
  262. x_p = self.conv1(x_p)
  263. x_p = self.bn1(x_p)
  264. x_p = self.relu(x_p)
  265. c1_p = self.maxpool(x_p)
  266. c2_p = self.layer1(c1_p)
  267. c3_p = self.layer2(c2_p)
  268. c4_p = self.layer3(c3_p)
  269. c5_p = self.layer4(c4_p)
  270. out_p = self.max(c5_p, (2, 3))
  271. out_p = self.partial_bn(out_p)
  272. partial_f = self.flatten(out_p)
  273. global_out = self.global_fc(global_f)
  274. partial_out = self.partial_fc(partial_f)
  275. return global_f, partial_f, global_out, partial_out, theta
  276. class NetWithLossClass(nn.Cell):
  277. '''net with loss'''
  278. def __init__(self, network, is_train=True):
  279. super(NetWithLossClass, self).__init__(auto_prefix=False)
  280. self.loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  281. self.l1_loss = L1Loss()
  282. self.network = network
  283. self.is_train = is_train
  284. self.concat = P.Concat(axis=1)
  285. def construct(self, x, label1, label2):
  286. '''construct'''
  287. global_f, partial_f, global_out, partial_out, theta = self.network(x)
  288. if not self.is_train:
  289. out = self.concat((global_f, partial_f))
  290. return out
  291. loss_global = self.loss(global_out, label1)
  292. loss_partial = self.loss(partial_out, label1)
  293. loss_theta = self.l1_loss(theta, label2)
  294. loss = loss_global + loss_partial + loss_theta
  295. return loss
  296. class TrainStepWrap(nn.Cell):
  297. '''train step wrap'''
  298. def __init__(self, network, lr, momentum, is_train=True):
  299. super(TrainStepWrap, self).__init__(auto_prefix=False)
  300. self.network = network
  301. self.weights = ParameterTuple(network.trainable_params())
  302. self.optimizer = Momentum(self.weights, lr, momentum)
  303. self.grad = C.GradOperation(get_by_list=True)
  304. self.is_train = is_train
  305. def construct(self, x, labels1, labels2):
  306. '''construct'''
  307. weights = self.weights
  308. loss = self.network(x, labels1, labels2)
  309. if not self.is_train:
  310. return loss
  311. grads = self.grad(self.network, weights)(x, labels1, labels2)
  312. return F.depend(loss, self.optimizer(grads))
  313. class TestStepWrap(nn.Cell):
  314. """
  315. Predict method
  316. """
  317. def __init__(self, network):
  318. super(TestStepWrap, self).__init__(auto_prefix=False)
  319. self.network = network
  320. self.sigmoid = P.Sigmoid()
  321. def construct(self, x, labels):
  322. '''construct'''
  323. logits_global, _, _, _, = self.network(x)
  324. pred_probs = self.sigmoid(logits_global)
  325. return logits_global, pred_probs, labels
  326. def resnet50(class_num=10, is_train=True):
  327. """
  328. Get ResNet50 neural network.
  329. Args:
  330. class_num (int): Class number.
  331. Returns:
  332. Cell, cell instance of ResNet50 neural network.
  333. Examples:
  334. >>> net = resnet50(10)
  335. """
  336. return ResNet(ResidualBlock,
  337. [3, 4, 6, 3],
  338. [64, 256, 512, 1024],
  339. [64, 128, 256, 512],
  340. [256, 512, 1024, 2048],
  341. [1, 2, 2, 1],
  342. class_num, is_train)
  343. def resnet101(class_num=1001):
  344. """
  345. Get ResNet101 neural network.
  346. Args:
  347. class_num (int): Class number.
  348. Returns:
  349. Cell, cell instance of ResNet101 neural network.
  350. Examples:
  351. >>> net = resnet101(1001)
  352. """
  353. return ResNet(ResidualBlock,
  354. [3, 4, 23, 3],
  355. [64, 256, 512, 1024],
  356. [256, 512, 1024, 2048],
  357. [1, 2, 2, 2],
  358. class_num)