You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_util.h 3.5 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_PREDICT_TENSOR_UTIL_H
  17. #define MINDSPORE_PREDICT_TENSOR_UTIL_H
  18. #include <cmath>
  19. #include <unordered_map>
  20. #include <memory>
  21. #include <utility>
  22. #include <string>
  23. #include <vector>
  24. #include "schema/inner/model_generated.h"
  25. #include "src/common/log_adapter.h"
  26. #include "ir/dtype/type_id.h"
  27. namespace mindspore {
  28. namespace lite {
  29. using schema::CNodeT;
  30. using schema::Format;
  31. using schema::FusedBatchNormT;
  32. using schema::MetaGraphT;
  33. using schema::QuantParamT;
  34. using schema::TensorT;
  35. using schema::Format::Format_NCHW;
  36. using schema::Format::Format_NHWC;
  37. using STATUS = int;
  38. std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor);
  39. size_t GetElementSize(const TensorT &tensor);
  40. size_t GetElementSize(const TypeId &dataType);
  41. size_t GetShapeSize(const TensorT &tensor);
  42. size_t GetShapeSize(const std::vector<int32_t> &shape);
  43. std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &);
  44. size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx);
  45. std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schema::QuantParamT> &srcQuantParam);
  46. std::unique_ptr<schema::QuantParamT> CopyQuantParamArrayT(
  47. const std::unique_ptr<schema::QuantParamT> &srcQuantParamArray);
  48. using MSGraphDefTPtr = std::shared_ptr<schema::MetaGraphT>;
  49. enum Category { CONST = 0, GRAPH_INPUT = 1, OP_OUTPUT = 2, TF_CONST = 3 };
  50. class TensorCache {
  51. public:
  52. TensorCache() {}
  53. ~TensorCache() { tensors.clear(); }
  54. int AddTensor(const std::string &name, TensorT *tensor, int Category) {
  55. index++;
  56. if (Category == CONST || Category == TF_CONST || Category == GRAPH_INPUT) {
  57. tensor->refCount = 1;
  58. tensor->nodeType = schema::NodeType_ValueNode;
  59. } else {
  60. tensor->nodeType = schema::NodeType_Parameter;
  61. }
  62. tensors.push_back(tensor);
  63. if (Category == GRAPH_INPUT) {
  64. graphInputs.push_back(index);
  65. }
  66. if (Category == GRAPH_INPUT || Category == OP_OUTPUT || Category == TF_CONST) {
  67. UpdateTensorIndex(name, index);
  68. }
  69. return index;
  70. }
  71. // find the name index
  72. int FindTensor(const std::string &name) {
  73. auto iter = tensorIndex.find(name);
  74. if (iter != tensorIndex.end()) {
  75. return iter->second;
  76. }
  77. return -1;
  78. }
  79. void UpdateTensorIndex(const std::string &name, int index) {
  80. auto iter = tensorIndex.find(name);
  81. if (iter != tensorIndex.end()) {
  82. tensorIndex[name] = index;
  83. } else {
  84. tensorIndex.insert(make_pair(name, index));
  85. }
  86. }
  87. // return allTensors
  88. const std::vector<TensorT *> &GetCachedTensor() const { return tensors; }
  89. const std::vector<int> &GetGraphInputs() const { return graphInputs; }
  90. private:
  91. std::vector<TensorT *> tensors;
  92. std::unordered_map<std::string, int> tensorIndex;
  93. std::vector<int> graphInputs;
  94. int index = -1;
  95. };
  96. } // namespace lite
  97. } // namespace mindspore
  98. #endif // MINDSPORE_PREDICT_TENSOR_UTIL_H