You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_layernorm.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import copy
  16. import numpy as np
  17. import pytest
  18. import mindspore.context as context
  19. from mindspore import Tensor
  20. import mindspore.nn as nn
  21. from mindspore.ops.operations import _grad_ops as G
  22. import mindspore.ops.operations as P
  23. class LayerNormNet(nn.Cell):
  24. def __init__(self, begin_norm_axis, begin_params_axis):
  25. super(LayerNormNet, self).__init__()
  26. self.layernorm = P.LayerNorm(begin_norm_axis, begin_params_axis)
  27. def construct(self, x, gamma, beta):
  28. return self.layernorm(x, gamma, beta)
  29. class LayerNormGradNet(nn.Cell):
  30. def __init__(self, begin_norm_axis, begin_params_axis):
  31. super(LayerNormGradNet, self).__init__()
  32. self.layernorm_grad = G.LayerNormGrad(begin_norm_axis, begin_params_axis)
  33. def construct(self, dy, x, var, mean, gamma):
  34. return self.layernorm_grad(dy, x, var, mean, gamma)
  35. def get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
  36. context.set_context(enable_graph_kernel=enable_graph_kernel)
  37. net = LayerNormNet(begin_norm_axis, begin_params_axis)
  38. output = net(x, gamma, beta)
  39. return output
  40. def get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, enable_graph_kernel=False):
  41. context.set_context(enable_graph_kernel=enable_graph_kernel)
  42. net = LayerNormGradNet(begin_norm_axis, begin_params_axis)
  43. output = net(x, dy, var, mean, gamma)
  44. return output
  45. def get_rtol_atol(dtype):
  46. if dtype == np.float16:
  47. return 1.e-3, 1.e-3
  48. return 1.e-4, 1.e-4
  49. def compare_result(expect, output, dtype):
  50. rtol, atol = get_rtol_atol(dtype)
  51. if isinstance(expect, (list, tuple)):
  52. assert isinstance(output, (list, tuple)) and len(expect) == len(output)
  53. expect_list = list(expect)
  54. output_list = list(output)
  55. for e, o in zip(expect_list, output_list):
  56. assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True)
  57. else:
  58. assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True)
  59. def test_layernorm(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
  60. begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
  61. begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
  62. assert 0 <= begin_norm_axis < len(shape)
  63. assert 0 <= begin_params_axis < len(shape)
  64. normalized_shape = shape[begin_params_axis:]
  65. np.random.seed(0)
  66. # input tensors
  67. x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
  68. gamma = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
  69. beta = Tensor(np.random.normal(0, 1, normalized_shape).astype(dtype))
  70. expect = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, False)
  71. output = get_layernorm_output(x, gamma, beta, begin_norm_axis, begin_params_axis, True)
  72. compare_result(expect, output, dtype)
  73. def test_layernorm_grad(shape, dtype, begin_norm_axis=-1, begin_params_axis=-1):
  74. begin_norm_axis = begin_norm_axis if begin_norm_axis >= 0 else begin_norm_axis + len(shape)
  75. begin_params_axis = begin_params_axis if begin_params_axis >= 0 else begin_params_axis + len(shape)
  76. assert 0 <= begin_norm_axis < len(shape)
  77. assert 0 <= begin_params_axis < len(shape)
  78. norm_axis = [i for i in range(begin_norm_axis, len(shape))]
  79. norm_shape = copy.deepcopy(shape)
  80. for i, _ in enumerate(norm_shape):
  81. if i in norm_axis:
  82. norm_shape[i] = 1
  83. params_shape = shape[begin_params_axis:]
  84. np.random.seed(0)
  85. # input tensors
  86. dy = Tensor(np.random.normal(0, 1, shape).astype(dtype))
  87. x = Tensor(np.random.normal(0, 1, shape).astype(dtype))
  88. var = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
  89. mean = Tensor(np.random.normal(0, 1, norm_shape).astype(dtype))
  90. gamma = Tensor(np.random.normal(0, 1, params_shape).astype(dtype))
  91. expect = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, False)
  92. output = get_layernorm_grad_output(x, dy, var, mean, gamma, begin_norm_axis, begin_params_axis, True)
  93. compare_result(expect, output, dtype)
  94. @pytest.mark.level0
  95. @pytest.mark.platform_x86_gpu_training
  96. @pytest.mark.env_onecard
  97. def test_layernorm_gpu():
  98. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  99. test_layernorm([4, 32, 32], np.float32, -1, -1)
  100. @pytest.mark.level0
  101. @pytest.mark.platform_arm_ascend_training
  102. @pytest.mark.platform_x86_ascend_training
  103. @pytest.mark.env_onecard
  104. def test_layernorm_ascend():
  105. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  106. test_layernorm([4, 32, 32], np.float16, -1, -1)
  107. test_layernorm([4, 32, 32], np.float32, -1, -1)
  108. @pytest.mark.level0
  109. @pytest.mark.platform_x86_gpu_training
  110. @pytest.mark.env_onecard
  111. def test_layernorm_grad_gpu():
  112. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  113. test_layernorm_grad([4, 32, 32], np.float32, -1, -1)
  114. @pytest.mark.level0
  115. @pytest.mark.platform_arm_ascend_training
  116. @pytest.mark.platform_x86_ascend_training
  117. @pytest.mark.env_onecard
  118. def test_layernorm_grad_ascend():
  119. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  120. test_layernorm_grad([2, 16, 32], np.float16, -1, -1)
  121. test_layernorm_grad([4, 32, 32], np.float32, -1, -1)