You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils.h 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_
  17. #define MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_
  18. #include <limits>
  19. #include <memory>
  20. #include <utility>
  21. #include <stack>
  22. #include <string>
  23. #include <vector>
  24. #include "utils/hash_map.h"
  25. #include "utils/hash_set.h"
  26. #include "utils/convert_utils_base.h"
  27. #include "utils/any.h"
  28. #include "base/base_ref.h"
  29. #include "base/core_ops.h"
  30. #include "base/base.h"
  31. #include "ir/anf.h"
  32. #include "ir/func_graph.h"
  33. namespace mindspore {
  34. namespace tensor {
  35. class Tensor;
  36. using TensorPtr = std::shared_ptr<Tensor>;
  37. } // namespace tensor
  38. bool BaseRefToBool(const BaseRef &in, bool *out);
  39. bool BaseRefToInt(const ValuePtr &v, int64_t *value);
  40. bool ValueToBool(const ValuePtr &in, bool *out);
  41. // Isomorphism
  42. struct PairHasher {
  43. template <class T1, class T2>
  44. std::size_t operator()(const std::pair<T1, T2> &p) const {
  45. auto h1 = std::hash<T1>{}(p.first);
  46. auto h2 = std::hash<T2>{}(p.second);
  47. return h1 ^ h2;
  48. }
  49. };
  50. enum EquivState { kNotEquiv = 0, kEquiv = 1, kPending = 2 };
  51. using FuncGraphPairMapEquiv = mindspore::HashMap<std::pair<FuncGraphPtr, FuncGraphPtr>, EquivState, PairHasher>;
  52. using NodeMapEquiv = mindspore::HashMap<AnfNodePtr, AnfNodePtr>;
  53. bool Isomorphic(const FuncGraphPtr &g1, const FuncGraphPtr &g2, FuncGraphPairMapEquiv *equiv_func_graph,
  54. NodeMapEquiv *equiv_node);
  55. tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar);
  56. template <typename T>
  57. std::vector<T> TensorValueToVector(const tensor::TensorPtr &tensor) {
  58. MS_EXCEPTION_IF_NULL(tensor);
  59. std::vector<T> value;
  60. auto element_size = tensor->data().size();
  61. auto *data = static_cast<T *>(tensor->data_c());
  62. for (auto i = 0; i < element_size; i++) {
  63. value.push_back(data[i]);
  64. }
  65. return value;
  66. }
  67. void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors);
  68. size_t CountValueNum(const ValueTuplePtr &value_tuple);
  69. // sparse_attr_map converts CNode{kPrimSparseGetAttr, SparseTensor}
  70. // to CNode{kPrimTupleGetItem, SparseTensor, int64_t(index)}, used
  71. // in backend common optimization pass: sparse_process.cc
  72. const mindspore::HashMap<std::string, int64_t> sparse_attr_map = {{prim::kPrimCSRTensorGetIndptr->name(), 0},
  73. {prim::kPrimCSRTensorGetIndices->name(), 1},
  74. {prim::kPrimCSRTensorGetValues->name(), 2},
  75. {prim::kPrimCSRTensorGetDenseShape->name(), 3}};
  76. // make_sparse_set records all make_sparse primitives, and tries to replace
  77. // make_sparse to make_tuple, used in backend common optimization pass:
  78. // sparse_process.cc
  79. const mindspore::HashSet<std::string> make_sparse_set = {{prim::kPrimMakeCSRTensor->name()}};
  80. // sparse_op_set records all sparse_compute operators, which takes sparsetensor
  81. // and (possibly) dense tensors, used in backend common optimization pass:
  82. // sparse_process.cc
  83. const mindspore::HashSet<std::string> sparse_op_set = {{prim::kPrimSparseTensorDenseMatmul->name()},
  84. {prim::kPrimCSRDenseMul->name()},
  85. {prim::kPrimCSRReduceSum->name()},
  86. {prim::kPrimCSRMV->name()},
  87. {prim::kPrimCSRMul->name()}};
  88. bool IsCustomCSROP(const AnfNodePtr &cnode);
  89. } // namespace mindspore
  90. #endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_