You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sparse_ops.py 10 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # coding: utf-8
  2. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ============================================================================
  16. """Operators for sparse operators."""
  17. from ..._checkparam import Validator as validator
  18. from ...common import dtype as mstype
  19. from ..primitive import PrimitiveWithInfer, prim_attr_register
  20. class SparseToDense(PrimitiveWithInfer):
  21. """
  22. Converts a sparse representation into a dense tensor.
  23. Inputs:
  24. - **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
  25. Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`.
  26. - **values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`.
  27. The shape should be :math:`(n,)`.
  28. - **sparse_shape** (tuple(int)) - A positive int tuple which specifies the shape of sparse tensor,
  29. should have 2 elements, represent sparse tensor shape is :math:`(N, C)`.
  30. Returns:
  31. Tensor, converted from sparse tensor. The dtype is same as `values`, and the shape is `sparse_shape`.
  32. Raises:
  33. TypeError: If the dtype of `indices` is neither int32 nor int64.
  34. ValueError: If `sparse_shape`, shape of `indices` and shape of `values` don't meet the parameter description.
  35. Supported Platforms:
  36. ``CPU``
  37. Examples:
  38. >>> indices = Tensor([[0, 1], [1, 2]])
  39. >>> values = Tensor([1, 2], dtype=ms.float32)
  40. >>> sparse_shape = (3, 4)
  41. >>> sparse_to_dense = ops.SparseToDense()
  42. >>> out = sparse_to_dense(indices, values, sparse_shape)
  43. >>> print(out)
  44. [[0 1 0 0]
  45. [0 0 2 0]
  46. [0 0 0 0]]
  47. """
  48. @prim_attr_register
  49. def __init__(self):
  50. """Initialize SparseToDense."""
  51. self.init_prim_io_names(inputs=['indices', 'values', 'dense_shape'], outputs=['output'])
  52. def __infer__(self, indices, values, sparse_shape):
  53. validator.check_tensor_dtype_valid('indices', indices['dtype'], [mstype.int32, mstype.int64], self.name)
  54. validator.check_tensor_dtype_valid('values', values['dtype'], mstype.number_type + (mstype.bool_,), self.name)
  55. indices_shape = indices['shape']
  56. if len(indices_shape) != 2:
  57. raise ValueError(f"For '{self.name}', the 'indices' must be a 2-D tensor, "
  58. f"but got 'indices' shape: {indices_shape}.")
  59. values_shape = values['shape']
  60. if len(values_shape) != 1 or values_shape[0] != indices_shape[0]:
  61. raise ValueError(f"For '{self.name}', the 'values' must be a 1-D tensor and the first dimension length "
  62. f"must be equal to the first dimension length of 'indices', "
  63. f"but got 'indices' shape: {indices_shape}, 'values' shape: {values_shape}.")
  64. sparse_shape_v = sparse_shape['value']
  65. for i in sparse_shape_v:
  66. if isinstance(i, bool) or not isinstance(i, int) or i <= 0:
  67. raise ValueError(f"For '{self.name}', all elements in 'sparse_shape' must be "
  68. f"positive int number, but got 'sparse_shape': {sparse_shape_v}.")
  69. if len(sparse_shape_v) != indices_shape[1]:
  70. raise ValueError(f"For '{self.name}', the length of 'sparse_shape' should be equal to the second dimension "
  71. f"length of 'indices', but got the second dimension length of 'indices': "
  72. f"{indices_shape[1]}, length of 'sparse_shape': {len(sparse_shape_v)}.")
  73. out = {'shape': sparse_shape['value'],
  74. 'dtype': values['dtype'],
  75. 'value': None}
  76. return out
  77. class SparseTensorDenseMatmul(PrimitiveWithInfer):
  78. """
  79. Multiplies sparse matrix `A` by dense matrix `B`.
  80. The rank of sparse matrix and dense matrix must be equal to `2`.
  81. Args:
  82. adjoint_st (bool): If true, sparse tensor is transposed before multiplication. Default: False.
  83. adjoint_dt (bool): If true, dense tensor is transposed before multiplication. Default: False.
  84. Inputs:
  85. - **indices** (Tensor) - A 2-D Tensor, represents the position of the element in the sparse tensor.
  86. Support int32, int64, each element value should be a non-negative int number. The shape is :math:`(n, 2)`.
  87. - **values** (Tensor) - A 1-D Tensor, represents the value corresponding to the position in the `indices`.
  88. Support float16, float32, float64, int32, int64. The shape should be :math:`(n,)`.
  89. - **sparse_shape** (tuple(int)) - A positive int tuple which specifies the shape of sparse tensor,
  90. should have 2 elements, represent sparse tensor shape is :math:`(N, C)`.
  91. - **dense** (Tensor) - A 2-D Tensor, the dtype is same as `values`.
  92. If `adjoint_st` is False and `adjoint_dt` is False, the shape must be :math:`(C, M)`.
  93. If `adjoint_st` is False and `adjoint_dt` is True, the shape must be :math:`(M, C)`.
  94. If `adjoint_st` is True and `adjoint_dt` is False, the shape must be :math:`(N, M)`.
  95. If `adjoint_st` is True and `adjoint_dt` is True, the shape must be :math:`(M, N)`.
  96. Outputs:
  97. Tensor, the dtype is the same as `values`.
  98. If `adjoint_st` is False, the shape is :math:`(N, M)`.
  99. If `adjoint_st` is True, the shape is :math:`(C, M)`.
  100. Raises:
  101. TypeError: If the type of `adjoint_st` or `adjoint_dt` is not bool, or the dtype of `indices`,
  102. dtype of `values` and dtype of `dense` don't meet the parameter description.
  103. ValueError: If `sparse_shape`, shape of `indices`, shape of `values`,
  104. and shape of `dense` don't meet the parameter description.
  105. Supported Platforms:
  106. ``CPU``
  107. Examples:
  108. >>> indices = Tensor([[0, 1], [1, 2]], dtype=ms.int32)
  109. >>> values = Tensor([1, 2], dtype=ms.float32)
  110. >>> sparse_shape = (3, 4)
  111. >>> dense = Tensor([[1,1], [2,2], [3,3 ], [4, 4]], dtype=ms.float32)
  112. >>> sparse_dense_matmul = ops.SparseTensorDenseMatmul()
  113. >>> out = sparse_dense_matmul(indices, values, sparse_shape, dense)
  114. >>> print(out)
  115. [[2 2]
  116. [6 6]
  117. [0 0]]
  118. """
  119. @prim_attr_register
  120. def __init__(self, adjoint_st=False, adjoint_dt=False):
  121. """Initialize SparseTensorDenseMatmul"""
  122. self.adjoint_st = adjoint_st
  123. self.adjoint_dt = adjoint_dt
  124. self.init_prim_io_names(inputs=['indices', 'values', 'sparse_shape', 'dense'],
  125. outputs=['output'])
  126. self.add_prim_attr('adjoint_st', self.adjoint_st)
  127. self.add_prim_attr('adjoint_dt', self.adjoint_dt)
  128. validator.check_value_type("adjoint_st", adjoint_st, [bool], self.name)
  129. validator.check_value_type("adjoint_dt", adjoint_dt, [bool], self.name)
  130. def __infer__(self, indices, values, sparse_shape, dense):
  131. validator.check_tensor_dtype_valid('indices', indices['dtype'], [mstype.int32, mstype.int64], self.name)
  132. valid_types = (mstype.float16, mstype.float32, mstype.float64, mstype.int32, mstype.int64)
  133. args = {'values': values['dtype'], 'dense': dense['dtype']}
  134. validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
  135. indices_shape = indices['shape']
  136. if len(indices_shape) != 2 or indices_shape[1] != 2:
  137. raise ValueError(f"For '{self.name}', the 'indices' must be a 2-D tensor and "
  138. f"the second dimension length must be 2, but got 'indices' shape: {indices_shape}.")
  139. values_shape = values['shape']
  140. if len(values_shape) != 1 or values_shape[0] != indices_shape[0]:
  141. raise ValueError(f"For '{self.name}', the 'values' must be a 1-D tensor and "
  142. f"the first dimension length must be equal to the first dimension length of 'indices', "
  143. f"but got 'indices' shape: {indices_shape}, 'values' shape: {values_shape}.")
  144. a_shape = sparse_shape['value'][::-1] if self.adjoint_st else sparse_shape['value']
  145. b_shape = dense['shape'][::-1] if self.adjoint_dt else dense['shape']
  146. for i in a_shape:
  147. if isinstance(i, bool) or not isinstance(i, int) or i <= 0:
  148. raise ValueError(f"For '{self.name}', all elements in 'sparse_shape' must be "
  149. f"positive int number, but got 'sparse_shape': {a_shape}.")
  150. if len(a_shape) != 2 or len(b_shape) != 2:
  151. raise ValueError(f"For '{self.name}', both the length of 'sparse_shape' and the tensor "
  152. f"rank of 'dense' should be equal to 2, but got the length of "
  153. f"'sparse_shape': {len(a_shape)}, "
  154. f"the tensor rank of 'dense': {len(b_shape)}.")
  155. if a_shape[1] != b_shape[0]:
  156. raise ValueError(f"For '{self.name}', the second dimension length of 'sparse_shape' must be equal to the "
  157. f"first dimension length of 'dense', but got "
  158. f"the tensor shape of 'sparse': {a_shape} and the tensor shape of 'dense': {b_shape}. "
  159. f"Don't meet the condition for matmul")
  160. out_shape = [a_shape[0], b_shape[1]]
  161. out = {'shape': tuple(out_shape),
  162. 'dtype': values['dtype'],
  163. 'value': None}
  164. return out