You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph.cc 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "src/graph.h"
  17. #include <map>
  18. #include <algorithm>
  19. #include <memory>
  20. #include <utility>
  21. #include "schema/ms_generated.h"
  22. #include "common/graph_util.h"
  23. #include "common/mslog.h"
  24. #include "include/errorcode.h"
  25. #include "src/graph_execution.h"
  26. namespace mindspore {
  27. namespace predict {
  28. static const uint32_t G_MAX_OP_COUNT = 10000;
  29. Graph *Graph::CreateFromBuf(const char *buf, size_t size, const Context &ctx) {
  30. if (buf == nullptr) {
  31. MS_LOGE("the input buffer is nullptr");
  32. return nullptr;
  33. }
  34. flatbuffers::Verifier verify((const uint8_t *)buf, size);
  35. if (!VerifyGraphDefBuffer(verify)) {
  36. MS_LOGE("the buffer is invalid and fail to create graph");
  37. return nullptr;
  38. }
  39. auto graphDef = GetGraphDef(buf);
  40. std::unique_ptr<Graph> graph(new (std::nothrow) Graph());
  41. if (graph == nullptr) {
  42. MS_LOGE("graph malloc fail");
  43. return nullptr;
  44. }
  45. auto ret = graph->Build(*graphDef, ctx);
  46. if (ret != RET_OK) {
  47. MS_LOGE("build graph fail");
  48. return nullptr;
  49. }
  50. return graph.release();
  51. }
  52. Graph::Graph() = default;
  53. Graph::~Graph() {
  54. for (auto &subgraph : subgraphs) {
  55. delete subgraph;
  56. }
  57. subgraphs.clear();
  58. }
  59. int Graph::Build(const GraphDef &graphDef, const Context &ctx) {
  60. MS_ASSERT(graphDef.subgraphs() != nullptr);
  61. for (size_t i = 0; i < graphDef.subgraphs()->size(); i++) {
  62. MS_ASSERT(graphDef.subgraphs()->GetAs<SubGraphDef>(i) != nullptr);
  63. SubGraph *subGraph = SubGraph::CreateSubGraph(*(graphDef.subgraphs()->GetAs<SubGraphDef>(i)), ctx);
  64. if (subGraph == nullptr) {
  65. MS_LOGE("converter subgraph failed");
  66. return RET_ERROR;
  67. }
  68. subgraphs.push_back(subGraph);
  69. auto subDepends = subGraph->GetDepends();
  70. depends.insert(subDepends.begin(), subDepends.end());
  71. }
  72. auto iter = depends.begin();
  73. while (iter != depends.end()) {
  74. if (iter->second.empty()) {
  75. readyQue.push_back(iter->first);
  76. iter = depends.erase(iter);
  77. } else {
  78. iter++;
  79. }
  80. }
  81. return RET_OK;
  82. }
  83. std::vector<Tensor *> Graph::GetInputs() {
  84. MS_ASSERT(subgraphs.front() != nullptr);
  85. return subgraphs.front()->GetInputs();
  86. }
  87. std::vector<Tensor *> Graph::GetOutputs() {
  88. MS_ASSERT(subgraphs.back() != nullptr);
  89. return subgraphs.back()->GetOutputs();
  90. }
  91. std::map<NODE_ID, std::vector<Tensor *>> &Graph::GetOutputsMap() {
  92. MS_ASSERT(subgraphs.back() != nullptr);
  93. return subgraphs.back()->GetOutputsMap();
  94. }
  95. void Graph::FreeAllTensors() {
  96. for (auto iter : subgraphs) {
  97. iter->FreeAllTensors();
  98. }
  99. }
  100. std::vector<SubGraph *> *Graph::Subgraphs() { return &subgraphs; }
  101. SubGraph::SubGraph() = default;
  102. SubGraph::~SubGraph() {
  103. for (auto iter = nodes.begin(); iter != nodes.end();) {
  104. if (iter->second != nullptr) {
  105. delete iter->second;
  106. }
  107. iter = nodes.erase(iter);
  108. }
  109. nodes.clear();
  110. for (auto &allTensor : allTensors) {
  111. if (allTensor != nullptr) {
  112. delete allTensor;
  113. }
  114. }
  115. allTensors.clear();
  116. }
  117. SubGraph *SubGraph::CreateSubGraph(const SubGraphDef &subGraphDef, const Context &ctx) {
  118. std::unique_ptr<SubGraph> subGraph(new (std::nothrow) SubGraph());
  119. if (subGraph == nullptr) {
  120. MS_LOGE("subGraph malloc fail");
  121. return nullptr;
  122. }
  123. auto ret = subGraph->Build(subGraphDef, ctx);
  124. if (ret != RET_OK) {
  125. MS_LOGE("subGraph Build fail");
  126. return nullptr;
  127. }
  128. return subGraph.release();
  129. }
  130. int SubGraph::Build(const SubGraphDef &subGraphDef, const Context &ctx) {
  131. int ret;
  132. MS_ASSERT(subGraphDef.inputIndex() != nullptr);
  133. ret = ConverterIndex(*(subGraphDef.inputIndex()), &inputIndices);
  134. if (ret != RET_OK) {
  135. MS_LOGE("ConverterIndex failed: %d", ret);
  136. return ret;
  137. }
  138. MS_LOGD("converter inputIndex succ");
  139. MS_ASSERT(subGraphDef.outputIndex() != nullptr);
  140. ret = ConverterIndex(*(subGraphDef.outputIndex()), &outputIndices);
  141. if (ret != RET_OK) {
  142. MS_LOGE("ConverterIndex failed: %d", ret);
  143. return ret;
  144. }
  145. MS_LOGD("converter outputIndex succ");
  146. MS_ASSERT(subGraphDef.allTensors() != nullptr);
  147. ret = ConverterAllTensor(*(subGraphDef.allTensors()));
  148. if (ret != RET_OK) {
  149. MS_LOGE("ConverterAllTensor failed: %d", ret);
  150. return ret;
  151. }
  152. MS_LOGD("converter AllTensor succ");
  153. MS_ASSERT(subGraphDef.nodes() != nullptr);
  154. ret = ConverterNodes(*(subGraphDef.nodes()), ctx);
  155. if (ret != RET_OK) {
  156. MS_LOGE("ConverterNodes failed: %d", ret);
  157. return ret;
  158. }
  159. MS_LOGD("converter nodes succ");
  160. ret = ConverterEdges(subGraphDef);
  161. if (ret != RET_OK) {
  162. MS_LOGE("ConverterEdges failed: %d", ret);
  163. return ret;
  164. }
  165. MS_LOGD("converter edges succ");
  166. ret = InitOutputsMap();
  167. if (ret != RET_OK) {
  168. MS_LOGE("InitOutputsMap failed: %d", ret);
  169. return ret;
  170. }
  171. MS_LOGD("init outputs map succ");
  172. MS_LOGD("build graph succ");
  173. return RET_OK;
  174. }
  175. int SubGraph::ConverterIndex(const flatbuffers::Vector<uint32_t> &srcIndex, std::vector<uint32_t> *dstIndex) {
  176. if (dstIndex == nullptr) {
  177. MS_LOGE("input dstIndex is nullptr");
  178. return RET_PARAM_INVALID;
  179. }
  180. dstIndex->resize(srcIndex.size());
  181. std::copy(srcIndex.begin(), srcIndex.end(), dstIndex->begin());
  182. return RET_OK;
  183. }
  184. int SubGraph::ConverterAllTensor(const flatbuffers::Vector<flatbuffers::Offset<TensorDef>> &srcTensors) {
  185. uint32_t tensorsSize = srcTensors.size();
  186. allTensors.clear();
  187. allTensors.reserve(tensorsSize);
  188. for (uint32_t i = 0; i < tensorsSize; i++) {
  189. auto tensorDef = srcTensors.GetAs<TensorDef>(i);
  190. if (tensorDef == nullptr) {
  191. MS_LOGE("%ud th tensordef is null", i);
  192. return RET_ERROR;
  193. }
  194. auto tensor = Tensor::CopyFromTensorDef(*tensorDef);
  195. if (tensor == nullptr) {
  196. return RET_ERROR;
  197. }
  198. allTensors.push_back(tensor);
  199. }
  200. return RET_OK;
  201. }
  202. int SubGraph::ConverterNodes(const flatbuffers::Vector<flatbuffers::Offset<NodeDef>> &nodeDefs, const Context &ctx) {
  203. uint32_t opCount = nodeDefs.size();
  204. // for dfx
  205. if (opCount > G_MAX_OP_COUNT) {
  206. MS_LOGE("opCount(%u) bigger than maxOpCount(%u)", opCount, G_MAX_OP_COUNT);
  207. return RET_ERROR;
  208. }
  209. nodes.clear();
  210. for (uint32_t i = 0; i < opCount; i++) {
  211. auto nodeDef = nodeDefs.GetAs<NodeDef>(i);
  212. MS_ASSERT(nodeDef != nullptr);
  213. auto node = std::unique_ptr<Node>(new (std::nothrow) Node(nodeDef));
  214. if (node == nullptr) {
  215. MS_LOGE("new node failed");
  216. return RET_NULL_PTR;
  217. }
  218. node->SetTensors(*nodeDef, allTensors);
  219. auto ret = node->InitOp(*(nodeDef->opDef()), ctx);
  220. if (ret != RET_OK) {
  221. MS_LOGE("node (%s) InitOP failed. ret:%d", node->ID().c_str(), ret);
  222. return ret;
  223. }
  224. auto nodeId = node->ID();
  225. nodes[nodeId] = node.release();
  226. MS_LOGD("add node succ, id:%s", nodeId.c_str());
  227. }
  228. return RET_OK;
  229. }
  230. int SubGraph::ConverterEdges(const SubGraphDef &subGraphDef) {
  231. auto opGraph = OpGraph::Build(subGraphDef);
  232. if (opGraph == nullptr) {
  233. MS_LOGE("opGraph Build fail");
  234. return RET_ERROR;
  235. }
  236. for (auto nodeIter : nodes) {
  237. auto node = opGraph->GetNode(nodeIter.first);
  238. if (node == nullptr) {
  239. MS_LOGI("node %s not found", nodeIter.first.c_str());
  240. continue;
  241. }
  242. for (const auto &edge : node->GetAllInEdge()) {
  243. MS_ASSERT(nodeIter.second != nullptr);
  244. nodeIter.second->AddInEdge(GetNode(edge));
  245. }
  246. for (const auto &edge : node->GetAllOutEdge()) {
  247. MS_ASSERT(nodeIter.second != nullptr);
  248. nodeIter.second->AddOutEdge(GetNode(edge));
  249. }
  250. }
  251. delete opGraph;
  252. return RET_OK;
  253. }
  254. int SubGraph::InitOutputsMap() {
  255. if (nodes.empty()) {
  256. MS_LOGE("nodes are empty");
  257. return RET_ERROR;
  258. }
  259. for (auto node : nodes) {
  260. NODE_ID realNodeName = node.second->ID();
  261. MS_ASSERT(node.second != nullptr);
  262. if (node.second->GetAllOutEdges().empty()) {
  263. auto nodeType = node.second->Type();
  264. if (nodeType == "Nhwc2Nchw" || nodeType == "Nchw2Nhwc") {
  265. auto dependNode = *(this->GetDepends().at(this->GetNode(realNodeName)).begin());
  266. realNodeName = dependNode->ID();
  267. }
  268. this->outputsMap.emplace(
  269. std::pair<NODE_ID, std::vector<Tensor *>>(realNodeName, node.second->GetOutputTensors()));
  270. }
  271. }
  272. return RET_OK;
  273. }
  274. std::unordered_map<Node *, std::unordered_set<Node *>> SubGraph::GetDepends() {
  275. std::unordered_map<Node *, std::unordered_set<Node *>> depends;
  276. for (auto nodeIter : nodes) {
  277. MS_ASSERT(nodeIter.second != nullptr);
  278. depends[nodeIter.second] = nodeIter.second->GetAllInEdges();
  279. }
  280. return depends;
  281. }
  282. Node *SubGraph::GetNode(const NODE_ID &id) {
  283. auto node = nodes.find(id);
  284. if (node == nodes.end()) {
  285. return nullptr;
  286. }
  287. return node->second;
  288. }
  289. std::vector<Tensor *> SubGraph::GetInputs() {
  290. std::vector<Tensor *> inputTensor;
  291. inputTensor.resize(inputIndices.size());
  292. std::transform(inputIndices.begin(), inputIndices.end(), inputTensor.begin(),
  293. [this](int i) { return this->allTensors[i]; });
  294. return inputTensor;
  295. }
  296. std::vector<Tensor *> SubGraph::GetOutputs() {
  297. std::vector<Tensor *> outputTensor;
  298. outputTensor.resize(outputIndices.size());
  299. std::transform(outputIndices.begin(), outputIndices.end(), outputTensor.begin(),
  300. [this](int i) { return this->allTensors[i]; });
  301. return outputTensor;
  302. }
  303. std::map<NODE_ID, std::vector<Tensor *>> &SubGraph::GetOutputsMap() { return outputsMap; }
  304. void SubGraph::FreeAllTensors() {
  305. for (auto &allTensor : allTensors) {
  306. if (allTensor != nullptr) {
  307. auto refcount = allTensor->RefCount();
  308. if (refcount != MSConst_WEIGHT_REFCOUNT) {
  309. allTensor->DefRef(refcount);
  310. allTensor->FreeData();
  311. }
  312. }
  313. }
  314. }
  315. const std::vector<uint32_t> *SubGraph::GetInputIndices() const { return &inputIndices; }
  316. const std::vector<uint32_t> *SubGraph::GetOutputIndices() const { return &outputIndices; }
  317. bool SubGraph::IsInputIndex(uint32_t i) {
  318. auto iter = std::find(inputIndices.begin(), inputIndices.end(), i);
  319. return !(iter == inputIndices.end());
  320. }
  321. bool SubGraph::IsOutputIndex(uint32_t i) {
  322. auto iter = std::find(outputIndices.begin(), outputIndices.end(), i);
  323. return !(iter == outputIndices.end());
  324. }
  325. } // namespace predict
  326. } // namespace mindspore