You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_coo.py 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """smoke tests for COO operations"""
  16. import pytest
  17. import numpy as np
  18. from mindspore import Tensor, COOTensor, ms_function, nn, context
  19. from mindspore.common import dtype as mstype
  20. context.set_context(mode=context.GRAPH_MODE)
  21. def compare_coo(coo1, coo2):
  22. assert isinstance(coo1, COOTensor)
  23. assert isinstance(coo2, COOTensor)
  24. assert (coo1.indices.asnumpy() == coo2.indices.asnumpy()).all()
  25. assert (coo1.values.asnumpy() == coo2.values.asnumpy()).all()
  26. assert coo1.shape == coo2.shape
  27. @pytest.mark.level0
  28. @pytest.mark.platform_arm_ascend_training
  29. @pytest.mark.platform_x86_ascend_training
  30. @pytest.mark.platform_x86_gpu_training
  31. @pytest.mark.env_onecard
  32. def test_make_coo():
  33. """
  34. Feature: Test COOTensor Constructor in Graph and PyNative.
  35. Description: Test COOTensor(indices, values, shape) and COOTensor(COOTensor)
  36. Expectation: Success.
  37. """
  38. indices = Tensor([[0, 1], [1, 2]])
  39. values = Tensor([1, 2], dtype=mstype.float32)
  40. dense_shape = (3, 4)
  41. def test_pynative():
  42. return COOTensor(indices, values, dense_shape)
  43. test_graph = ms_function(test_pynative)
  44. coo1 = test_pynative()
  45. coo2 = test_graph()
  46. compare_coo(coo1, coo2)
  47. coo3 = COOTensor(coo_tensor=coo2)
  48. compare_coo(coo3, coo2)
  49. @pytest.mark.level0
  50. @pytest.mark.platform_arm_ascend_training
  51. @pytest.mark.platform_x86_ascend_training
  52. @pytest.mark.platform_x86_gpu_training
  53. @pytest.mark.env_onecard
  54. def test_coo_tensor_in_while():
  55. """
  56. Feature: Test COOTensor in while loop.
  57. Description: Test COOTensor computation in while loop.
  58. Expectation: Success.
  59. """
  60. class COOTensorWithControlWhile(nn.Cell):
  61. def __init__(self, shape):
  62. super().__init__()
  63. self.shape = shape
  64. @ms_function
  65. def construct(self, a, b, indices, values):
  66. x = COOTensor(indices, values, self.shape)
  67. while a > b:
  68. x = COOTensor(indices, values, self.shape)
  69. b = b + 1
  70. return x
  71. a = Tensor(3, mstype.int32)
  72. b = Tensor(0, mstype.int32)
  73. indices = Tensor([[0, 1], [1, 2]])
  74. values = Tensor([1, 2], dtype=mstype.float32)
  75. shape = (3, 4)
  76. net = COOTensorWithControlWhile(shape)
  77. out = net(a, b, indices, values)
  78. assert np.allclose(out.indices.asnumpy(), indices.asnumpy(), .0, .0)
  79. assert np.allclose(out.values.asnumpy(), values.asnumpy(), .0, .0)
  80. assert out.shape == shape