You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_utils.h 3.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_
  17. #define MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_
  18. #include <limits>
  19. #include <memory>
  20. #include <utility>
  21. #include <stack>
  22. #include <string>
  23. #include <vector>
  24. #include "utils/hash_map.h"
  25. #include "utils/hash_set.h"
  26. #include "utils/convert_utils_base.h"
  27. #include "utils/any.h"
  28. #include "base/base_ref.h"
  29. #include "base/core_ops.h"
  30. #include "base/base.h"
  31. #include "ir/anf.h"
  32. #include "ir/func_graph.h"
  33. namespace mindspore {
  34. namespace tensor {
  35. class Tensor;
  36. using TensorPtr = std::shared_ptr<Tensor>;
  37. } // namespace tensor
  38. bool BaseRefToBool(const BaseRef &in, bool *out);
  39. bool BaseRefToInt(const ValuePtr &v, int64_t *value);
  40. bool ValueToBool(const ValuePtr &in, bool *out);
  41. // Isomorphism
  42. struct PairHasher {
  43. template <class T1, class T2>
  44. std::size_t operator()(const std::pair<T1, T2> &p) const {
  45. auto h1 = std::hash<T1>{}(p.first);
  46. auto h2 = std::hash<T2>{}(p.second);
  47. return h1 ^ h2;
  48. }
  49. };
  50. enum EquivState { kNotEquiv = 0, kEquiv = 1, kPending = 2 };
  51. using FuncGraphPairMapEquiv = mindspore::HashMap<std::pair<FuncGraphPtr, FuncGraphPtr>, EquivState, PairHasher>;
  52. using NodeMapEquiv = mindspore::HashMap<AnfNodePtr, AnfNodePtr>;
  53. bool Isomorphic(const FuncGraphPtr &g1, const FuncGraphPtr &g2, FuncGraphPairMapEquiv *equiv_func_graph,
  54. NodeMapEquiv *equiv_node);
  55. tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar);
  56. template <typename T>
  57. std::vector<T> TensorValueToVector(const tensor::TensorPtr &tensor) {
  58. MS_EXCEPTION_IF_NULL(tensor);
  59. std::vector<T> value;
  60. auto element_size = tensor->data().size();
  61. auto *data = static_cast<T *>(tensor->data_c());
  62. for (auto i = 0; i < element_size; i++) {
  63. value.push_back(data[i]);
  64. }
  65. return value;
  66. }
  67. void TensorValueToTensor(const ValuePtr &value, std::vector<tensor::TensorPtr> *tensors);
  68. size_t CountValueNum(const ValueTuplePtr &value_tuple);
  69. // sparse_attr_map converts CNode{kPrimSparseGetAttr, SparseTensor}
  70. // to CNode{kPrimTupleGetItem, SparseTensor, int64_t(index)}, used
  71. // in backend common optimization pass: sparse_process.cc
  72. const mindspore::HashMap<std::string, int64_t> sparse_attr_map = {{prim::kPrimCSRTensorGetIndptr->name(), 0},
  73. {prim::kPrimCSRTensorGetIndices->name(), 1},
  74. {prim::kPrimCSRTensorGetValues->name(), 2},
  75. {prim::kPrimCSRTensorGetDenseShape->name(), 3}};
  76. // make_sparse_set records all make_sparse primitives, and tries to replace
  77. // make_sparse to make_tuple, used in backend common optimization pass:
  78. // sparse_process.cc
  79. const mindspore::HashSet<std::string> make_sparse_set = {
  80. {prim::kPrimMakeCSRTensor->name()}, {prim::kPrimMakeSparseTensor->name()}, {prim::kPrimMakeRowTensor->name()}};
  81. } // namespace mindspore
  82. #endif // MINDSPORE_CCSRC_UTILS_CONVERT_UTILS_H_