You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sparse_ops.py 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. # coding: utf-8
  2. # Copyright 2020 Huawei Technologies Co., Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ============================================================================
  16. """Operators for sparse operators."""
  17. from ..._checkparam import Validator as validator
  18. from ...common import dtype as mstype
  19. from ..primitive import PrimitiveWithInfer, prim_attr_register
  20. class SparseToDense(PrimitiveWithInfer):
  21. """
  22. Converts a sparse representation into a dense tensor.
  23. Inputs:
  24. - **indices** (Tensor) - The indices of sparse representation.
  25. - **values** (Tensor) - Values corresponding to each row of indices.
  26. - **dense_shape** (tuple) - An int tuple which specifies the shape of dense tensor.
  27. Returns:
  28. Tensor, the shape of tensor is `dense_shape`.
  29. Examples:
  30. >>> indices = Tensor([[0, 1], [1, 2]])
  31. >>> values = Tensor([1, 2], dtype=ms.float32)
  32. >>> dense_shape = (3, 4)
  33. >>> out = P.SparseToDense()(indices, values, dense_shape)
  34. """
  35. @prim_attr_register
  36. def __init__(self):
  37. """Initialize index_select"""
  38. self.init_prim_io_names(inputs=['indices', 'values', 'dense_shape'], outputs=['output'])
  39. def __infer__(self, indices, values, dense_shape):
  40. validator.check_subclass("indices", indices['dtype'], mstype.tensor, self.name)
  41. validator.check_subclass("values", values['dtype'], mstype.tensor, self.name)
  42. out = {'shape': dense_shape['value'],
  43. 'dtype': values['dtype'],
  44. 'value': None}
  45. return out