You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_csr.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """smoke tests for CSR operations"""
  16. import pytest
  17. import numpy as np
  18. from mindspore import Tensor, CSRTensor, ms_function
  19. from mindspore.common import dtype as mstype
  20. from mindspore import nn, context
  21. context.set_context(mode=context.GRAPH_MODE)
  22. def compare_csr(csr1, csr2):
  23. assert isinstance(csr1, CSRTensor)
  24. assert isinstance(csr2, CSRTensor)
  25. assert (csr1.indptr.asnumpy() == csr2.indptr.asnumpy()).all()
  26. assert (csr1.indices.asnumpy() == csr2.indices.asnumpy()).all()
  27. assert (csr1.values.asnumpy() == csr2.values.asnumpy()).all()
  28. assert csr1.shape == csr2.shape
  29. @pytest.mark.level0
  30. @pytest.mark.platform_arm_ascend_training
  31. @pytest.mark.platform_x86_ascend_training
  32. @pytest.mark.platform_x86_gpu_training
  33. @pytest.mark.platform_x86_cpu
  34. @pytest.mark.env_onecard
  35. def test_make_csr():
  36. """
  37. Feature: Test CSRTensor Constructor in Graph and PyNative.
  38. Description: Test CSRTensor(indptr, indices, values, shape) and CSRTensor(CSRTensor)
  39. Expectation: Success.
  40. """
  41. indptr = Tensor([0, 1, 2])
  42. indices = Tensor([0, 1])
  43. values = Tensor([1, 2], dtype=mstype.float32)
  44. shape = (2, 6)
  45. def test_pynative():
  46. return CSRTensor(indptr, indices, values, shape)
  47. test_graph = ms_function(test_pynative)
  48. csr1 = test_pynative()
  49. csr2 = test_graph()
  50. compare_csr(csr1, csr2)
  51. csr3 = CSRTensor(csr_tensor=csr2)
  52. compare_csr(csr3, csr2)
  53. @pytest.mark.level0
  54. @pytest.mark.platform_arm_ascend_training
  55. @pytest.mark.platform_x86_ascend_training
  56. @pytest.mark.platform_x86_gpu_training
  57. @pytest.mark.platform_x86_cpu
  58. @pytest.mark.env_onecard
  59. def test_csr_attr():
  60. """
  61. Feature: Test CSRTensor GetAttr in Graph and PyNative.
  62. Description: Test CSRTensor.indptr, CSRTensor.indices, CSRTensor.values, CSRTensor.shape.
  63. Expectation: Success.
  64. """
  65. indptr = Tensor([0, 1, 2])
  66. indices = Tensor([0, 1])
  67. values = Tensor([1, 2], dtype=mstype.float32)
  68. shape = (2, 6)
  69. def test_pynative():
  70. csr = CSRTensor(indptr, indices, values, shape)
  71. return csr.indptr, csr.indices, csr.values, csr.shape
  72. test_graph = ms_function(test_pynative)
  73. csr1_tuple = test_pynative()
  74. csr2_tuple = test_graph()
  75. csr1 = CSRTensor(*csr1_tuple)
  76. csr2 = CSRTensor(*csr2_tuple)
  77. compare_csr(csr1, csr2)
  78. @pytest.mark.level0
  79. @pytest.mark.platform_arm_ascend_training
  80. @pytest.mark.platform_x86_ascend_training
  81. @pytest.mark.platform_x86_gpu_training
  82. @pytest.mark.platform_x86_cpu
  83. @pytest.mark.env_onecard
  84. def test_csr_tensor_in_while():
  85. """
  86. Feature: Test CSRTensor in while loop.
  87. Description: Test CSRTensor computation in while loop.
  88. Expectation: Success.
  89. """
  90. class CSRTensorValuesDouble(nn.Cell):
  91. def construct(self, x):
  92. indptr = x.indptr
  93. indices = x.indices
  94. values = x.values * 2
  95. shape = x.shape
  96. return CSRTensor(indptr, indices, values, shape)
  97. class CSRTensorValuesAdd2(nn.Cell):
  98. def construct(self, x):
  99. indptr = x.indptr
  100. indices = x.indices
  101. values = x.values + 2
  102. shape = x.shape
  103. return CSRTensor(indptr, indices, values, shape)
  104. class CSRTensorWithControlWhile(nn.Cell):
  105. def __init__(self, shape):
  106. super().__init__()
  107. self.op1 = CSRTensorValuesDouble()
  108. self.op2 = CSRTensorValuesAdd2()
  109. self.shape = shape
  110. @ms_function
  111. def construct(self, a, b, indptr, indices, values):
  112. x = CSRTensor(indptr, indices, values, self.shape)
  113. x = self.op2(x)
  114. while a > b:
  115. x = self.op1(x)
  116. b = b + 1
  117. return x
  118. a = Tensor(3, mstype.int32)
  119. b = Tensor(0, mstype.int32)
  120. indptr = Tensor([0, 1, 2])
  121. indices = Tensor([0, 1])
  122. values = Tensor([1, 2], dtype=mstype.float32)
  123. shape = (2, 6)
  124. net = CSRTensorWithControlWhile(shape)
  125. out = net(a, b, indptr, indices, values)
  126. assert np.allclose(out.indptr.asnumpy(), indptr.asnumpy(), .0, .0)
  127. assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
  128. assert np.allclose((values.asnumpy() + 2) * 8, out.values.asnumpy(), .0, .0)
  129. assert shape == out.shape