You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_layernorm.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.context as context
  18. from mindspore import Tensor
  19. from mindspore.nn import Cell
  20. import mindspore.nn as nn
  21. from mindspore.ops.operations import _grad_ops as G
  22. import mindspore.ops.operations as P
  23. class Net(Cell):
  24. def __init__(self):
  25. super(Net, self).__init__()
  26. self.layernorm = P.LayerNorm(1, 1)
  27. def construct(self, x, y, z):
  28. return self.layernorm(x, y, z)
  29. class LayerNormGradNet(nn.Cell):
  30. def __init__(self, begin_norm_axis, begin_params_axis):
  31. super(LayerNormGradNet, self).__init__()
  32. self.norm = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
  33. def construct(self, dy, x, var, mean, gamma):
  34. return self.norm(dy, x, var, mean, gamma)
  35. def LayerNormGradReference(x, dy, gamma, epsilon, begin_norm_axis, begin_params_axis):
  36. begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(x.shape)
  37. begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(x.shape)
  38. norm_axis = [i for i in range(begin_norm_axis, len(x.shape))]
  39. param_axis = [i for i in range(0, begin_params_axis)]
  40. num = 1
  41. for i in range(begin_norm_axis, len(x.shape)):
  42. num *= x.shape[i]
  43. mean = np.mean(x, axis=tuple(norm_axis), keepdims=True)
  44. var = np.var(x, axis=tuple(norm_axis), keepdims=True)
  45. gamma = gamma.reshape((*((1,) * begin_params_axis), *x.shape[begin_params_axis:]))
  46. dg = np.sum(dy * np.power(var + epsilon, -0.5) * (x - mean), axis=tuple(param_axis), keepdims=True)
  47. db = np.sum(dy, axis=tuple(param_axis), keepdims=True)
  48. sum1 = np.sum((-0.5) * dy * gamma * (x - mean) * np.power(var + epsilon, -1.5), axis=tuple(norm_axis),
  49. keepdims=True)
  50. sum2 = np.sum(dy * gamma, axis=tuple(norm_axis), keepdims=True)
  51. sum3 = np.sum(-2.0 * (x - mean), axis=tuple(norm_axis), keepdims=True)
  52. dx1 = dy * gamma * np.power(var + epsilon, -0.5)
  53. dx2 = sum1 * 2.0 / num * (x - mean)
  54. dx3 = ((-1.0) * np.power(var + epsilon, -0.5) * sum2 + (1.0 / num) * sum1 * sum3) * (1.0 / num)
  55. dx = dx1 + dx2 + dx3
  56. return dx, dg, db, mean, var
  57. def test_basic():
  58. input_x = np.random.normal(0, 1, [2, 3, 4, 3]).astype(np.float32)
  59. gamma = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
  60. beta = np.random.normal(0, 1, [3, 4, 3]).astype(np.float32)
  61. shape_x = [2, 3, 4, 3]
  62. begin_norm_axis = 1
  63. in_rank = len(shape_x)
  64. if begin_norm_axis < 0:
  65. norm_axis = begin_norm_axis + in_rank
  66. else:
  67. norm_axis = begin_norm_axis
  68. norm_axes = tuple(range(norm_axis, in_rank))
  69. mean = np.mean(input_x, axis=norm_axes, keepdims=True)
  70. mean_b = np.broadcast_to(mean, shape_x)
  71. diff = input_x - mean_b
  72. square = np.square(diff)
  73. smean = np.mean(square, axis=norm_axes, keepdims=True)
  74. smean_b = np.broadcast_to(smean, shape_x)
  75. meps = smean_b + 1e-5
  76. logs = np.log(meps)
  77. mul = logs * (-0.5)
  78. rsqrt = np.exp(mul)
  79. out = diff * rsqrt
  80. bn = out * gamma + beta
  81. expect = (bn, mean, smean)
  82. net = Net()
  83. net_result = net(Tensor(input_x), Tensor(gamma), Tensor(beta))
  84. if isinstance(net_result, tuple) and len(net_result) == 3:
  85. result = (net_result[0].asnumpy(), net_result[1].asnumpy(), net_result[2].asnumpy())
  86. res0 = np.allclose(expect[0], result[0], rtol=1.e-4, atol=1.e-4, equal_nan=True)
  87. assert res0
  88. res1 = np.allclose(expect[1], result[1], rtol=1.e-4, atol=1.e-7, equal_nan=True)
  89. assert res1
  90. res2 = np.allclose(expect[2], result[2], rtol=1.e-4, atol=1.e-7, equal_nan=True)
  91. assert res2
  92. else:
  93. assert False
  94. def test_layernormgrad():
  95. np.random.seed(0)
  96. begin_norm_axis = 1
  97. begin_params_axis = 1
  98. x_np = np.random.randn(4096, 3072).astype(np.float32)
  99. dy_np = np.random.randn(4096, 3072).astype(np.float32)
  100. gamma_np = np.random.randn(*x_np.shape[begin_params_axis:]).astype(np.float32)
  101. epsilon = 1e-11
  102. dx_np, dg_np, db_np, mean_np, var_np = LayerNormGradReference(x_np, dy_np, gamma_np, epsilon, begin_norm_axis,
  103. begin_params_axis)
  104. dy_ms = Tensor(dy_np)
  105. x_ms = Tensor(x_np)
  106. var_ms = Tensor(var_np)
  107. mean_ms = Tensor(mean_np)
  108. gamma_ms = Tensor(gamma_np)
  109. net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
  110. dx_ms, dg_ms, db_ms = net(x_ms, dy_ms, var_ms, mean_ms, gamma_ms)
  111. assert np.allclose(dx_ms.asnumpy(), dx_np, rtol=1e-6, atol=1e-6)
  112. assert np.allclose(dg_ms.asnumpy(), dg_np, rtol=1e-6, atol=1e-3)
  113. assert np.allclose(db_ms.asnumpy(), db_np, rtol=1e-6, atol=1e-3)
  114. @pytest.mark.level0
  115. @pytest.mark.platform_x86_gpu_training
  116. @pytest.mark.env_onecard
  117. def test_basic_gpu():
  118. context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
  119. test_basic()
  120. def test_basic_ascend():
  121. context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
  122. test_basic()
  123. @pytest.mark.level0
  124. @pytest.mark.platform_x86_gpu_training
  125. @pytest.mark.env_onecard
  126. def test_layernormgrad_gpu():
  127. context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
  128. test_layernormgrad()
  129. def test_layernormgrad_ascend():
  130. context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend")
  131. test_layernormgrad()