You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_fix_precision.py 6.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import numpy as np
  2. import mindspore.context as context
  3. from mindspore import Tensor, Parameter
  4. from mindspore.nn import Cell, Composite
  5. from mindspore.ops import operations as P
  6. import mindspore.ops.composite as C
  7. import logging
  8. from mindspore._checkparam import ParamValidator as validator
  9. from mindspore.ops import Primitive
  10. from mindspore._checkparam import Rel
  11. from mindspore.common.initializer import initializer
  12. from mindspore.nn.composite_ops.composite_ops import InplaceAssign
  13. log = logging.getLogger("ME")
  14. log.setLevel(level=logging.DEBUG)
  15. context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target="Ascend")
  16. class DtypeTest(Composite):
  17. def __init__(self, fix_precision = "float16"):
  18. super(DtypeTest, self).__init__()
  19. self.sum = P.ReduceSum()
  20. self.pow = P.Pow()
  21. self.sum.add_prim_attr("fix_precision", fix_precision)
  22. self.pow.add_prim_attr("fix_precision", fix_precision)
  23. def construct(self, x):
  24. res = self.sum(x, (0,))
  25. res = self.pow(res, 2.0)
  26. return res
  27. class Net(Cell):
  28. def __init__(self, fix_precision = "float16"):
  29. super(Net, self).__init__()
  30. self.net = DtypeTest(fix_precision)
  31. def construct(self, x):
  32. return P.Neg()(self.net(x))
  33. class FusedBatchNorm(Composite):
  34. def __init__(self,
  35. mode=0,
  36. epsilon=1e-5,
  37. momentum=0.1,
  38. fix_precision = "float16"):
  39. super(FusedBatchNorm, self).__init__()
  40. self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN)
  41. self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT)
  42. self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH)
  43. self.reduce1 = P.ReduceSum()
  44. self.reduce2 = P.ReduceSum()
  45. self.reshape1 = P.Reshape()
  46. self.reshape2 = P.Reshape()
  47. self.reshape3 = P.Reshape()
  48. self.reshape4 = P.Reshape()
  49. self.pow1 = P.Pow()
  50. self.pow2 = P.Pow()
  51. self.mul1 = P.Mul()
  52. self.mul2 = P.Mul()
  53. self.mul3 = P.Mul()
  54. self.mul4 = P.Mul()
  55. self.mul5 = P.Mul()
  56. self.mul6 = P.Mul()
  57. self.neg = P.Neg()
  58. self.sub1 = P.Sub()
  59. self.sub2 = P.Sub()
  60. self.rsqrt = P.Rsqrt()
  61. self.add1 = P.TensorAdd()
  62. self.add2 = P.TensorAdd()
  63. self.add3 = P.TensorAdd()
  64. self.add4 = P.TensorAdd()
  65. self.inplaceAssign1 = InplaceAssign()
  66. self.inplaceAssign2 = InplaceAssign()
  67. for _, value in vars(self).items():
  68. if isinstance(value, Primitive):
  69. value.add_prim_attr("fix_precision", fix_precision)
  70. def construct(self, x, scale, b, moving_mean, moving_variance):
  71. axes = (3, 2, 0) # NCHW
  72. # axes = (2, 1, 0) # NHWC
  73. shape = P.Shape()(x)
  74. value_num = 1
  75. for axis in axes:
  76. value_num *= shape[axis]
  77. # value_num = 4.0 * 4.0 * 16.0 # NCHW
  78. avg_num = 1.0 / P.Fill()(P.DType()(x), (1, ), value_num)
  79. data_square = self.pow1(x, 2.0)
  80. # cal mean
  81. data_sum =self.reduce1(x, axes)
  82. data_square_sum = self.reduce2(data_square, axes)
  83. data_mean = self.mul1(data_sum, avg_num)
  84. #data_mean = data_sum * avg_num
  85. data_square_mean = self.mul2(data_square_sum, avg_num)
  86. data_mean_square = self.pow2(data_mean, 2.0)
  87. # cal variance
  88. data_variance = self.sub1(data_square_mean, data_mean_square)
  89. def update_by_moving_average(hat_z, z, momentum):
  90. run = self.mul5(hat_z, momentum)
  91. now = self.mul6(z, 1.0 - momentum)
  92. return self.add4(run, now)
  93. _moving_mean = update_by_moving_average(moving_mean, data_mean, self.momentum)
  94. _moving_variance = update_by_moving_average(moving_variance,
  95. data_variance, self.momentum)
  96. # var + eps
  97. veps_no_bc = self.add1(data_variance, self.epsilon)
  98. # rsqrt(var + eps)
  99. rsveps_no_bc = self.rsqrt(veps_no_bc)
  100. # -mean
  101. mean2_no_bc = self.neg(data_mean)
  102. mid_shape = (1, shape[1], 1, 1)
  103. # scale * (x + mean) / sqrt(var + eps) + b
  104. dmean = self.add2(x, self.reshape1(mean2_no_bc, mid_shape)) # broadcast result error
  105. dmsve = self.mul3(dmean, self.reshape2(rsveps_no_bc, mid_shape))
  106. dmsveg = self.mul4(dmsve, self.reshape3(scale, mid_shape))
  107. outs = self.add3(dmsveg, self.reshape4(b, mid_shape))
  108. #outs = self.inplaceAssign1(moving_mean, _moving_mean, outs)
  109. #outs = self.inplaceAssign2(moving_variance, _moving_variance, outs)
  110. outs = InplaceAssign()(moving_mean, _moving_mean, outs)
  111. outs = InplaceAssign()(moving_variance, _moving_variance, outs)
  112. return outs
  113. class Net_BN(Cell):
  114. def __init__(self, fix_precision = "float16"):
  115. super(Net_BN, self).__init__()
  116. self.bn = FusedBatchNorm(fix_precision=fix_precision)
  117. self.gamma = Parameter(initializer('ones', [4]), name='gamma')
  118. self.beta = Parameter(initializer('zeros', [4]), name='beta')
  119. self.mean = Parameter(initializer('ones', [4]), name='mean')
  120. self.variance = Parameter(initializer('zeros', [4]), name='variance')
  121. def construct(self, x):
  122. return self.bn(x, self.gamma, self.beta, self.mean, self.variance)
  123. def test_dtype_test_float16():
  124. x = np.array([1.0, 2.0, 3.0]).astype(np.float16)
  125. net = Net()
  126. result = net(Tensor(x))
  127. print("=======================================")
  128. print("x: {}".format(x))
  129. print("result: {}".format(result))
  130. print("=======================================")
  131. def test_dtype_test_float32():
  132. x = np.array([1.0, 2.0, 3.0]).astype(np.float16)
  133. net = Net(fix_precision = "float32")
  134. result = net(Tensor(x))
  135. print("=======================================")
  136. print("x: {}".format(x))
  137. print("result: {}".format(result))
  138. print("=======================================")
  139. def test_composite_bn():
  140. x = np.random.normal(0, 1, [16, 4, 4, 4]).astype(np.float32)
  141. net = Net_BN(fix_precision = "float16")
  142. #net1 = Net1()
  143. output = net(Tensor(x))
  144. # output1 = net1(Tensor(x))
  145. print("=======================================")
  146. print("x:\n{}".format(x))
  147. print("output:\n{}".format(output))
  148. print("=======================================")
  149. # print("x:\n{}".format(x))
  150. # print("output1:\n{}".format(output1))
  151. # print("=======================================")
  152. #test_dtype_test_float16()
  153. #test_dtype_test_float32()
  154. test_composite_bn()