You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_initializer.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_initializer """
  16. import math
  17. from functools import reduce
  18. import numpy as np
  19. from scipy import stats
  20. import pytest as py
  21. import mindspore.common.initializer as init
  22. import mindspore as ms
  23. from mindspore import context
  24. from mindspore.nn import Conv2d
  25. from ..ut_filter import non_graph_engine
  26. # pylint: disable=W0212
  27. # W0212: protected-access
  28. class InitTwo(init.Initializer):
  29. """Initialize the array to two."""
  30. def _initialize(self, arr):
  31. init._assignment(arr, 2)
  32. def _check_value(tensor, value_min, value_max):
  33. nd = tensor.asnumpy()
  34. for ele in nd.flatten():
  35. if value_min <= ele <= value_max:
  36. continue
  37. raise TypeError('value_min = %d, ele = %d, value_max = %d'
  38. % (value_min, ele, value_max))
  39. def _check_uniform(tensor, boundary_a, boundary_b):
  40. samples = tensor.asnumpy().reshape((-1))
  41. _, p = stats.kstest(samples, 'uniform', (boundary_a, (boundary_b - boundary_a)))
  42. print("p-value is %f"%p)
  43. return p > 0.0001
  44. def test_init_Initializer():
  45. tensor = init.initializer(InitTwo(), [2, 2], ms.int32)
  46. assert tensor.shape() == (2, 2)
  47. _check_value(tensor, 2, 2)
  48. def test_init_tensor():
  49. tensor = ms.Tensor(np.zeros([1, 2, 3]))
  50. tensor = init.initializer(tensor, [1, 2, 3], ms.float32)
  51. assert tensor.shape() == (1, 2, 3)
  52. def test_init_zero_default_dtype():
  53. tensor = init.initializer(init.Zero(), [2, 2])
  54. assert tensor.dtype() == ms.float32
  55. _check_value(tensor, 0, 0)
  56. def test_init_zero():
  57. tensor = init.initializer(init.Zero(), [2, 2], ms.float32)
  58. _check_value(tensor, 0, 0)
  59. def test_init_zero_alias_default_dtype():
  60. tensor = init.initializer('zeros', [1, 2])
  61. assert tensor.dtype() == ms.float32
  62. _check_value(tensor, 0, 0)
  63. def test_init_zero_alias():
  64. tensor = init.initializer('zeros', [1, 2], ms.float32)
  65. _check_value(tensor, 0, 0)
  66. def test_init_one():
  67. tensor = init.initializer(init.One(), [2, 2], ms.float32)
  68. _check_value(tensor, 1, 1)
  69. def test_init_one_alias():
  70. tensor = init.initializer('ones', [1, 2], ms.float32)
  71. _check_value(tensor, 1, 1)
  72. def test_init_uniform():
  73. scale = 10
  74. tensor = init.initializer(init.Uniform(scale=scale), [5, 4], ms.float32)
  75. _check_value(tensor, -scale, scale)
  76. def test_init_uniform_alias():
  77. scale = 100
  78. tensor = init.initializer('uniform', [5, 4], ms.float32)
  79. _check_value(tensor, -scale, scale)
  80. def test_init_normal():
  81. tensor = init.initializer(init.Normal(), [5, 4], ms.float32)
  82. assert isinstance(tensor, ms.Tensor), 'tensor init failed!'
  83. def test_init_truncated_normal():
  84. tensor = init.initializer(init.TruncatedNormal(), [5, 4], ms.float32)
  85. assert isinstance(tensor, ms.Tensor), 'tensor init failed!'
  86. def test_init_normal_alias():
  87. tensor = init.initializer('normal', [5, 4], ms.float32)
  88. assert isinstance(tensor, ms.Tensor), 'tensor init failed!'
  89. def test_init_truncatednormal_alias():
  90. tensor = init.initializer('truncatednormal', [5, 4], ms.float32)
  91. assert isinstance(tensor, ms.Tensor), 'tensor init failed!'
  92. def test_init_abnormal():
  93. with py.raises(TypeError):
  94. init.initializer([''], [5, 4], ms.float32)
  95. def test_init_xavier_uniform():
  96. """ test_init_xavier_uniform """
  97. gain = 1.2
  98. tensor1 = init.initializer(init.XavierUniform(gain=gain), [20, 22], ms.float32)
  99. tensor2 = init.initializer(init.XavierUniform(), [20, 22], ms.float32)
  100. tensor3 = init.initializer(init.XavierUniform(gain=gain), [20, 22, 5, 5], ms.float32)
  101. tensor4 = init.initializer(init.XavierUniform(), [20, 22, 5, 5], ms.float32)
  102. tensor5 = init.initializer('xavier_uniform', [20, 22, 5, 5], ms.float32)
  103. tensor6 = init.initializer('xavier_uniform', [20, 22], ms.float32)
  104. tensor_dict = {tensor1: gain, tensor2: None, tensor3: gain, tensor4: None, tensor5: None, tensor6: None}
  105. for tensor, gain_value in tensor_dict.items():
  106. if gain_value is None:
  107. gain_value = 1
  108. shape = tensor.asnumpy().shape
  109. if len(shape) > 2:
  110. s = reduce(lambda x, y: x * y, shape[2:])
  111. else:
  112. s = 1
  113. n_in = shape[1] * s
  114. n_out = shape[0] * s
  115. std = gain_value * math.sqrt(2 / (n_in + n_out))
  116. boundary = std * math.sqrt(3)
  117. assert _check_uniform(tensor, -boundary, boundary)
  118. def test_init_xavier_uniform_error():
  119. with py.raises(ValueError):
  120. init.initializer(init.XavierUniform(), [6], ms.float32)
  121. def test_init_he_uniform():
  122. """ test_init_he_uniform """
  123. tensor1 = init.initializer(init.HeUniform(), [20, 22], ms.float32)
  124. tensor2 = init.initializer(init.HeUniform(), [20, 22, 5, 5], ms.float32)
  125. tensor3 = init.initializer('he_uniform', [20, 22, 5, 5], ms.float32)
  126. tensor4 = init.initializer('he_uniform', [20, 22], ms.float32)
  127. tensors = [tensor1, tensor2, tensor3, tensor4]
  128. for tensor in tensors:
  129. shape = tensor.asnumpy().shape
  130. if len(shape) > 2:
  131. s = reduce(lambda x, y: x * y, shape[2:])
  132. else:
  133. s = 1
  134. n_in = shape[1] * s
  135. std = math.sqrt(2 / n_in)
  136. boundary = std * math.sqrt(3)
  137. assert _check_uniform(tensor, -boundary, boundary)
  138. def test_init_he_uniform_error():
  139. with py.raises(ValueError):
  140. init.initializer(init.HeUniform(), [6], ms.float32)
  141. def test_conv2d_abnormal_kernel_negative():
  142. kernel = np.random.randn(64, 3, 7, 7).astype(np.float32)
  143. with py.raises(ValueError):
  144. ms.Model(
  145. Conv2d(in_channels=3, out_channels=64, kernel_size=-7, stride=3,
  146. padding=0, weight_init=ms.Tensor(kernel)))
  147. @non_graph_engine
  148. def test_conv2d_abnormal_kernel_normal():
  149. kernel = np.random.randn(64, 3, 7, 7).astype(np.float32)
  150. input_data = np.random.randn(32, 3, 224, 112).astype(np.float32)
  151. context.set_context(mode=context.GRAPH_MODE)
  152. model = ms.Model(
  153. Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=3,
  154. padding=0, weight_init=ms.Tensor(kernel)))
  155. model.predict(ms.Tensor(input_data))
  156. @non_graph_engine
  157. def test_conv2d_abnormal_kernel_truncated_normal():
  158. input_data = init.initializer(init.TruncatedNormal(), [64, 3, 7, 7], ms.float32)
  159. context.set_context(mode=context.GRAPH_MODE)
  160. model = ms.Model(
  161. Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=3,
  162. padding=0, weight_init="truncatednormal"))
  163. model.predict(input_data)