You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils.h 4.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. /**
  2. * Copyright 2019-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONVERT_UTILS_H_
  17. #define MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONVERT_UTILS_H_
  18. #include <limits>
  19. #include <memory>
  20. #include <utility>
  21. #include <stack>
  22. #include <string>
  23. #include <vector>
  24. #include "utils/hash_map.h"
  25. #include "utils/hash_set.h"
  26. #include "utils/convert_utils_base.h"
  27. #include "utils/any.h"
  28. #include "base/base_ref.h"
  29. #include "base/core_ops.h"
  30. #include "base/base.h"
  31. #include "ir/anf.h"
  32. #include "ir/func_graph.h"
  33. #include "include/common/visible.h"
  34. namespace mindspore {
  35. namespace tensor {
  36. class Tensor;
  37. } // namespace tensor
  38. COMMON_EXPORT bool BaseRefToBool(const BaseRef &in, bool *out);
  39. COMMON_EXPORT bool BaseRefToInt(const ValuePtr &v, int64_t *value);
  40. COMMON_EXPORT bool ValueToBool(const ValuePtr &in, bool *out);
  41. // Isomorphism
  42. struct PairHasher {
  43. template <class T1, class T2>
  44. std::size_t operator()(const std::pair<T1, T2> &p) const {
  45. auto h1 = std::hash<T1>{}(p.first);
  46. auto h2 = std::hash<T2>{}(p.second);
  47. return h1 ^ h2;
  48. }
  49. };
  50. enum EquivState { kNotEquiv = 0, kEquiv = 1, kPending = 2 };
  51. using FuncGraphPairMapEquiv = mindspore::HashMap<std::pair<FuncGraphPtr, FuncGraphPtr>, EquivState, PairHasher>;
  52. using NodeMapEquiv = mindspore::HashMap<AnfNodePtr, AnfNodePtr>;
  53. COMMON_EXPORT bool Isomorphic(const FuncGraphPtr &g1, const FuncGraphPtr &g2, FuncGraphPairMapEquiv *equiv_func_graph,
  54. NodeMapEquiv *equiv_node);
  55. COMMON_EXPORT tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar);
  56. template <typename T>
  57. std::vector<T> TensorValueToVector(const tensor::TensorPtr &tensor) {
  58. MS_EXCEPTION_IF_NULL(tensor);
  59. std::vector<T> value;
  60. auto element_size = tensor->data().size();
  61. auto *data = static_cast<T *>(tensor->data_c());
  62. for (auto i = 0; i < element_size; i++) {
  63. value.push_back(data[i]);
  64. }
  65. return value;
  66. }
  67. COMMON_EXPORT void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors);
  68. COMMON_EXPORT ValuePtr ShallowCopyTensorValue(const ValuePtr &value);
  69. COMMON_EXPORT size_t CountValueNum(const ValueTuplePtr &value_tuple);
  70. // sparse_attr_map converts CNode{kPrimSparseGetAttr, SparseTensor}
  71. // to CNode{kPrimTupleGetItem, SparseTensor, int64_t(index)}, used
  72. // in backend common optimization pass: sparse_process.cc
  73. const mindspore::HashMap<std::string, int64_t> sparse_attr_map = {
  74. {prim::kCSRTensorGetIndptr, 0}, {prim::kCSRTensorGetIndices, 1}, {prim::kCSRTensorGetValues, 2},
  75. {prim::kCSRTensorGetDenseShape, 3}, {prim::kCOOTensorGetIndices, 0}, {prim::kCOOTensorGetValues, 1},
  76. {prim::kCOOTensorGetDenseShapes, 2}};
  77. // make_sparse_set records all make_sparse primitives, and tries to replace
  78. // make_sparse to make_tuple, used in backend common optimization pass:
  79. // sparse_process.cc
  80. const mindspore::HashSet<std::string> make_sparse_set = {{prim::kMakeCSRTensor}, {prim::kMakeCOOTensor}};
  81. // sparse_op_set records all sparse_compute operators, which takes sparsetensor
  82. // and (possibly) dense tensors, used in backend common optimization pass:
  83. // sparse_process.cc
  84. const mindspore::HashSet<std::string> sparse_op_set = {{prim::kSparseTensorDenseMatmul},
  85. {prim::kCSRReduceSum},
  86. {prim::kCSRMV},
  87. {prim::kCSRMul},
  88. {prim::kCSRGather},
  89. {prim::kCSR2COO},
  90. {prim::kCSRDiv}};
  91. COMMON_EXPORT bool IsAKGSparseOP(const AnfNodePtr &cnode);
  92. } // namespace mindspore
  93. #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_UTILS_CONVERT_UTILS_H_