You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sparse_ops.py 5.7 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # coding: utf-8
  2. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ============================================================================
  16. """Operators for sparse operators."""
  17. from ..._checkparam import Validator as validator
  18. from ...common import dtype as mstype
  19. from ..primitive import PrimitiveWithInfer, prim_attr_register
  20. class SparseToDense(PrimitiveWithInfer):
  21. """
  22. Converts a sparse representation into a dense tensor.
  23. Inputs:
  24. - **indices** (Tensor) - The indices of sparse representation.
  25. - **values** (Tensor) - Values corresponding to each row of indices.
  26. - **dense_shape** (tuple) - An int tuple which specifies the shape of dense tensor.
  27. Returns:
  28. Tensor, the shape of tensor is `dense_shape`.
  29. Supported Platforms:
  30. ``CPU``
  31. Examples:
  32. >>> indices = Tensor([[0, 1], [1, 2]])
  33. >>> values = Tensor([1, 2], dtype=ms.float32)
  34. >>> dense_shape = (3, 4)
  35. >>> out = ops.SparseToDense()(indices, values, dense_shape)
  36. """
  37. @prim_attr_register
  38. def __init__(self):
  39. """Initialize index_select"""
  40. self.init_prim_io_names(inputs=['indices', 'values', 'dense_shape'], outputs=['output'])
  41. def __infer__(self, indices, values, dense_shape):
  42. validator.check_subclass("indices", indices['dtype'], mstype.tensor, self.name)
  43. validator.check_subclass("values", values['dtype'], mstype.tensor, self.name)
  44. out = {'shape': dense_shape['value'],
  45. 'dtype': values['dtype'],
  46. 'value': None}
  47. return out
  48. class SparseTensorDenseMatmul(PrimitiveWithInfer):
  49. """
  50. Multiply SparseTensor(of rank 2) "A" by dense tensor.
  51. The shape of sparse tensor is :math:`(N, C)`, and the shape of dense tensor is :math:`(C, M)`, then the shape of
  52. output tensor is :math:`(N, M)`.The output data type is the same as "values".
  53. tensors.
  54. Args:
  55. - *adjoint_st** (Bool) - If true, SparseTensor is transposed before multiplication. Default: False.
  56. - *adjoint_dt** (Bool) - If true, DenseTensor is transposed before multiplication. Default: False.
  57. Inputs:
  58. - **indices** (Tensor) - The indices of sparse representation, support int32/int64.
  59. - **values** (Tensor) - Values corresponding to each row of indices.
  60. - **dense_shape** (tuple) - An int tuple which specifies the shape of dense tensor. The dense_shape is :
  61. math:`(N, C)`. If `adjoint_st` is True, its shape must be :math:`(N, C)` after transpose.
  62. - **dense** (Tensor) - Dense Matrix. The shape of the tensor is :math:`(C, M)`. If
  63. `adjoint_dt` is True, its shape must be :math:`(C, M)` after transpose.
  64. Outputs:
  65. Tensor, the shape of tensor is :math:`(N, M)`. The output data type is the same as "values".
  66. Raises:
  67. TypeError: If `indices` is neither int32 nor int64.
  68. TypeError: If 'values' is not boot, uint8-64, int8-64, float16-64.
  69. TypeError: If 'dense' is not boot, uint8-64, int8-64, float16-64.
  70. ValueError: If length of shape of `SparseTensor` or `DenseTensor` is not equal to 2
  71. Supported Platforms:
  72. ``CPU``
  73. Examples:
  74. >>> indices = Tensor([[0, 1], [1, 2]], dtype=ms.int32)
  75. >>> values = Tensor([1, 2], dtype=ms.float32)
  76. >>> dense_shape = (3, 4)
  77. >>> dsMatrix = Tensor([[1,1], [2,2], [3,3 ], [4, 4]], dtype=ms.float32)
  78. >>> out = ops.SparseTensorDenseMatmul(indices, values, dense_shape, dsMatrix)
  79. """
  80. @prim_attr_register
  81. def __init__(self, adjoint_st=False, adjoint_dt=False):
  82. """Initialize SparseTensorDenseMatmul"""
  83. self.adjoint_st = adjoint_st
  84. self.adjoint_dt = adjoint_dt
  85. self.init_prim_io_names(inputs=['indices', 'values', 'dense_shape', 'dense'],
  86. outputs=['output'])
  87. self.add_prim_attr('adjoint_st', self.adjoint_st)
  88. self.add_prim_attr('adjoint_dt', self.adjoint_dt)
  89. validator.check_value_type("adjoint_st", adjoint_st, [bool], self.name)
  90. validator.check_value_type("adjoint_dt", adjoint_dt, [bool], self.name)
  91. def __infer__(self, indices, values, dense_shape, dense):
  92. validator.check_tensor_dtype_valid('indices', indices['dtype'], [mstype.int32, mstype.int64], self.name)
  93. valid_types = mstype.number_type + (mstype.bool_,)
  94. args = {'values': values['dtype'], 'dense': dense['dtype']}
  95. validator.check_tensors_dtypes_same_and_valid(args, valid_types, self.name)
  96. a_shape = dense_shape['value']
  97. b_shape = dense['shape']
  98. if len(a_shape) != 2 or len(b_shape) != 2:
  99. raise ValueError('SparseTensorDenseMatmul SparseTensor, DenseTensor should have the same dimension size '
  100. + f'and equal to 2, while SparseTensor size is ({len(a_shape)}) and DenseTensor size is '
  101. + f'({len(b_shape)}).')
  102. out_shape = []
  103. out_shape.append(a_shape[0])
  104. out_shape.append(b_shape[1])
  105. out = {'shape': tuple(out_shape),
  106. 'dtype': values['dtype'],
  107. 'value': None}
  108. return out