You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parse.py 9.2 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_parse.py
  17. @Author:
  18. @Date : 2019-01-23 17:13
  19. @Desc :
  20. """
  21. import logging
  22. import pytest
  23. import numpy as np
  24. import mindspore as ms
  25. import mindspore.nn as nn
  26. from mindspore import Tensor
  27. from mindspore import context
  28. from mindspore.ops import composite as C
  29. from mindspore.ops import operations as P
  30. from mindspore.common.api import ms_function, _executor
  31. from mindspore.ops._grad.grad_base import bprop_getters
  32. from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
  33. from mindspore.ops.functional import tensor_add
  34. from ...ut_filter import non_graph_engine
  35. # pylint: disable=W0613,W0612
  36. # W0613: unused-argument
  37. grad_all = C.GradOperation(get_all=True)
  38. log = logging.getLogger("test")
  39. log.setLevel(level=logging.ERROR)
  40. context.set_context(mode=context.GRAPH_MODE)
  41. # Test case: use the parse obj interface use default parameter
  42. class Net(nn.Cell):
  43. """ Net definition """
  44. def __init__(self, dim):
  45. super(Net, self).__init__()
  46. self.softmax1 = nn.Softmax(dim)
  47. self.softmax2 = nn.Softmax(dim + 1)
  48. def construct(self, input_data, input1=ms.Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))):
  49. return self.softmax1(input_data)
  50. @non_graph_engine
  51. def test_parse_defalut_parameter_case2():
  52. """ test_parse_defalut_parameter_case2 """
  53. log.debug("begin test_parse_defalut_parameter_case2")
  54. net = Net(0)
  55. npd = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  56. log.debug("input value is: %r", npd)
  57. input_data = ms.Tensor(npd)
  58. input_data.set_dtype(ms.float32)
  59. log.debug("start run")
  60. output = net(input_data)
  61. value = output.asnumpy()
  62. log.debug("output value = %r", value)
  63. # Test case: use the variable parameter for parse object
  64. class Net1(nn.Cell):
  65. """ Net1 definition """
  66. def __init__(self):
  67. super(Net1, self).__init__()
  68. def construct(self, *args):
  69. x = args[0]
  70. return x
  71. def test_var_parameter_case2():
  72. """ test_var_parameter_case2 """
  73. log.debug("begin test_var_parameter_case2")
  74. net = Net1()
  75. npd = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  76. log.debug("input value is: %r", npd)
  77. input_data = ms.Tensor(npd)
  78. input_data.set_dtype(ms.float32)
  79. np1 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  80. input1 = ms.Tensor(np1)
  81. np2 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  82. input2 = ms.Tensor(np2)
  83. _executor.compile(net, input_data, input1, input2)
  84. # Test case: test the global flag
  85. g_x = Tensor(np.ones([3, 3]).astype(np.float32))
  86. @ms_function
  87. def tensor_add_global(x):
  88. """ tensor_add_global """
  89. global g_x
  90. res = tensor_add(x, g_x)
  91. return res
  92. @non_graph_engine
  93. def test_global_flag():
  94. """ test_global_flag """
  95. log.debug("begin test_global_flag")
  96. x = Tensor(np.ones([3, 3]).astype(np.float32))
  97. res = tensor_add_global(x)
  98. log.debug("finished test_global_flag, ret = %r", res)
  99. class NetWithNDarray(nn.Cell):
  100. """ NetWithNDarray definition """
  101. def __init__(self, dim):
  102. super(NetWithNDarray, self).__init__()
  103. self.softmax = nn.Softmax(dim)
  104. self.x = ms.Tensor(np.ones(shape=(1)).astype(np.float32))
  105. def construct(self, input_data):
  106. return self.softmax(input_data) * self.x
  107. @non_graph_engine
  108. def test_net_with_ndarray():
  109. """ test_net_with_ndarray """
  110. net = NetWithNDarray(0)
  111. input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  112. net(ms.Tensor(input_data))
  113. def test_bprop_with_wrong_output_num():
  114. context.set_context(check_bprop=True)
  115. class BpropWithWrongOutputNum(PrimitiveWithInfer):
  116. @prim_attr_register
  117. def __init__(self):
  118. super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
  119. def __call__(self, x, y):
  120. return x
  121. def infer_shape(self, x_shape, yshape):
  122. return x_shape
  123. def infer_dtype(self, x_type, y_type):
  124. return x_type
  125. @bprop_getters.register(BpropWithWrongOutputNum)
  126. def get_bprop_with_wrong_output_num(self):
  127. """Generate bprop for BpropWithWrongOutputNum"""
  128. def bprop(x, y, out, dout):
  129. return (dout,)
  130. return bprop
  131. class BpropWithWrongOutputNumCell(nn.Cell):
  132. def __init__(self):
  133. super(BpropWithWrongOutputNumCell, self).__init__()
  134. def construct(self, x, y):
  135. return BpropWithWrongOutputNum()(x, y)
  136. with pytest.raises(ValueError):
  137. grad_all(BpropWithWrongOutputNumCell())(1, 2)
  138. def test_bprop_with_wrong_output_type():
  139. context.set_context(check_bprop=True)
  140. class BpropWithWrongOutputType(PrimitiveWithInfer):
  141. @prim_attr_register
  142. def __init__(self):
  143. super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
  144. def __call__(self, x):
  145. return x
  146. def infer_shape(self, x_shape):
  147. return x_shape
  148. def infer_dtype(self, x_type):
  149. return x_type
  150. @bprop_getters.register(BpropWithWrongOutputType)
  151. def get_bprop_with_wrong_output_type(self):
  152. """Generate bprop for BpropWithWrongOutputType"""
  153. def bprop(x, out, dout):
  154. return (1,)
  155. return bprop
  156. class BpropWithWrongOutputTypeCell(nn.Cell):
  157. def __init__(self):
  158. super(BpropWithWrongOutputTypeCell, self).__init__()
  159. def construct(self, x):
  160. return BpropWithWrongOutputType()(x)
  161. with pytest.raises(TypeError):
  162. grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
  163. def test_bprop_with_wrong_output_shape():
  164. context.set_context(check_bprop=True)
  165. class BpropWithWrongOutputShape(PrimitiveWithInfer):
  166. @prim_attr_register
  167. def __init__(self):
  168. super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
  169. def __call__(self, x):
  170. return x
  171. def infer_shape(self, x_shape):
  172. return x_shape
  173. def infer_dtype(self, x_type):
  174. return x_type
  175. @bprop_getters.register(BpropWithWrongOutputShape)
  176. def get_bprop_with_wrong_output_shape(self):
  177. """Generate bprop for BpropWithWrongOutputShape"""
  178. ones = Tensor(np.ones([2,]).astype(np.int32))
  179. def bprop(x, out, dout):
  180. return (ones,)
  181. return bprop
  182. class BpropWithWrongOutputShapeCell(nn.Cell):
  183. def __init__(self):
  184. super(BpropWithWrongOutputShapeCell, self).__init__()
  185. def construct(self, x):
  186. return BpropWithWrongOutputShape()(x)
  187. with pytest.raises(ValueError):
  188. net = BpropWithWrongOutputShapeCell()
  189. net.set_grad()
  190. grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))
  191. class AssignWhenInsertGrad(nn.Cell):
  192. """ NetWithNDarray definition """
  193. def __init__(self):
  194. super(AssignWhenInsertGrad, self).__init__()
  195. self.gather = P.GatherV2()
  196. self.damping = Tensor(np.array([0.03, 0.03]).astype(np.float32))
  197. self.cov_step = ms.Parameter(0, name="cov_step", requires_grad=False)
  198. self.freq = Tensor(278, ms.int32)
  199. self.getG = P.InsertGradientOf(self.save_gradient)
  200. def save_gradient(self, dout):
  201. self.cov_step = self.cov_step + self.freq
  202. return dout
  203. def construct(self, x):
  204. self.gather(self.damping, self.cov_step, 0)
  205. out = P.ReLU()(x)
  206. out = self.getG(out)
  207. return out
  208. grad_all = C.GradOperation(get_all=True)
  209. class GradNet(nn.Cell):
  210. def __init__(self, net):
  211. super(GradNet, self).__init__()
  212. self.net = net
  213. def construct(self, *inputs):
  214. out = self.net(*inputs)
  215. return out, grad_all(self.net)(*inputs)
  216. def test_assign_in_insert_grad():
  217. context.set_context(mode=context.GRAPH_MODE)
  218. net = AssignWhenInsertGrad().to_float(ms.float16)
  219. input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  220. net_back = GradNet(net)
  221. net_back(ms.Tensor(input_data))
  222. class Assign(nn.Cell):
  223. """ NetWithNDarray definition """
  224. def __init__(self):
  225. super(Assign, self).__init__()
  226. self.cov_step = ms.Parameter(0.0, name="cov_step", requires_grad=False)
  227. def construct(self, x):
  228. self.cov_step = self.cov_step + x
  229. return self.cov_step
  230. def test_assign():
  231. context.set_context(mode=context.GRAPH_MODE)
  232. net = Assign()
  233. input_data = ms.Tensor(np.array(1).astype(np.int32))
  234. net_back = GradNet(net)
  235. net_back(input_data)