You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet50_expand_loss.py 13 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import numpy as np
  17. import mindspore.common.dtype as mstype
  18. import mindspore.context as context
  19. import mindspore.nn as nn
  20. import mindspore.ops.functional as F
  21. from mindspore import Tensor
  22. from mindspore.common.initializer import TruncatedNormal
  23. from mindspore.communication.management import init
  24. from mindspore.nn.loss.loss import _Loss
  25. from mindspore.nn.optim.momentum import Momentum
  26. from mindspore.ops import operations as P
  27. from mindspore.parallel import set_algo_parameters
  28. from mindspore.train.callback import Callback
  29. from mindspore.train.model import Model
  30. from mindspore.context import ParallelMode
  31. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  32. context.set_context(device_id=int(os.getenv('DEVICE_ID')))
  33. init()
  34. context.set_auto_parallel_context(gradients_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL)
  35. np.random.seed(10)
  36. def weight_variable():
  37. return TruncatedNormal(0.01)
  38. def _conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
  39. init_value = weight_variable()
  40. return nn.Conv2d(in_channels, out_channels,
  41. kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
  42. def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
  43. init_value = weight_variable()
  44. return nn.Conv2d(in_channels, out_channels,
  45. kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
  46. def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'):
  47. init_value = weight_variable()
  48. return nn.Conv2d(in_channels, out_channels,
  49. kernel_size=7, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value)
  50. def _fused_bn(channels, momentum=0.9):
  51. return nn.BatchNorm2d(channels, momentum=momentum)
  52. class BasicBlock(nn.Cell):
  53. expansion = 1
  54. def __init__(self,
  55. in_channels,
  56. out_channels,
  57. stride=1,
  58. momentum=0.1):
  59. super(BasicBlock, self).__init__()
  60. self.conv1 = _conv3x3(in_channels, out_channels, stride=stride)
  61. self.bn1 = _fused_bn(out_channels, momentum=momentum)
  62. self.conv2 = _conv3x3(out_channels, out_channels)
  63. self.bn2 = _fused_bn(out_channels, momentum=momentum)
  64. self.relu = P.ReLU()
  65. self.down_sample_layer = None
  66. self.downsample = (in_channels != out_channels)
  67. if self.downsample:
  68. self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channels,
  69. out_channels,
  70. stride=stride,
  71. padding=0),
  72. _fused_bn(out_channels,
  73. momentum=momentum)])
  74. self.add = P.TensorAdd()
  75. def construct(self, x):
  76. identity = x
  77. x = self.conv1(x)
  78. x = self.relu(x)
  79. x = self.conv2(x)
  80. if self.downsample:
  81. identity = self.down_sample_layer(identity)
  82. out = self.add(x, identity)
  83. out = self.relu(out)
  84. return out
  85. class ResidualBlock(nn.Cell):
  86. expansion = 4
  87. def __init__(self,
  88. in_channels,
  89. out_channels,
  90. stride=1):
  91. super(ResidualBlock, self).__init__()
  92. out_chls = out_channels // self.expansion
  93. self.conv1 = _conv1x1(in_channels, out_chls, stride=1)
  94. self.conv2 = _conv3x3(out_chls, out_chls, stride=stride)
  95. self.conv3 = _conv1x1(out_chls, out_channels, stride=1)
  96. self.relu = P.ReLU()
  97. self.downsample = (in_channels != out_channels)
  98. self.stride = stride
  99. if self.downsample:
  100. self.conv_down_sample = _conv1x1(in_channels, out_channels,
  101. stride=stride)
  102. elif self.stride != 1:
  103. self.maxpool_down = nn.MaxPool2d(kernel_size=1, stride=2, pad_mode='same')
  104. self.add = P.TensorAdd()
  105. def construct(self, x):
  106. identity = x
  107. out = self.conv1(x)
  108. out = self.relu(out)
  109. out = self.conv2(out)
  110. out = self.relu(out)
  111. out = self.conv3(out)
  112. if self.downsample:
  113. identity = self.conv_down_sample(identity)
  114. elif self.stride != 1:
  115. identity = self.maxpool_down(identity)
  116. out = self.add(out, identity)
  117. out = self.relu(out)
  118. return out
  119. class ResNet(nn.Cell):
  120. def __init__(self,
  121. block,
  122. layer_nums,
  123. in_channels,
  124. out_channels,
  125. strides=None,
  126. num_classes=100):
  127. super(ResNet, self).__init__()
  128. if strides is None:
  129. strides = [1, 2, 2, 2]
  130. if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
  131. raise ValueError("the length of "
  132. "layer_num, inchannel, outchannel list must be 4!")
  133. self.conv1 = _conv7x7(3, 64, stride=2)
  134. self.bn1 = _fused_bn(64)
  135. self.relu = P.ReLU()
  136. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
  137. self.layer1 = self._make_layer(block,
  138. layer_nums[0],
  139. in_channel=in_channels[0],
  140. out_channel=out_channels[0],
  141. stride=strides[0])
  142. self.layer2 = self._make_layer(block,
  143. layer_nums[1],
  144. in_channel=in_channels[1],
  145. out_channel=out_channels[1],
  146. stride=strides[1])
  147. self.layer3 = self._make_layer(block,
  148. layer_nums[2],
  149. in_channel=in_channels[2],
  150. out_channel=out_channels[2],
  151. stride=strides[2])
  152. self.layer4 = self._make_layer(block,
  153. layer_nums[3],
  154. in_channel=in_channels[3],
  155. out_channel=out_channels[3],
  156. stride=strides[3])
  157. self.mean = P.ReduceMean(keep_dims=True)
  158. self.end_point = nn.Dense(2048, num_classes, has_bias=True,
  159. weight_init=weight_variable(),
  160. bias_init=weight_variable()).add_flags_recursive(fp16=True)
  161. self.squeeze = P.Squeeze()
  162. self.cast = P.Cast()
  163. def _make_layer(self, block, layer_num, in_channel, out_channel, stride):
  164. layers = []
  165. resblk = block(in_channel, out_channel, stride=1)
  166. layers.append(resblk)
  167. for _ in range(1, layer_num - 1):
  168. resblk = block(out_channel, out_channel, stride=1)
  169. layers.append(resblk)
  170. resblk = block(out_channel, out_channel, stride=stride)
  171. layers.append(resblk)
  172. return nn.SequentialCell(layers)
  173. def construct(self, x):
  174. x = self.conv1(x)
  175. x = self.relu(x)
  176. c1 = self.maxpool(x)
  177. c2 = self.layer1(c1)
  178. c3 = self.layer2(c2)
  179. c4 = self.layer3(c3)
  180. c5 = self.layer4(c4)
  181. out = self.mean(c5, (2, 3))
  182. out = self.squeeze(out)
  183. out = self.end_point(out)
  184. return out
  185. def resnet50(class_num=10):
  186. return ResNet(ResidualBlock,
  187. [3, 4, 6, 3],
  188. [64, 256, 512, 1024],
  189. [256, 512, 1024, 2048],
  190. [2, 2, 2, 1],
  191. class_num)
  192. class SoftmaxCrossEntropyExpand(_Loss):
  193. def __init__(self, sparse=False):
  194. super(SoftmaxCrossEntropyExpand, self).__init__()
  195. self.exp = P.Exp()
  196. self.sum = P.ReduceSum(keep_dims=True)
  197. self.onehot = P.OneHot()
  198. self.on_value = Tensor(1.0, mstype.float32)
  199. self.off_value = Tensor(0.0, mstype.float32)
  200. self.div = P.Div()
  201. self.log = P.Log()
  202. self.sum_cross_entropy = P.ReduceSum(keep_dims=False)
  203. self.mul = P.Mul()
  204. self.mul2 = P.Mul()
  205. self.cast = P.Cast()
  206. self.mean = P.ReduceMean(keep_dims=False)
  207. self.sparse = sparse
  208. self.max = P.ReduceMax(keep_dims=True)
  209. self.sub = P.Sub()
  210. self.eps = Tensor(1e-24, mstype.float32)
  211. def construct(self, logit, label):
  212. logit = self.cast(logit, mstype.float32)
  213. logit_max = self.max(logit, -1)
  214. exp = self.exp(self.sub(logit, logit_max))
  215. exp_sum = self.sum(exp, -1)
  216. softmax_result = self.div(exp, exp_sum)
  217. if self.sparse:
  218. label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
  219. softmax_result_log = self.log(softmax_result + self.eps)
  220. loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
  221. loss = self.mul2(F.scalar_to_array(-1.0), loss)
  222. loss = self.mean(loss, -1)
  223. return loss
  224. rank_id = int(os.environ["RANK_ID"])
  225. device_num = int(os.environ["RANK_SIZE"])
  226. class DataGenerator():
  227. def get_parallel_blocks(self, input_, strategy):
  228. blocks = [input_]
  229. i = 0
  230. for stra in strategy:
  231. temp = []
  232. while blocks:
  233. block = blocks.pop(0)
  234. temp.extend(np.split(block, stra, axis=i))
  235. blocks.extend(temp)
  236. i += 1
  237. return blocks
  238. def generate_data(self, shape):
  239. data = np.arange(np.prod(shape)).reshape(shape)
  240. return data
  241. def input_data(self, shape):
  242. data = (self.generate_data(shape)).astype(np.float32)
  243. stra = [1] * len(shape)
  244. stra[0] = device_num
  245. datas = self.get_parallel_blocks(data, stra)
  246. return Tensor(data), Tensor(datas[rank_id])
  247. def label_data(self, shape):
  248. data = (self.generate_data(shape) * 1000 / np.prod(shape)).astype(np.int32)
  249. stra = [1] * len(shape)
  250. stra[0] = device_num
  251. datas = self.get_parallel_blocks(data, stra)
  252. return Tensor(data), Tensor(datas[rank_id])
  253. class Dataset():
  254. def __init__(self, predict, label, length=1, input_num=2, repeat_count=1):
  255. self.predict = predict
  256. self.label = label
  257. self.index = 0
  258. self.length = length
  259. self.input_num = input_num
  260. self.repeat_count = repeat_count
  261. def __iter__(self):
  262. return self
  263. def __next__(self):
  264. if self.index >= self.length:
  265. raise StopIteration
  266. self.index += 1
  267. if self.input_num == 2:
  268. return (self.predict, self.label)
  269. return (self.predict,)
  270. def reset(self):
  271. self.index = 0
  272. def get_dataset_size(self):
  273. return self.length
  274. def get_repeat_count(self):
  275. return self.repeat_count
  276. class ModelCallback(Callback):
  277. def __init__(self):
  278. super(ModelCallback, self).__init__()
  279. self.loss_list = []
  280. def epoch_end(self, run_context):
  281. cb_params = run_context.original_args()
  282. result = cb_params.net_outputs
  283. self.loss_list.append(result.asnumpy().mean())
  284. def test_train_feed(num_classes=65536):
  285. set_algo_parameters(elementwise_op_strategy_follow=True)
  286. parallel_callback = ModelCallback()
  287. data_gen = DataGenerator()
  288. _, input_part = data_gen.input_data((32 * 8, 3, 224, 224))
  289. _, label_part = data_gen.label_data((32 * 8,))
  290. dataset = Dataset(input_part, label_part)
  291. net = resnet50(num_classes)
  292. loss = SoftmaxCrossEntropyExpand(sparse=True)
  293. opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
  294. model = Model(net, loss_fn=loss, optimizer=opt)
  295. model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback)
  296. loss_value = np.array(parallel_callback.loss_list)
  297. expect_out = [11.11153, 11.090023, 11.050361, 10.994822, 10.924148]
  298. print(loss_value)
  299. assert np.allclose(loss_value, expect_out, 0.0001, 0.0001)