You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_one_hot_net.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import mindspore as ms
  16. import mindspore.nn as nn
  17. from mindspore import Tensor, Parameter
  18. from mindspore import context
  19. from mindspore.common import dtype as mstype
  20. from mindspore.common.api import _executor
  21. from mindspore.nn.cell import Cell
  22. from mindspore.nn.optim.momentum import Momentum
  23. from mindspore.ops import composite as C
  24. from mindspore.ops import functional as F
  25. from mindspore.ops import operations as P
  26. from mindspore.train import Model
  27. from mindspore.context import ParallelMode
  28. from tests.dataset_mock import MindData
  29. from tests.ut.python.ops.test_math_ops import VirtualLoss
  30. grad_all = C.GradOperation(get_all=True)
  31. device_num = 16
  32. device_id = 2
  33. class StrategyModel():
  34. onehot_strategy = ((1, device_num), (), ())
  35. twod_strategy = ((1, device_num),)
  36. twod_strategy_m = ((device_num, 1),)
  37. scalar_twod_strategy = ((), (1, device_num))
  38. twod_scalar_strategy = ((1, device_num), ())
  39. scalar_strategy = ((),)
  40. oned_strategy = ((1,),)
  41. scalar_scalar_strategy = ((), ())
  42. twod_twod_strategy = ((1, device_num), (1, device_num))
  43. twod_twodbc_strategy = ((1, device_num), (1, 1))
  44. twodbc_twod_strategy = ((1, 1), (device_num, 1))
  45. class StrategyBatch():
  46. onehot_strategy = ((device_num, 1), (), ())
  47. twod_strategy = ((1, device_num),)
  48. twod_strategy_m = ((device_num, 1),)
  49. scalar_twod_strategy = ((), (1, device_num))
  50. twod_scalar_strategy = ((1, device_num), ())
  51. scalar_strategy = ((),)
  52. oned_strategy = ((1,),)
  53. scalar_scalar_strategy = ((), ())
  54. twod_twod_strategy = ((1, device_num), (1, device_num))
  55. twod_twodbc_strategy = ((1, device_num), (1, 1))
  56. twodbc_twod_strategy = ((1, 1), (device_num, 1))
  57. class Args():
  58. a = 1
  59. b = 2
  60. c = 3
  61. d = 4
  62. e = 5
  63. num_classes = 512
  64. emb_size = 512
  65. class SemiAutoOneHotNet(Cell):
  66. def __init__(self, args, strategy):
  67. super(SemiAutoOneHotNet, self).__init__()
  68. self.a = args.a
  69. self.b = args.b
  70. self.c = args.c
  71. self.d = args.d
  72. self.e = args.e
  73. self.cast = P.Cast()
  74. self.cast.shard(strategy=strategy.twod_strategy)
  75. self.cast1 = P.Cast()
  76. self.cast1.shard(strategy=strategy.twod_strategy)
  77. self.cast2 = P.Cast()
  78. self.cast2.shard(strategy=strategy.twod_strategy)
  79. self.cast3 = P.Cast()
  80. self.cast3.shard(strategy=strategy.scalar_strategy)
  81. self.cast4 = P.Cast()
  82. self.cast4.shard(strategy=strategy.scalar_strategy)
  83. self.a_const = Tensor(self.a, dtype=mstype.float32)
  84. self.b_const = Tensor(self.b, dtype=mstype.float32)
  85. self.c_const = Tensor(self.c, dtype=mstype.float32)
  86. self.d_const = Tensor(self.d, dtype=mstype.float32)
  87. self.e_const = Tensor(self.e, dtype=mstype.float32)
  88. self.m_const_zero = Tensor(0, dtype=mstype.float32)
  89. self.a_const_one = Tensor(1, dtype=mstype.float32)
  90. self.onehot = P.OneHot()
  91. self.onehot.shard(strategy=strategy.onehot_strategy)
  92. self.exp = P.Exp()
  93. self.exp.shard(strategy=strategy.twod_strategy)
  94. self.exp2 = P.Exp()
  95. self.exp2.shard(strategy=strategy.twod_strategy)
  96. self.exp3 = P.Exp()
  97. self.exp3.shard(strategy=strategy.twod_strategy)
  98. self.mul_const = P.Mul()
  99. self.mul_const.shard(strategy=strategy.scalar_twod_strategy)
  100. self.mul_const2 = P.Add()
  101. self.mul_const2.shard(strategy=strategy.scalar_twod_strategy)
  102. self.mul_const3 = P.Sub()
  103. self.mul_const3.shard(strategy=strategy.twod_scalar_strategy)
  104. self.mul_const4 = P.Sub()
  105. self.mul_const4.shard(strategy=strategy.scalar_twod_strategy)
  106. self.mul_const5 = P.Mul()
  107. self.mul_const5.shard(strategy=strategy.twod_scalar_strategy)
  108. self.mul = P.Mul()
  109. self.mul.shard(strategy=strategy.twod_twod_strategy)
  110. self.mul2 = P.Mul()
  111. self.mul2.shard(strategy=strategy.twod_twod_strategy)
  112. self.mul3 = P.Add()
  113. self.mul3.shard(strategy=strategy.twod_twod_strategy)
  114. self.mul4 = P.Sub()
  115. self.mul4.shard(strategy=strategy.twod_twodbc_strategy)
  116. self.mul5 = P.RealDiv()
  117. self.mul5.shard(strategy=strategy.twod_twodbc_strategy)
  118. self.mul6 = P.Mul()
  119. self.mul6.shard(strategy=strategy.twod_twod_strategy)
  120. self.mul7 = P.Mul()
  121. self.mul7.shard(strategy=strategy.twod_scalar_strategy)
  122. self.mul8 = P.RealDiv()
  123. self.mul8.shard(strategy=strategy.scalar_scalar_strategy)
  124. self.mul9 = P.Add()
  125. self.mul9.shard(strategy=strategy.twod_scalar_strategy)
  126. self.reduce_max = P.ReduceMax(keep_dims=True)
  127. self.reduce_max.shard(strategy=strategy.twod_strategy)
  128. self.reduce_sum = P.ReduceSum(keep_dims=False)
  129. self.reduce_sum.shard(strategy=strategy.twod_strategy)
  130. self.reduce_sum_2 = P.ReduceSum(keep_dims=False)
  131. self.reduce_sum_2.shard(strategy=strategy.twod_strategy)
  132. self.reduce_sum_3 = P.ReduceSum(keep_dims=False)
  133. self.reduce_sum_3.shard(strategy=strategy.oned_strategy)
  134. self.reshape = P.Reshape()
  135. self.log = P.Log()
  136. self.log.shard(strategy=strategy.twod_strategy)
  137. self.on_value = Tensor(1.0, mstype.float32)
  138. self.off_value = Tensor(0.0, mstype.float32)
  139. self.normalize = P.L2Normalize(axis=1)
  140. self.normalize.shard(strategy=strategy.twod_strategy_m)
  141. self.normalize2 = P.L2Normalize(axis=1)
  142. self.normalize2.shard(strategy=strategy.twod_strategy_m)
  143. self.fc = P.MatMul(transpose_b=True)
  144. self.fc.shard(strategy=strategy.twodbc_twod_strategy)
  145. weight_shape = [args.num_classes, args.emb_size]
  146. weight_np = np.zeros(weight_shape, np.float32)
  147. self.weight = Parameter(Tensor(weight_np), name='model_parallel_weight')
  148. def construct(self, input_, label):
  149. input_n = self.normalize(input_)
  150. w = self.normalize2(self.weight)
  151. fc_o = self.fc(input_n, w)
  152. fc_o_shape = F.shape(fc_o)
  153. one_hot_float = self.onehot(label, fc_o_shape[1], self.on_value, self.off_value)
  154. local_label = self.cast(one_hot_float, mstype.int32)
  155. exp_o = self.exp(fc_o)
  156. mul_const_o = self.mul_const(self.a_const, exp_o)
  157. mul_const2_o = self.mul_const2(self.b_const, mul_const_o)
  158. exp2_o = self.exp2(mul_const2_o)
  159. mul_const3_o = self.mul_const3(exp2_o, self.c_const)
  160. mul_const4_o = self.mul_const4(F.scalar_to_array(1), local_label)
  161. mul6_o = self.mul6(self.mul(mul_const3_o, one_hot_float),
  162. self.mul2(fc_o, self.cast2(mul_const4_o, mstype.float32)))
  163. mul_const5_o = self.mul_const5(mul6_o, self.d_const)
  164. max_o = self.reduce_max(mul_const5_o, -1)
  165. mul4_o = self.mul4(mul_const5_o, max_o)
  166. exp3_o = self.exp3(mul4_o)
  167. sum_o = self.reduce_sum(exp3_o, -1)
  168. reshape_o = self.reshape(sum_o, (F.shape(sum_o)[0], 1))
  169. mul5_o = self.mul5(exp3_o, reshape_o)
  170. log_o = self.log(self.mul9(mul5_o, self.e_const))
  171. mul3_o = self.mul3(log_o, one_hot_float)
  172. mul7_o = self.mul7(mul3_o, self.cast3(F.scalar_to_array(-1), mstype.float32))
  173. sum2_o = self.reduce_sum_2(mul7_o, -1)
  174. loss = self.mul8(self.reduce_sum_3(sum2_o, -1),
  175. self.cast4(F.scalar_to_array(F.shape(mul_const5_o)[0]), mstype.float32))
  176. return loss
  177. class Dataset(MindData):
  178. def __init__(self, predict, label, length=3, input_num=2):
  179. super(Dataset, self).__init__(size=length)
  180. self.predict = predict
  181. self.label = label
  182. self.index = 0
  183. self.length = length
  184. self.input_num = input_num
  185. def __iter__(self):
  186. return self
  187. def __next__(self):
  188. if self.index >= self.length:
  189. raise StopIteration
  190. self.index += 1
  191. if self.input_num == 2:
  192. return (self.predict, self.label)
  193. return (self.predict,)
  194. def reset(self):
  195. self.index = 0
  196. class NetWithLoss(nn.Cell):
  197. def __init__(self, network):
  198. super(NetWithLoss, self).__init__()
  199. self.loss = VirtualLoss()
  200. self.network = network
  201. def construct(self, x, b):
  202. predict = self.network(x, b)
  203. return self.loss(predict)
  204. class GradWrap(nn.Cell):
  205. def __init__(self, network):
  206. super(GradWrap, self).__init__()
  207. self.network = network
  208. def construct(self, x, b):
  209. return grad_all(self.network)(x, b)
  210. def bn_with_initialize(out_channels):
  211. bn = nn.BatchNorm2d(out_channels, momentum=0.3, eps=1e-5).add_flags_recursive(fp32=True)
  212. return bn
  213. def fc_with_initialize(input_channels, out_channels):
  214. return nn.Dense(input_channels, out_channels)
  215. class BNReshapeDenseBNNet(nn.Cell):
  216. def __init__(self):
  217. super(BNReshapeDenseBNNet, self).__init__()
  218. self.batch_norm = bn_with_initialize(2)
  219. self.reshape = P.Reshape()
  220. self.batch_norm2 = nn.BatchNorm1d(512, affine=False)
  221. self.fc = fc_with_initialize(2 * 32 * 32, 512)
  222. self.loss = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch())
  223. def construct(self, x, label):
  224. x = self.batch_norm(x)
  225. x = self.reshape(x, (16, 2 * 32 * 32))
  226. x = self.fc(x)
  227. x = self.batch_norm2(x)
  228. loss = self.loss(x, label)
  229. return loss
  230. def test_bn_reshape_dense_bn_train_loss():
  231. batch_size = 16
  232. context.set_auto_parallel_context(device_num=device_num, global_rank=0)
  233. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  234. input_ = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01)
  235. label = Tensor(np.ones([batch_size]), dtype=ms.int32)
  236. net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
  237. net.set_auto_parallel()
  238. net.set_train()
  239. _executor.compile(net, input_, label)
  240. def test_semi_one_hot_net_batch():
  241. batch_size = 16
  242. context.set_auto_parallel_context(device_num=device_num, global_rank=0)
  243. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  244. input_ = Tensor(np.ones([batch_size * 1, 512]).astype(np.float32) * 0.01)
  245. label = Tensor(np.ones([batch_size]), dtype=ms.int32)
  246. net = SemiAutoOneHotNet(args=Args(), strategy=StrategyBatch())
  247. net = GradWrap(NetWithLoss(net))
  248. net.set_auto_parallel()
  249. net.set_train()
  250. _executor.compile(net, input_, label)
  251. def test_semi_one_hot_net_model():
  252. batch_size = 16
  253. learning_rate = 0.1
  254. momentum = 0.9
  255. epoch_size = 2
  256. predict = Tensor(np.ones([batch_size, 512]), dtype=ms.float32)
  257. label = Tensor(np.ones([batch_size]), dtype=ms.int32)
  258. dataset = Dataset(predict, label, 2, input_num=2)
  259. context.reset_auto_parallel_context()
  260. context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, device_num=16)
  261. context.set_context(mode=context.GRAPH_MODE)
  262. net = SemiAutoOneHotNet(args=Args(), strategy=StrategyModel())
  263. opt = Momentum(net.trainable_params(), learning_rate, momentum)
  264. model = Model(net, optimizer=opt)
  265. model.train(epoch_size, dataset, dataset_sink_mode=False)