You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_graph_fallback.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test numpy ops """
  16. import pytest
  17. import numpy as np
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, ms_function, context
  20. from mindspore.ops import operations as P
  21. from mindspore.ops import functional as F
  22. import mindspore.common.dtype as mstype
  23. import mindspore.common._monad as monad
  24. context.set_context(mode=context.GRAPH_MODE)
  25. # `add_func` is defined in current file.
  26. def add_func(x, y):
  27. return x + y
  28. @ms_function
  29. def do_increment(i):
  30. add_1 = F.partial(add_func, 1)
  31. return add_1(i)
  32. def test_increment():
  33. a = do_increment(9)
  34. assert a == 10
  35. @ms_function
  36. def use_monad(x, y):
  37. res = P.Mul()(x, y)
  38. res = F.depend(res, monad.U)
  39. return res
  40. def test_use_monad():
  41. x = Tensor(1.0, mstype.float32)
  42. y = Tensor(1.0, mstype.float32)
  43. print(use_monad(x, y))
  44. class Net(nn.Cell):
  45. def __init__(self):
  46. super(Net, self).__init__()
  47. self.x = Tensor([2, 3, 4])
  48. def construct(self):
  49. x_len = len(self.x)
  50. for i in range(x_len):
  51. print(i)
  52. return x_len
  53. def test_builtins_len():
  54. net = Net()
  55. net()
  56. @ms_function
  57. def np_fallback_func():
  58. array_x = tuple([2, 3, 4, 5])
  59. np_x = np.array(array_x).astype(np.float32)
  60. me_x = Tensor(np_x)
  61. me_x = me_x + me_x
  62. return me_x
  63. @pytest.mark.skip(reason='Not support graph fallback feature yet')
  64. def test_np_fallback_func():
  65. print(np_fallback_func())
  66. # Test `return` interpret node.
  67. @ms_function
  68. def div_mod_func1():
  69. x = 8
  70. y = 3
  71. a = divmod(x, y)
  72. return Tensor(a)
  73. @pytest.mark.skip(reason='Not support graph fallback feature yet')
  74. def test_div_mod_func1():
  75. print(div_mod_func1()) # (2, 2)
  76. # Test interpret node with parameters as input.
  77. @ms_function
  78. def div_mod_func2(x, y):
  79. a = divmod(x, y)
  80. return Tensor(a)
  81. @pytest.mark.skip(reason='Not support graph fallback feature yet')
  82. def test_div_mod_func2():
  83. print(div_mod_func2(8, 3)) # (2, 2)
  84. # NameError: name 'Tensor' is not defined.
  85. @ms_function
  86. def select_func(cond, x, y):
  87. if isinstance(cond, (tuple, list)):
  88. output = y
  89. elif isinstance(cond, Tensor):
  90. output = F.select(cond, x, y)
  91. else:
  92. output = x
  93. return output
  94. def test_select_func():
  95. cond = Tensor([True, False])
  96. x = Tensor([2, 3], mstype.float32)
  97. y = Tensor([1, 2], mstype.float32)
  98. print(select_func(cond, x, y))
  99. # Not interpret 'Tensor'.
  100. @ms_function
  101. def select_func2(cond, x, y):
  102. if isinstance(cond, (tuple, list)):
  103. output = y
  104. if isinstance(cond, Tensor):
  105. output = F.select(cond, x, y)
  106. else:
  107. output = x
  108. return output
  109. def test_select_func2():
  110. cond = Tensor([True, False])
  111. x = Tensor([2, 3], mstype.float32)
  112. y = Tensor([1, 2], mstype.float32)
  113. print(select_func2(cond, x, y))
  114. # NameError: name 'Tensor' is not defined.
  115. @ms_function
  116. def slice_func(a, b):
  117. a[1:3, ::] = b
  118. return a
  119. def test_slice_func():
  120. a = Tensor(np.arange(60).reshape(3, 4, 5), dtype=mstype.float32)
  121. b = Tensor([1], dtype=mstype.float32)
  122. print(slice_func(a, b))