You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_one_hot_net.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from mindspore.ops import operations as P
  15. from mindspore.ops import functional as F
  16. from mindspore import Tensor, Parameter
  17. from mindspore.common import dtype as mstype
  18. import mindspore.nn as nn
  19. import numpy as np
  20. from mindspore.nn.cell import Cell
  21. from tests.dataset_mock import MindData
  22. from mindspore.nn.optim.momentum import Momentum
  23. from mindspore.train import Model, ParallelMode
  24. from tests.ut.python.ops.test_math_ops import VirtualLoss
  25. from mindspore.ops import composite as C
  26. import mindspore as ms
  27. from mindspore.common.api import _executor
  28. from mindspore import context
  29. device_num=16
  30. device_id = 2
  31. class StrategyModel():
  32. onehot_strategy = ((1, device_num),(),())
  33. twod_strategy = ((1, device_num), )
  34. twod_strategy_m = ((device_num, 1), )
  35. scalar_twod_strategy = ((), (1, device_num))
  36. twod_scalar_strategy = ((1, device_num), ())
  37. scalar_strategy = ((), )
  38. oned_strategy = ((1, ), )
  39. scalar_scalar_strategy = ((), ())
  40. twod_twod_strategy = ((1, device_num), (1, device_num))
  41. twod_twodbc_strategy = ((1, device_num), (1, 1))
  42. twodbc_twod_strategy = ((1, 1), (device_num, 1))
  43. class StrategyBatch():
  44. onehot_strategy = ((device_num, 1),(),())
  45. twod_strategy = ((1, device_num), )
  46. twod_strategy_m = ((device_num, 1), )
  47. scalar_twod_strategy = ((), (1, device_num))
  48. twod_scalar_strategy = ((1, device_num), ())
  49. scalar_strategy = ((), )
  50. oned_strategy = ((1, ), )
  51. scalar_scalar_strategy = ((), ())
  52. twod_twod_strategy = ((1, device_num), (1, device_num))
  53. twod_twodbc_strategy = ((1, device_num), (1, 1))
  54. twodbc_twod_strategy = ((1, 1), (device_num, 1))
  55. class Args():
  56. a = 1
  57. b = 2
  58. c = 3
  59. d = 4
  60. e = 5
  61. num_classes = 512
  62. emb_size = 512
  63. class SemiAutoOneHotNet(Cell):
  64. def __init__(self, args, strategy):
  65. super(SemiAutoOneHotNet, self).__init__()
  66. self.a = args.a
  67. self.b = args.b
  68. self.c = args.c
  69. self.d = args.d
  70. self.e = args.e
  71. self.cast = P.Cast()
  72. self.cast.set_strategy(strategy=strategy.twod_strategy)
  73. self.cast1 = P.Cast()
  74. self.cast1.set_strategy(strategy=strategy.twod_strategy)
  75. self.cast2 = P.Cast()
  76. self.cast2.set_strategy(strategy=strategy.twod_strategy)
  77. self.cast3 = P.Cast()
  78. self.cast3.set_strategy(strategy=strategy.scalar_strategy)
  79. self.cast4 = P.Cast()
  80. self.cast4.set_strategy(strategy=strategy.scalar_strategy)
  81. self.a_const = Tensor(self.a, dtype=mstype.float32)
  82. self.b_const = Tensor(self.b, dtype=mstype.float32)
  83. self.c_const = Tensor(self.c, dtype=mstype.float32)
  84. self.d_const = Tensor(self.d, dtype=mstype.float32)
  85. self.e_const = Tensor(self.e, dtype=mstype.float32)
  86. self.m_const_zero = Tensor(0, dtype=mstype.float32)
  87. self.a_const_one = Tensor(1, dtype=mstype.float32)
  88. self.onehot = P.OneHot()
  89. self.onehot.set_strategy(strategy=strategy.onehot_strategy)
  90. self.exp = P.Exp()
  91. self.exp.set_strategy(strategy=strategy.twod_strategy)
  92. self.exp2 = P.Exp()
  93. self.exp2.set_strategy(strategy=strategy.twod_strategy)
  94. self.exp3 = P.Exp()
  95. self.exp3.set_strategy(strategy=strategy.twod_strategy)
  96. self.mul_const = P.Mul()
  97. self.mul_const.set_strategy(strategy=strategy.scalar_twod_strategy)
  98. self.mul_const2 = P.TensorAdd()
  99. self.mul_const2.set_strategy(strategy=strategy.scalar_twod_strategy)
  100. self.mul_const3 = P.Sub()
  101. self.mul_const3.set_strategy(strategy=strategy.twod_scalar_strategy)
  102. self.mul_const4 = P.Sub()
  103. self.mul_const4.set_strategy(strategy=strategy.scalar_twod_strategy)
  104. self.mul_const5 = P.Mul()
  105. self.mul_const5.set_strategy(strategy=strategy.twod_scalar_strategy)
  106. self.mul = P.Mul()
  107. self.mul.set_strategy(strategy=strategy.twod_twod_strategy)
  108. self.mul2 = P.Mul()
  109. self.mul2.set_strategy(strategy=strategy.twod_twod_strategy)
  110. self.mul3 = P.TensorAdd()
  111. self.mul3.set_strategy(strategy=strategy.twod_twod_strategy)
  112. self.mul4 = P.Sub()
  113. self.mul4.set_strategy(strategy=strategy.twod_twodbc_strategy)
  114. self.mul5 = P.RealDiv()
  115. self.mul5.set_strategy(strategy=strategy.twod_twodbc_strategy)
  116. self.mul6 = P.Mul()
  117. self.mul6.set_strategy(strategy=strategy.twod_twod_strategy)
  118. self.mul7 = P.Mul()
  119. self.mul7.set_strategy(strategy=strategy.twod_scalar_strategy)
  120. self.mul8 = P.RealDiv()
  121. self.mul8.set_strategy(strategy=strategy.scalar_scalar_strategy)
  122. self.mul9 = P.TensorAdd()
  123. self.mul9.set_strategy(strategy=strategy.twod_scalar_strategy)
  124. self.reduce_max = P.ReduceMax(keep_dims=True)
  125. self.reduce_max.set_strategy(strategy=strategy.twod_strategy)
  126. self.reduce_sum = P.ReduceSum(keep_dims=False)
  127. self.reduce_sum.set_strategy(strategy=strategy.twod_strategy)
  128. self.reduce_sum_2 = P.ReduceSum(keep_dims=False)
  129. self.reduce_sum_2.set_strategy(strategy=strategy.twod_strategy)
  130. self.reduce_sum_3 = P.ReduceSum(keep_dims=False)
  131. self.reduce_sum_3.set_strategy(strategy=strategy.oned_strategy)
  132. self.reshape = P.Reshape()
  133. self.log = P.Log()
  134. self.log.set_strategy(strategy=strategy.twod_strategy)
  135. self.on_value = Tensor(1.0, mstype.float32)
  136. self.off_value = Tensor(0.0, mstype.float32)
  137. self.normalize = P.L2Normalize(axis=1)
  138. self.normalize.set_strategy(strategy=strategy.twod_strategy_m)
  139. self.normalize2 = P.L2Normalize(axis=1)
  140. self.normalize2.set_strategy(strategy=strategy.twod_strategy_m)
  141. self.fc = P.MatMul(transpose_b=True)
  142. self.fc.set_strategy(strategy=strategy.twodbc_twod_strategy)
  143. weight_shape = [args.num_classes, args.emb_size]
  144. weight_np = np.zeros(weight_shape, np.float32)
  145. self.weight = Parameter(Tensor(weight_np), name='model_parallel_weight')
  146. def construct(self, input, label):
  147. input_n = self.normalize(input)
  148. w = self.normalize2(self.weight)
  149. fc_o = self.fc(input_n, w)
  150. fc_o_shape = F.shape(fc_o)
  151. one_hot_float = self.onehot(label, fc_o_shape[1],self.on_value, self.off_value)
  152. local_label = self.cast(one_hot_float, mstype.int32)
  153. exp_o = self.exp(fc_o)
  154. mul_const_o = self.mul_const(self.a_const, exp_o)
  155. mul_const2_o = self.mul_const2(self.b_const, mul_const_o)
  156. exp2_o = self.exp2(mul_const2_o)
  157. mul_const3_o = self.mul_const3(exp2_o, self.c_const)
  158. mul_const4_o = self.mul_const4(F.scalar_to_array(1), local_label)
  159. mul6_o = self.mul6(self.mul(mul_const3_o, one_hot_float), self.mul2(fc_o, self.cast2(mul_const4_o, mstype.float32)))
  160. mul_const5_o = self.mul_const5(mul6_o, self.d_const)
  161. max_o = self.reduce_max(mul_const5_o, -1)
  162. mul4_o = self.mul4(mul_const5_o, max_o)
  163. exp3_o = self.exp3(mul4_o)
  164. sum_o = self.reduce_sum(exp3_o, -1)
  165. reshape_o = self.reshape(sum_o, (F.shape(sum_o)[0], 1))
  166. mul5_o = self.mul5(exp3_o, reshape_o)
  167. log_o = self.log(self.mul9(mul5_o, self.e_const))
  168. mul3_o = self.mul3(log_o, one_hot_float)
  169. mul7_o = self.mul7(mul3_o, self.cast3(F.scalar_to_array(-1), mstype.float32))
  170. sum2_o = self.reduce_sum_2(mul7_o, -1)
  171. loss = self.mul8(self.reduce_sum_3(sum2_o, -1), self.cast4(F.scalar_to_array(F.shape(mul_const5_o)[0]), mstype.float32))
  172. return loss
  173. class Dataset(MindData):
  174. def __init__(self, predict, label, length=3, input_num=2):
  175. super(Dataset, self).__init__(size=length)
  176. self.predict = predict
  177. self.label = label
  178. self.index = 0
  179. self.length = length
  180. self.input_num = input_num
  181. def __iter__(self):
  182. return self
  183. def __next__(self):
  184. if self.index >= self.length:
  185. raise StopIteration
  186. self.index += 1
  187. if self.input_num == 2:
  188. return self.predict, self.label
  189. else:
  190. return self.predict,
  191. def reset(self):
  192. self.index = 0
  193. class NetWithLoss(nn.Cell):
  194. def __init__(self, network):
  195. super(NetWithLoss, self).__init__()
  196. self.loss = VirtualLoss()
  197. self.network = network
  198. def construct(self, x, b):
  199. predict = self.network(x, b)
  200. return self.loss(predict)
  201. class GradWrap(nn.Cell):
  202. def __init__(self, network):
  203. super(GradWrap, self).__init__()
  204. self.network = network
  205. def construct(self, x, b):
  206. return C.grad_all(self.network)(x, b)
  207. def bn_with_initialize(out_channels):
  208. bn = nn.BatchNorm2d(out_channels, momentum=0.3, eps=1e-5).add_flags_recursive(fp32=True)
  209. return bn
  210. def fc_with_initialize(input_channels, out_channels):
  211. return nn.Dense(input_channels, out_channels)
  212. class BNReshapeDenseBNNet(nn.Cell):
  213. def __init__(self):
  214. super(BNReshapeDenseBNNet, self).__init__()
  215. self.batch_norm = bn_with_initialize(2)
  216. self.reshape = P.Reshape()
  217. self.batch_norm2 = nn.BatchNorm1d(512, affine=False)
  218. self.fc = fc_with_initialize(2 * 32 * 32, 512)
  219. self.loss = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch())
  220. def construct(self, x, label):
  221. x = self.batch_norm(x)
  222. x = self.reshape(x, (16, 2*32*32))
  223. x = self.fc(x)
  224. x = self.batch_norm2(x)
  225. loss = self.loss(x, label)
  226. return loss
  227. def test_bn_reshape_dense_bn_train_loss():
  228. batch_size = 16
  229. device_num = 16
  230. context.set_auto_parallel_context(device_num=device_num, global_rank=0)
  231. input = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01)
  232. label = Tensor(np.ones([batch_size]), dtype=ms.int32)
  233. net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
  234. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  235. _executor.compile(net, input, label)
  236. def test_semi_one_hot_net_batch():
  237. batch_size = 16
  238. context.set_auto_parallel_context(device_num=device_num, global_rank=0)
  239. input = Tensor(np.ones([batch_size * 1, 512]).astype(np.float32) * 0.01)
  240. label = Tensor(np.ones([batch_size]), dtype=ms.int32)
  241. net = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch())
  242. net = GradWrap(NetWithLoss(net))
  243. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  244. _executor.compile(net, input, label)
  245. def test_semi_one_hot_net_model():
  246. batch_size = 16
  247. learning_rate = 0.1
  248. momentum = 0.9
  249. epoch_size = 2
  250. predict = Tensor(np.ones([batch_size, 512]), dtype=ms.float32)
  251. label = Tensor(np.ones([batch_size]), dtype=ms.int32)
  252. dataset = Dataset(predict, label, 2, input_num=2)
  253. net = SemiAutoOneHotNet(args=Args(), strategy=StrategyModel())
  254. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  255. context.reset_auto_parallel_context()
  256. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=16)
  257. context.set_context(mode=context.GRAPH_MODE)
  258. model = Model(net, optimizer=opt)
  259. model.train(epoch_size, dataset, dataset_sink_mode=False)