You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hypermap.py 4.7 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_hypermap """
  16. import numpy as np
  17. from mindspore import Tensor
  18. from mindspore.common.api import ms_function
  19. from mindspore.ops import Primitive
  20. from mindspore.ops import _constants
  21. from mindspore.ops import composite as C
  22. from mindspore.ops import functional as F
  23. from mindspore.ops import operations as P
  24. from ...ut_filter import non_graph_engine
  25. # pylint: disable=W0613
  26. # W0613: unused-argument
  27. tensor_add = P.Add()
  28. scala_add = Primitive(_constants.kScalarAdd)
  29. add = C.MultitypeFuncGraph('add')
  30. @add.register("Number", "Number")
  31. def add_scala(x, y):
  32. return scala_add(x, y)
  33. @add.register("Tensor", "Tensor")
  34. def add_tensor(x, y):
  35. return tensor_add(x, y)
  36. hyper_add = C.HyperMap(add)
  37. @ms_function
  38. def mainf(x, y):
  39. return hyper_add(x, y)
  40. @non_graph_engine
  41. def test_hypermap_tensor():
  42. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  43. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  44. print("test_hypermap_tensor:", mainf(tensor1, tensor2))
  45. def test_hypermap_scalar():
  46. print("test_hypermap_scalar", mainf(1, 2))
  47. def test_hypermap_tuple():
  48. print("test_hypermap_tuple", mainf((1, 1), (2, 2)))
  49. @non_graph_engine
  50. def test_hypermap_tuple_tensor():
  51. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  52. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  53. print("test_hypermap_tuple_tensor", mainf((tensor1, tensor1), (tensor2, tensor2)))
  54. @non_graph_engine
  55. def test_hypermap_tuple_mix():
  56. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  57. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  58. print("test_hypermap_tuple_mix", mainf((tensor1, 1), (tensor2, 2)))
  59. hyper_map = C.HyperMap()
  60. @ms_function
  61. def main_noleaf(x, y):
  62. return hyper_map(add, x, y)
  63. def test_hypermap_noleaf_scalar():
  64. main_noleaf(1, 2)
  65. @non_graph_engine
  66. def test_hypermap_noleaf_tensor():
  67. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  68. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  69. main_noleaf(tensor1, tensor2)
  70. def test_hypermap_noleaf_tuple():
  71. main_noleaf((1, 1), (2, 2))
  72. @non_graph_engine
  73. def test_hypermap_noleaf_tuple_tensor():
  74. tensor1 = Tensor(np.array([[1.1, 2.1], [2.1, 3.1]]).astype('float32'))
  75. tensor2 = Tensor(np.array([[1.2, 2.2], [2.2, 3.2]]).astype('float32'))
  76. tensor3 = Tensor(np.array([[2.2], [3.2]]).astype('float32'))
  77. tensor4 = Tensor(np.array([[2.2], [3.2]]).astype('float32'))
  78. main_noleaf((tensor1, tensor3), (tensor2, tensor4))
  79. def test_hypermap_noleaf_tuple_mix():
  80. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  81. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  82. main_noleaf((tensor1, 1), (tensor2, 2))
  83. def add3_scalar(x, y, z):
  84. return scala_add(scala_add(x, y), z)
  85. @ms_function
  86. def main_add3_easy(x, y):
  87. add2 = F.partial(add3_scalar, 1)
  88. return add2(x, y)
  89. def test_hypermap_add3_easy():
  90. main_add3_easy(1, 2)
  91. add3 = C.MultitypeFuncGraph('add')
  92. partial = P.Partial()
  93. @add3.register("Number", "Number", "Number")
  94. def add3_scala(x, y, z):
  95. return scala_add(scala_add(x, y), z)
  96. @add3.register("Number", "Tensor", "Tensor")
  97. def add3_tensor(x, y, z):
  98. return tensor_add(y, z)
  99. @ms_function
  100. def main_add3_scala(x, y):
  101. add2 = partial(add3_scala, 1)
  102. return hyper_map(add2, x, y)
  103. @ms_function
  104. def main_add3(x, y):
  105. add2 = partial(add3, 1)
  106. return hyper_map(add2, x, y)
  107. @non_graph_engine
  108. def test_hypermap_add3_tensor():
  109. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  110. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  111. main_add3(tensor1, tensor2)
  112. def test_hypermap_add3_tuple():
  113. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  114. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  115. main_add3((tensor1, 1), (tensor2, 1))