You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parse.py 7.4 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_parse.py
  17. @Author:
  18. @Date : 2019-01-23 17:13
  19. @Desc :
  20. """
  21. import logging
  22. import pytest
  23. import numpy as np
  24. import mindspore as ms
  25. import mindspore.nn as nn
  26. from mindspore import Tensor
  27. from mindspore import context
  28. from mindspore.ops import composite as C
  29. from mindspore.common.api import ms_function, _executor
  30. from mindspore.ops._grad.grad_base import bprop_getters
  31. from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
  32. from mindspore.ops.functional import tensor_add
  33. from ...ut_filter import non_graph_engine
  34. # pylint: disable=W0613,W0612
  35. # W0613: unused-argument
  36. grad_all = C.GradOperation(get_all=True)
  37. log = logging.getLogger("test")
  38. log.setLevel(level=logging.ERROR)
  39. context.set_context(mode=context.GRAPH_MODE)
  40. # Test case: use the parse obj interface use default parameter
  41. class Net(nn.Cell):
  42. """ Net definition """
  43. def __init__(self, dim):
  44. super(Net, self).__init__()
  45. self.softmax1 = nn.Softmax(dim)
  46. self.softmax2 = nn.Softmax(dim + 1)
  47. def construct(self, input_data, input1=ms.Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))):
  48. return self.softmax1(input_data)
  49. @non_graph_engine
  50. def test_parse_defalut_parameter_case2():
  51. """ test_parse_defalut_parameter_case2 """
  52. log.debug("begin test_parse_defalut_parameter_case2")
  53. net = Net(0)
  54. npd = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  55. log.debug("input value is: %r", npd)
  56. input_data = ms.Tensor(npd)
  57. input_data.set_dtype(ms.float32)
  58. log.debug("start run")
  59. output = net(input_data)
  60. value = output.asnumpy()
  61. log.debug("output value = %r", value)
  62. # Test case: use the variable parameter for parse object
  63. class Net1(nn.Cell):
  64. """ Net1 definition """
  65. def __init__(self):
  66. super(Net1, self).__init__()
  67. def construct(self, *args):
  68. x = args[0]
  69. return x
  70. def test_var_parameter_case2():
  71. """ test_var_parameter_case2 """
  72. log.debug("begin test_var_parameter_case2")
  73. net = Net1()
  74. npd = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  75. log.debug("input value is: %r", npd)
  76. input_data = ms.Tensor(npd)
  77. input_data.set_dtype(ms.float32)
  78. np1 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  79. input1 = ms.Tensor(np1)
  80. np2 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  81. input2 = ms.Tensor(np2)
  82. _executor.compile(net, input_data, input1, input2)
  83. # Test case: test the global flag
  84. g_x = Tensor(np.ones([3, 3]).astype(np.float32))
  85. @ms_function
  86. def tensor_add_global(x):
  87. """ tensor_add_global """
  88. global g_x
  89. res = tensor_add(x, g_x)
  90. return res
  91. @non_graph_engine
  92. def test_global_flag():
  93. """ test_global_flag """
  94. log.debug("begin test_global_flag")
  95. x = Tensor(np.ones([3, 3]).astype(np.float32))
  96. res = tensor_add_global(x)
  97. log.debug("finished test_global_flag, ret = %r", res)
  98. class NetWithNDarray(nn.Cell):
  99. """ NetWithNDarray definition """
  100. def __init__(self, dim):
  101. super(NetWithNDarray, self).__init__()
  102. self.softmax = nn.Softmax(dim)
  103. self.x = ms.Tensor(np.ones(shape=(1)).astype(np.float32))
  104. def construct(self, input_data):
  105. return self.softmax(input_data) * self.x
  106. @non_graph_engine
  107. def test_net_with_ndarray():
  108. """ test_net_with_ndarray """
  109. net = NetWithNDarray(0)
  110. input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')
  111. net(ms.Tensor(input_data))
  112. def test_bprop_with_wrong_output_num():
  113. context.set_context(check_bprop=True)
  114. class BpropWithWrongOutputNum(PrimitiveWithInfer):
  115. @prim_attr_register
  116. def __init__(self):
  117. super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum')
  118. def __call__(self, x, y):
  119. return x
  120. def infer_shape(self, x_shape, yshape):
  121. return x_shape
  122. def infer_dtype(self, x_type, y_type):
  123. return x_type
  124. @bprop_getters.register(BpropWithWrongOutputNum)
  125. def get_bprop_with_wrong_output_num(self):
  126. """Generate bprop for BpropWithWrongOutputNum"""
  127. def bprop(x, y, out, dout):
  128. return (dout,)
  129. return bprop
  130. class BpropWithWrongOutputNumCell(nn.Cell):
  131. def __init__(self):
  132. super(BpropWithWrongOutputNumCell, self).__init__()
  133. def construct(self, x, y):
  134. return BpropWithWrongOutputNum()(x, y)
  135. with pytest.raises(ValueError):
  136. grad_all(BpropWithWrongOutputNumCell())(1, 2)
  137. def test_bprop_with_wrong_output_type():
  138. context.set_context(check_bprop=True)
  139. class BpropWithWrongOutputType(PrimitiveWithInfer):
  140. @prim_attr_register
  141. def __init__(self):
  142. super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType')
  143. def __call__(self, x):
  144. return x
  145. def infer_shape(self, x_shape):
  146. return x_shape
  147. def infer_dtype(self, x_type):
  148. return x_type
  149. @bprop_getters.register(BpropWithWrongOutputType)
  150. def get_bprop_with_wrong_output_type(self):
  151. """Generate bprop for BpropWithWrongOutputType"""
  152. def bprop(x, out, dout):
  153. return (1,)
  154. return bprop
  155. class BpropWithWrongOutputTypeCell(nn.Cell):
  156. def __init__(self):
  157. super(BpropWithWrongOutputTypeCell, self).__init__()
  158. def construct(self, x):
  159. return BpropWithWrongOutputType()(x)
  160. with pytest.raises(TypeError):
  161. grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32)))
  162. def test_bprop_with_wrong_output_shape():
  163. context.set_context(check_bprop=True)
  164. class BpropWithWrongOutputShape(PrimitiveWithInfer):
  165. @prim_attr_register
  166. def __init__(self):
  167. super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape')
  168. def __call__(self, x):
  169. return x
  170. def infer_shape(self, x_shape):
  171. return x_shape
  172. def infer_dtype(self, x_type):
  173. return x_type
  174. @bprop_getters.register(BpropWithWrongOutputShape)
  175. def get_bprop_with_wrong_output_shape(self):
  176. """Generate bprop for BpropWithWrongOutputShape"""
  177. ones = Tensor(np.ones([2,]).astype(np.int32))
  178. def bprop(x, out, dout):
  179. return (ones,)
  180. return bprop
  181. class BpropWithWrongOutputShapeCell(nn.Cell):
  182. def __init__(self):
  183. super(BpropWithWrongOutputShapeCell, self).__init__()
  184. def construct(self, x):
  185. return BpropWithWrongOutputShape()(x)
  186. with pytest.raises(ValueError):
  187. net = BpropWithWrongOutputShapeCell()
  188. net.set_grad()
  189. grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32)))