You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_create_obj.py 6.6 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_create_obj.py
  17. @Author:
  18. @Date : 2019-06-26
  19. @Desc : test create object instance on parse function, eg: 'construct'
  20. Support class : nn.Cell ops.Primitive
  21. Support parameter: type is define on function 'ValuePtrToPyData'
  22. (int,float,string,bool,tensor)
  23. """
  24. import logging
  25. import numpy as np
  26. import pytest
  27. import mindspore.nn as nn
  28. from mindspore import context, ops, dtype
  29. from mindspore.common.api import ms_function
  30. from mindspore.common import Tensor, Parameter
  31. from mindspore.ops import operations as P
  32. from ...ut_filter import non_graph_engine
  33. log = logging.getLogger("test")
  34. log.setLevel(level=logging.ERROR)
  35. class Net(nn.Cell):
  36. """ Net definition """
  37. def __init__(self):
  38. super(Net, self).__init__()
  39. self.softmax = nn.Softmax(0)
  40. self.axis = 0
  41. def construct(self, x):
  42. x = nn.Softmax(self.axis)(x)
  43. return x
  44. # Test: Create Cell OR Primitive instance on construct
  45. @non_graph_engine
  46. def test_create_cell_object_on_construct():
  47. """ test_create_cell_object_on_construct """
  48. log.debug("begin test_create_object_on_construct")
  49. context.set_context(mode=context.GRAPH_MODE)
  50. np1 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  51. input_me = Tensor(np1)
  52. net = Net()
  53. output = net(input_me)
  54. out_me1 = output.asnumpy()
  55. print(np1)
  56. print(out_me1)
  57. log.debug("finished test_create_object_on_construct")
  58. # Test: Create Cell OR Primitive instance on construct
  59. class Net1(nn.Cell):
  60. """ Net1 definition """
  61. def __init__(self):
  62. super(Net1, self).__init__()
  63. self.add = P.Add()
  64. @ms_function
  65. def construct(self, x, y):
  66. add = P.Add()
  67. result = add(x, y)
  68. return result
  69. @non_graph_engine
  70. def test_create_primitive_object_on_construct():
  71. """ test_create_primitive_object_on_construct """
  72. log.debug("begin test_create_object_on_construct")
  73. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]], np.float32))
  74. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]], np.float32))
  75. net = Net1()
  76. net.construct(x, y)
  77. log.debug("finished test_create_object_on_construct")
  78. # Test: Create Cell OR Primitive instance on construct use many parameter
  79. class NetM(nn.Cell):
  80. """ NetM definition """
  81. def __init__(self, name, axis):
  82. super(NetM, self).__init__()
  83. # self.relu = nn.ReLU()
  84. self.name = name
  85. self.axis = axis
  86. self.softmax = nn.Softmax(self.axis)
  87. def construct(self, x):
  88. x = self.softmax(x)
  89. return x
  90. class NetC(nn.Cell):
  91. """ NetC definition """
  92. def __init__(self, tensor):
  93. super(NetC, self).__init__()
  94. self.tensor = tensor
  95. def construct(self, x):
  96. x = NetM("test", 1)(x)
  97. return x
  98. # Test: Create Cell OR Primitive instance on construct
  99. @non_graph_engine
  100. def test_create_cell_object_on_construct_use_many_parameter():
  101. """ test_create_cell_object_on_construct_use_many_parameter """
  102. log.debug("begin test_create_object_on_construct")
  103. context.set_context(mode=context.GRAPH_MODE)
  104. np1 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  105. input_me = Tensor(np1)
  106. net = NetC(input_me)
  107. output = net(input_me)
  108. out_me1 = output.asnumpy()
  109. print(np1)
  110. print(out_me1)
  111. log.debug("finished test_create_object_on_construct")
  112. class NetD(nn.Cell):
  113. """ NetD definition """
  114. def construct(self, x, y):
  115. concat = P.Concat(axis=1)
  116. return concat((x, y))
  117. # Test: Create Cell OR Primitive instance on construct
  118. @non_graph_engine
  119. def test_create_primitive_object_on_construct_use_kwargs():
  120. """ test_create_primitive_object_on_construct_use_kwargs """
  121. log.debug("begin test_create_primitive_object_on_construct_use_kwargs")
  122. context.set_context(mode=context.GRAPH_MODE)
  123. x = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
  124. y = Tensor(np.array([[0, 1], [2, 1]]).astype(np.float32))
  125. net = NetD()
  126. net(x, y)
  127. log.debug("finished test_create_primitive_object_on_construct_use_kwargs")
  128. class NetE(nn.Cell):
  129. """ NetE definition """
  130. def __init__(self):
  131. super(NetE, self).__init__()
  132. self.w = Parameter(Tensor(np.ones([16, 16, 3, 3]).astype(np.float32)), name='w')
  133. def construct(self, x):
  134. out_channel = 16
  135. kernel_size = 3
  136. conv2d = P.Conv2D(out_channel,
  137. kernel_size,
  138. 1,
  139. pad_mode='valid',
  140. pad=0,
  141. stride=1,
  142. dilation=1,
  143. group=1)
  144. return conv2d(x, self.w)
  145. # Test: Create Cell OR Primitive instance on construct
  146. @non_graph_engine
  147. def test_create_primitive_object_on_construct_use_args_and_kwargs():
  148. """ test_create_primitive_object_on_construct_use_args_and_kwargs """
  149. log.debug("begin test_create_primitive_object_on_construct_use_args_and_kwargs")
  150. context.set_context(mode=context.GRAPH_MODE)
  151. inputs = Tensor(np.ones([1, 16, 16, 16]).astype(np.float32))
  152. net = NetE()
  153. net(inputs)
  154. log.debug("finished test_create_primitive_object_on_construct_use_args_and_kwargs")
  155. # Test: Create Cell instance in construct
  156. class SubCell(nn.Cell):
  157. def __init__(self, t):
  158. super(SubCell, self).__init__()
  159. self.t = t
  160. def construct(self):
  161. return ops.typeof(self.t)
  162. class WrapCell(nn.Cell):
  163. def construct(self, t):
  164. type_0 = ops.typeof(t)
  165. type_1 = SubCell(t)()
  166. return type_0, type_1
  167. def test_create_cell_with_tensor():
  168. """
  169. Feature: Raise exception while create Cell(that init use tensor input) in construct.
  170. Description: None
  171. Expectation: TypeError.
  172. """
  173. t = Tensor(np.zeros((2, 2), np.float), dtype.float32)
  174. with pytest.raises(TypeError):
  175. print(WrapCell()(t))