You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_print_op.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. from mindspore import Tensor
  18. import mindspore.nn as nn
  19. from mindspore.ops import operations as P
  20. import mindspore.context as context
  21. class PrintNetOneInput(nn.Cell):
  22. def __init__(self):
  23. super(PrintNetOneInput, self).__init__()
  24. self.op = P.Print()
  25. def construct(self, x):
  26. self.op(x)
  27. return x
  28. class PrintNetTwoInputs(nn.Cell):
  29. def __init__(self):
  30. super(PrintNetTwoInputs, self).__init__()
  31. self.op = P.Print()
  32. def construct(self, x, y):
  33. self.op(x, y)
  34. return x
  35. class PrintNetIndex(nn.Cell):
  36. def __init__(self):
  37. super(PrintNetIndex, self).__init__()
  38. self.op = P.Print()
  39. def construct(self, x):
  40. self.op(x[0][0][6][3])
  41. return x
  42. def print_testcase(nptype):
  43. # large shape
  44. x = np.arange(20808).reshape(6, 3, 34, 34).astype(nptype)
  45. # a value that can be stored as int8_t
  46. x[0][0][6][3] = 125
  47. # small shape
  48. y = np.arange(9).reshape(3, 3).astype(nptype)
  49. x = Tensor(x)
  50. y = Tensor(y)
  51. # graph mode
  52. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  53. net_1 = PrintNetOneInput()
  54. net_2 = PrintNetTwoInputs()
  55. net_3 = PrintNetIndex()
  56. net_1(x)
  57. net_2(x, y)
  58. net_3(x)
  59. class PrintNetString(nn.Cell):
  60. def __init__(self):
  61. super(PrintNetString, self).__init__()
  62. self.op = P.Print()
  63. def construct(self, x, y):
  64. self.op("The first Tensor is", x)
  65. self.op("The second Tensor is", y)
  66. self.op("This line only prints string", "Another line")
  67. self.op("The first Tensor is", x, y, "is the second Tensor")
  68. return x
  69. def print_testcase_string(nptype):
  70. x = np.ones(18).astype(nptype)
  71. y = np.arange(9).reshape(3, 3).astype(nptype)
  72. x = Tensor(x)
  73. y = Tensor(y)
  74. # graph mode
  75. context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
  76. net = PrintNetString()
  77. net(x, y)
  78. @pytest.mark.level0
  79. @pytest.mark.platform_x86_gpu_training
  80. @pytest.mark.env_onecard
  81. def test_print_bool():
  82. print_testcase(np.bool)
  83. @pytest.mark.level0
  84. @pytest.mark.platform_x86_gpu_training
  85. @pytest.mark.env_onecard
  86. def test_print_int8():
  87. print_testcase(np.int8)
  88. @pytest.mark.level0
  89. @pytest.mark.platform_x86_gpu_training
  90. @pytest.mark.env_onecard
  91. def test_print_int16():
  92. print_testcase(np.int16)
  93. @pytest.mark.level0
  94. @pytest.mark.platform_x86_gpu_training
  95. @pytest.mark.env_onecard
  96. def test_print_int32():
  97. print_testcase(np.int32)
  98. @pytest.mark.level0
  99. @pytest.mark.platform_x86_gpu_training
  100. @pytest.mark.env_onecard
  101. def test_print_int64():
  102. print_testcase(np.int64)
  103. @pytest.mark.level0
  104. @pytest.mark.platform_x86_gpu_training
  105. @pytest.mark.env_onecard
  106. def test_print_uint8():
  107. print_testcase(np.uint8)
  108. @pytest.mark.level0
  109. @pytest.mark.platform_x86_gpu_training
  110. @pytest.mark.env_onecard
  111. def test_print_uint16():
  112. print_testcase(np.uint16)
  113. @pytest.mark.level0
  114. @pytest.mark.platform_x86_gpu_training
  115. @pytest.mark.env_onecard
  116. def test_print_uint32():
  117. print_testcase(np.uint32)
  118. @pytest.mark.level0
  119. @pytest.mark.platform_x86_gpu_training
  120. @pytest.mark.env_onecard
  121. def test_print_uint64():
  122. print_testcase(np.uint64)
  123. @pytest.mark.level0
  124. @pytest.mark.platform_x86_gpu_training
  125. @pytest.mark.env_onecard
  126. def test_print_float16():
  127. print_testcase(np.float16)
  128. @pytest.mark.level0
  129. @pytest.mark.platform_x86_gpu_training
  130. @pytest.mark.env_onecard
  131. def test_print_float32():
  132. print_testcase(np.float32)
  133. @pytest.mark.level0
  134. @pytest.mark.platform_x86_gpu_training
  135. @pytest.mark.env_onecard
  136. def test_print_string():
  137. print_testcase_string(np.float32)