You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.h 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef PREDICT_SRC_GRAPH_H_
  17. #define PREDICT_SRC_GRAPH_H_
  18. #include <map>
  19. #include <deque>
  20. #include <string>
  21. #include <unordered_map>
  22. #include <unordered_set>
  23. #include <vector>
  24. #include "common/utils.h"
  25. #include "common/graph_util.h"
  26. #include "include/tensor.h"
  27. #include "src/node.h"
  28. #define MSPREDICT_API __attribute__((visibility("default")))
  29. namespace mindspore {
  30. namespace predict {
  31. class SubGraph {
  32. public:
  33. SubGraph();
  34. ~SubGraph();
  35. static SubGraph *CreateSubGraph(const SubGraphDef &subGraphDef, const Context &ctx);
  36. int Build(const SubGraphDef &subGraphDef, const Context &ctx);
  37. bool IsInputIndex(uint32_t i);
  38. bool IsOutputIndex(uint32_t i);
  39. const std::vector<uint32_t> *GetInputIndices() const;
  40. const std::vector<uint32_t> *GetOutputIndices() const;
  41. std::vector<Tensor *> GetInputs();
  42. std::vector<Tensor *> GetOutputs();
  43. std::map<NODE_ID, std::vector<Tensor *>> &GetOutputsMap();
  44. void FreeAllTensors();
  45. Node *GetNode(const NODE_ID &id);
  46. std::unordered_map<Node *, std::unordered_set<Node *>> GetDepends();
  47. private:
  48. int ConverterIndex(const flatbuffers::Vector<uint32_t> &srcIndex, std::vector<uint32_t> *dstIndex);
  49. int ConverterAllTensor(const flatbuffers::Vector<flatbuffers::Offset<TensorDef>> &srcTensors);
  50. int ConverterNodes(const flatbuffers::Vector<flatbuffers::Offset<NodeDef>> &opDefs, const Context &ctx);
  51. int ConverterEdges(const SubGraphDef &subGraphDef);
  52. int InitOutputsMap();
  53. protected:
  54. std::unordered_map<NODE_ID, Node *> nodes;
  55. std::vector<uint32_t> inputIndices;
  56. std::vector<uint32_t> outputIndices;
  57. std::vector<Tensor *> allTensors; // weight + input + output
  58. std::map<NODE_ID, std::vector<Tensor *>> outputsMap;
  59. };
  60. class MSPREDICT_API Graph {
  61. public:
  62. Graph();
  63. ~Graph();
  64. static Graph *CreateFromBuf(const char *buf, size_t size, const Context &ctx);
  65. std::vector<Tensor *> GetInputs();
  66. std::vector<Tensor *> GetOutputs();
  67. std::map<NODE_ID, std::vector<Tensor *>> &GetOutputsMap();
  68. void FreeAllTensors();
  69. int Build(const GraphDef &def, const Context &ctx);
  70. std::vector<SubGraph *> *Subgraphs();
  71. protected:
  72. friend class GraphExecution;
  73. std::vector<SubGraph *> subgraphs;
  74. std::unordered_map<Node *, std::unordered_set<Node *>> depends; // records the dependencies
  75. std::deque<Node *> readyQue; // the nodes which can execute without any dependencies
  76. };
  77. } // namespace predict
  78. } // namespace mindspore
  79. #endif // PREDICT_SRC_GRAPH_H_