You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parse.py 9.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_parse.py
  17. @Author:
  18. @Date : 2019-01-23 17:13
  19. @Desc :
  20. """
  21. import logging
  22. import pytest
  23. import numpy as np
  24. import mindspore as ms
  25. import mindspore.nn as nn
  26. from mindspore import Tensor
  27. from mindspore import context
  28. from mindspore.ops import composite as C
  29. from mindspore.ops import operations as P
  30. from mindspore.common.api import ms_function, _executor
  31. from mindspore.ops._grad.grad_base import bprop_getters
  32. from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
  33. from mindspore.ops.functional import tensor_add
  34. from ...ut_filter import non_graph_engine
  35. # pylint: disable=W0613,W0612
  36. # W0613: unused-argument
  37. @pytest.fixture(name='enable_check_bprop')
  38. def fixture_enable_check_bprop():
  39. context.set_context(check_bprop=True)
  40. yield
  41. context.set_context(check_bprop=False)
  42. grad_all = C.GradOperation(get_all=True)
  43. log = logging.getLogger("test")
  44. log.setLevel(level=logging.ERROR)
  45. context.set_context(mode=context.GRAPH_MODE)
  46. # Test case: use the parse obj interface use default parameter
  47. class Net(nn.Cell):
  48. """ Net definition """
  49. def __init__(self, dim):
  50. super(Net, self).__init__()
  51. self.softmax1 = nn.Softmax(dim)
  52. self.softmax2 = nn.Softmax(dim + 1)
  53. def construct(self, input_data, input1=1+2+3+4):
  54. return self.softmax1(input_data)
  55. @non_graph_engine
  56. def test_parse_defalut_parameter_case2():
  57. """ test_parse_defalut_parameter_case2 """
  58. log.debug("begin test_parse_defalut_parameter_case2")
  59. net = Net(0)
  60. npd = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  61. log.debug("input value is: %r", npd)
  62. input_data = ms.Tensor(npd)
  63. input_data.set_dtype(ms.float32)
  64. log.debug("start run")
  65. output = net(input_data)
  66. value = output.asnumpy()
  67. log.debug("output value = %r", value)
  68. # Test case: use the variable parameter for parse object
  69. class Net1(nn.Cell):
  70. """ Net1 definition """
  71. def __init__(self):
  72. super(Net1, self).__init__()
  73. def construct(self, *args):
  74. x = args[0]
  75. return x
  76. def test_var_parameter_case2():
  77. """ test_var_parameter_case2 """
  78. log.debug("begin test_var_parameter_case2")
  79. net = Net1()
  80. npd = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  81. log.debug("input value is: %r", npd)
  82. input_data = ms.Tensor(npd)
  83. input_data.set_dtype(ms.float32)
  84. np1 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  85. input1 = ms.Tensor(np1)
  86. np2 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  87. input2 = ms.Tensor(np2)
  88. _executor.compile(net, input_data, input1, input2)
  89. # Test case: test the global flag
  90. g_x = Tensor(np.ones([3, 3]).astype(np.float32))
  91. @ms_function
  92. def tensor_add_global(x):
  93. """ tensor_add_global """
  94. global g_x
  95. res = tensor_add(x, g_x)
  96. return res
  97. @non_graph_engine
  98. def test_global_flag():
  99. """ test_global_flag """
  100. log.debug("begin test_global_flag")
  101. x = Tensor(np.ones([3, 3]).astype(np.float32))
  102. res = tensor_add_global(x)
  103. log.debug("finished test_global_flag, ret = %r", res)
  104. class NetWithNDarray(nn.Cell):
  105. """ NetWithNDarray definition """
  106. def __init__(self, dim):
  107. super(NetWithNDarray, self).__init__()
  108. self.softmax = nn.Softmax(dim)
  109. self.x = ms.Tensor(np.ones(shape=(1)).astype(np.float32))
  110. def construct(self, input_data):
  111. return self.softmax(input_data) * self.x
  112. @non_graph_engine
  113. def test_net_with_ndarray():
  114. """ test_net_with_ndarray """
  115. net = NetWithNDarray(0)
  116. input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  117. net(ms.Tensor(input_data))
  118. def test_bprop_with_wrong_output_num(enable_check_bprop):
  119. class BpropWithWrongOutputNum(PrimitiveWithInfer):
  120. @prim_attr_register
  121. def __init__(self):
  122. super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
  123. def __call__(self, x, y):
  124. return x
  125. def infer_shape(self, x_shape, yshape):
  126. return x_shape
  127. def infer_dtype(self, x_type, y_type):
  128. return x_type
  129. @bprop_getters.register(BpropWithWrongOutputNum)
  130. def get_bprop_with_wrong_output_num(self):
  131. """Generate bprop for BpropWithWrongOutputNum"""
  132. def bprop(x, y, out, dout):
  133. return (dout,)
  134. return bprop
  135. class BpropWithWrongOutputNumCell(nn.Cell):
  136. def __init__(self):
  137. super(BpropWithWrongOutputNumCell, self).__init__()
  138. def construct(self, x, y):
  139. return BpropWithWrongOutputNum()(x, y)
  140. with pytest.raises(ValueError):
  141. grad_all(BpropWithWrongOutputNumCell())(Tensor(np.array(1).astype(np.int32)),
  142. Tensor(np.array(2).astype(np.int32)))
  143. def test_bprop_with_wrong_output_type(enable_check_bprop):
  144. class BpropWithWrongOutputType(PrimitiveWithInfer):
  145. @prim_attr_register
  146. def __init__(self):
  147. super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
  148. def __call__(self, x):
  149. return x
  150. def infer_shape(self, x_shape):
  151. return x_shape
  152. def infer_dtype(self, x_type):
  153. return x_type
  154. @bprop_getters.register(BpropWithWrongOutputType)
  155. def get_bprop_with_wrong_output_type(self):
  156. """Generate bprop for BpropWithWrongOutputType"""
  157. def bprop(x, out, dout):
  158. return (1,)
  159. return bprop
  160. class BpropWithWrongOutputTypeCell(nn.Cell):
  161. def __init__(self):
  162. super(BpropWithWrongOutputTypeCell, self).__init__()
  163. def construct(self, x):
  164. return BpropWithWrongOutputType()(x)
  165. with pytest.raises(TypeError):
  166. grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
  167. def test_bprop_with_wrong_output_shape(enable_check_bprop):
  168. class BpropWithWrongOutputShape(PrimitiveWithInfer):
  169. @prim_attr_register
  170. def __init__(self):
  171. super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
  172. def __call__(self, x):
  173. return x
  174. def infer_shape(self, x_shape):
  175. return x_shape
  176. def infer_dtype(self, x_type):
  177. return x_type
  178. @bprop_getters.register(BpropWithWrongOutputShape)
  179. def get_bprop_with_wrong_output_shape(self):
  180. """Generate bprop for BpropWithWrongOutputShape"""
  181. ones = Tensor(np.ones([2,]).astype(np.int32))
  182. def bprop(x, out, dout):
  183. return (ones,)
  184. return bprop
  185. class BpropWithWrongOutputShapeCell(nn.Cell):
  186. def __init__(self):
  187. super(BpropWithWrongOutputShapeCell, self).__init__()
  188. def construct(self, x):
  189. return BpropWithWrongOutputShape()(x)
  190. with pytest.raises(ValueError):
  191. net = BpropWithWrongOutputShapeCell()
  192. net.set_grad()
  193. grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
  194. class AssignWhenInsertGrad(nn.Cell):
  195. """ NetWithNDarray definition """
  196. def __init__(self):
  197. super(AssignWhenInsertGrad, self).__init__()
  198. self.gather = P.Gather()
  199. self.damping = Tensor(np.array([0.03, 0.03]).astype(np.float32))
  200. self.cov_step = ms.Parameter(0, name="cov_step", requires_grad=False)
  201. self.freq = Tensor(278, ms.int32)
  202. self.getG = P.InsertGradientOf(self.save_gradient)
  203. def save_gradient(self, dout):
  204. self.cov_step = self.cov_step + self.freq
  205. return dout
  206. def construct(self, x):
  207. self.gather(self.damping, self.cov_step, 0)
  208. out = P.ReLU()(x)
  209. out = self.getG(out)
  210. return out
  211. grad_all = C.GradOperation(get_all=True)
  212. class GradNet(nn.Cell):
  213. def __init__(self, net):
  214. super(GradNet, self).__init__()
  215. self.net = net
  216. def construct(self, *inputs):
  217. out = self.net(*inputs)
  218. return out, grad_all(self.net)(*inputs)
  219. def test_assign_in_insert_grad():
  220. context.set_context(mode=context.GRAPH_MODE)
  221. net = AssignWhenInsertGrad().to_float(ms.float16)
  222. input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  223. net_back = GradNet(net)
  224. net_back(ms.Tensor(input_data))
  225. class Assign(nn.Cell):
  226. """ NetWithNDarray definition """
  227. def __init__(self):
  228. super(Assign, self).__init__()
  229. self.cov_step = ms.Parameter(0.0, name="cov_step", requires_grad=False)
  230. def construct(self, x):
  231. self.cov_step = self.cov_step + x
  232. return self.cov_step
  233. def test_assign(enable_check_bprop):
  234. context.set_context(mode=context.GRAPH_MODE)
  235. net = Assign()
  236. input_data = ms.Tensor(np.array(1).astype(np.int32))
  237. net_back = GradNet(net)
  238. net_back(input_data)
  239. class AssignCheck(nn.Cell):
  240. """ NetWithNDarray definition """
  241. def __init__(self):
  242. super(AssignCheck, self).__init__()
  243. self.cov_step = ms.Parameter(0.0, name="cov_step", requires_grad=False)
  244. def construct(self, x):
  245. self.cov_step = x
  246. return self.cov_step
  247. def test_assign_check_none():
  248. context.set_context(mode=context.GRAPH_MODE)
  249. net = AssignCheck()
  250. with pytest.raises(TypeError):
  251. net(None)