You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_create_obj.py 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. @File : test_create_obj.py
  17. @Author:
  18. @Date : 2019-06-26
  19. @Desc : test create object instance on parse function, eg: 'construct'
  20. Support class : nn.Cell ops.Primitive
  21. Support parameter: type is define on function 'ValuePtrToPyData'
  22. (int,float,string,bool,tensor)
  23. """
  24. import logging
  25. import numpy as np
  26. import mindspore.nn as nn
  27. from mindspore import context
  28. from mindspore.ops import operations as P
  29. from mindspore.common.api import ms_function
  30. from mindspore.common.tensor import Tensor
  31. from ...ut_filter import non_graph_engine
  32. log = logging.getLogger("test")
  33. log.setLevel(level=logging.ERROR)
  34. class Net(nn.Cell):
  35. """ Net definition """
  36. def __init__(self):
  37. super(Net, self).__init__()
  38. self.softmax = nn.Softmax(0)
  39. self.axis = 0
  40. def construct(self, x):
  41. x = nn.Softmax(self.axis)(x)
  42. return x
  43. # Test: creat CELL OR Primitive instance on construct
  44. @non_graph_engine
  45. def test_create_cell_object_on_construct():
  46. """ test_create_cell_object_on_construct """
  47. log.debug("begin test_create_object_on_construct")
  48. context.set_context(mode=context.GRAPH_MODE)
  49. np1 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  50. input_me = Tensor(np1)
  51. net = Net()
  52. output = net(input_me)
  53. out_me1 = output.asnumpy()
  54. print(np1)
  55. print(out_me1)
  56. log.debug("finished test_create_object_on_construct")
  57. # Test: creat CELL OR Primitive instance on construct
  58. class Net1(nn.Cell):
  59. """ Net1 definition """
  60. def __init__(self):
  61. super(Net1, self).__init__()
  62. self.add = P.TensorAdd()
  63. @ms_function
  64. def construct(self, x, y):
  65. add = P.TensorAdd()
  66. result = add(x, y)
  67. return result
  68. @non_graph_engine
  69. def test_create_primitive_object_on_construct():
  70. """ test_create_primitive_object_on_construct """
  71. log.debug("begin test_create_object_on_construct")
  72. x = Tensor(np.array([[1, 2, 3], [1, 2, 3]], np.float32))
  73. y = Tensor(np.array([[2, 3, 4], [1, 1, 2]], np.float32))
  74. net = Net1()
  75. net.construct(x, y)
  76. log.debug("finished test_create_object_on_construct")
  77. # Test: creat CELL OR Primitive instance on construct use many parameter
  78. class NetM(nn.Cell):
  79. """ NetM definition """
  80. def __init__(self, name, axis):
  81. super(NetM, self).__init__()
  82. # self.relu = nn.ReLU()
  83. self.name = name
  84. self.axis = axis
  85. self.softmax = nn.Softmax(self.axis)
  86. def construct(self, x):
  87. x = self.softmax(x)
  88. return x
  89. class NetC(nn.Cell):
  90. """ NetC definition """
  91. def __init__(self, tensor):
  92. super(NetC, self).__init__()
  93. self.tensor = tensor
  94. def construct(self, x):
  95. x = NetM("test", 1)(x)
  96. return x
  97. # Test: creat CELL OR Primitive instance on construct
  98. @non_graph_engine
  99. def test_create_cell_object_on_construct_use_many_parameter():
  100. """ test_create_cell_object_on_construct_use_many_parameter """
  101. log.debug("begin test_create_object_on_construct")
  102. context.set_context(mode=context.GRAPH_MODE)
  103. np1 = np.random.randn(2, 3, 4, 5).astype(np.float32)
  104. input_me = Tensor(np1)
  105. net = NetC(input_me)
  106. output = net(input_me)
  107. out_me1 = output.asnumpy()
  108. print(np1)
  109. print(out_me1)
  110. log.debug("finished test_create_object_on_construct")