You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_hypermap.py 4.6 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_hypermap """
  16. import numpy as np
  17. from mindspore import Tensor
  18. from mindspore.common.api import ms_function
  19. from mindspore.ops import Primitive
  20. from mindspore.ops import composite as C
  21. from mindspore.ops import functional as F
  22. from mindspore.ops import operations as P
  23. from ...ut_filter import non_graph_engine
  24. # pylint: disable=W0613
  25. # W0613: unused-argument
  26. tensor_add = P.TensorAdd()
  27. scala_add = Primitive('scalar_add')
  28. add = C.MultitypeFuncGraph('add')
  29. @add.register("Number", "Number")
  30. def add_scala(x, y):
  31. return scala_add(x, y)
  32. @add.register("Tensor", "Tensor")
  33. def add_tensor(x, y):
  34. return tensor_add(x, y)
  35. hyper_add = C.HyperMap(add)
  36. @ms_function
  37. def mainf(x, y):
  38. return hyper_add(x, y)
  39. @non_graph_engine
  40. def test_hypermap_tensor():
  41. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  42. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  43. print("test_hypermap_tensor:", mainf(tensor1, tensor2))
  44. def test_hypermap_scalar():
  45. print("test_hypermap_scalar", mainf(1, 2))
  46. def test_hypermap_tuple():
  47. print("test_hypermap_tuple", mainf((1, 1), (2, 2)))
  48. @non_graph_engine
  49. def test_hypermap_tuple_tensor():
  50. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  51. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  52. print("test_hypermap_tuple_tensor", mainf((tensor1, tensor1), (tensor2, tensor2)))
  53. @non_graph_engine
  54. def test_hypermap_tuple_mix():
  55. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  56. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  57. print("test_hypermap_tuple_mix", mainf((tensor1, 1), (tensor2, 2)))
  58. hyper_map = C.HyperMap()
  59. @ms_function
  60. def main_noleaf(x, y):
  61. return hyper_map(add, x, y)
  62. def test_hypermap_noleaf_scalar():
  63. main_noleaf(1, 2)
  64. @non_graph_engine
  65. def test_hypermap_noleaf_tensor():
  66. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  67. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  68. main_noleaf(tensor1, tensor2)
  69. def test_hypermap_noleaf_tuple():
  70. main_noleaf((1, 1), (2, 2))
  71. @non_graph_engine
  72. def test_hypermap_noleaf_tuple_tensor():
  73. tensor1 = Tensor(np.array([[1.1, 2.1], [2.1, 3.1]]).astype('float32'))
  74. tensor2 = Tensor(np.array([[1.2, 2.2], [2.2, 3.2]]).astype('float32'))
  75. tensor3 = Tensor(np.array([[2.2], [3.2]]).astype('float32'))
  76. tensor4 = Tensor(np.array([[2.2], [3.2]]).astype('float32'))
  77. main_noleaf((tensor1, tensor3), (tensor2, tensor4))
  78. def test_hypermap_noleaf_tuple_mix():
  79. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  80. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  81. main_noleaf((tensor1, 1), (tensor2, 2))
  82. def add3_scalar(x, y, z):
  83. return scala_add(scala_add(x, y), z)
  84. @ms_function
  85. def main_add3_easy(x, y):
  86. add2 = F.partial(add3_scalar, 1)
  87. return add2(x, y)
  88. def test_hypermap_add3_easy():
  89. main_add3_easy(1, 2)
  90. add3 = C.MultitypeFuncGraph('add')
  91. partial = P.Partial()
  92. @add3.register("Number", "Number", "Number")
  93. def add3_scala(x, y, z):
  94. return scala_add(scala_add(x, y), z)
  95. @add3.register("Number", "Tensor", "Tensor")
  96. def add3_tensor(x, y, z):
  97. return tensor_add(y, z)
  98. @ms_function
  99. def main_add3_scala(x, y):
  100. add2 = partial(add3_scala, 1)
  101. return hyper_map(add2, x, y)
  102. @ms_function
  103. def main_add3(x, y):
  104. add2 = partial(add3, 1)
  105. return hyper_map(add2, x, y)
  106. @non_graph_engine
  107. def test_hypermap_add3_tensor():
  108. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  109. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  110. main_add3(tensor1, tensor2)
  111. def test_hypermap_add3_tuple():
  112. tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  113. tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32'))
  114. main_add3((tensor1, 1), (tensor2, 1))