You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_execution.cc 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "src/graph_execution.h"
  17. #include <utility>
  18. #include <vector>
  19. #include <memory>
  20. namespace mindspore {
  21. namespace predict {
  22. GraphExecution::GraphExecution(const Context &ctx) : graph(nullptr), _ctx(ctx) {}
  23. GraphExecution::GraphExecution(const Context &ctx, Graph *staticGraph) : _ctx(ctx) {
  24. graph = staticGraph;
  25. if (graph != nullptr) {
  26. depends = graph->depends;
  27. readyQue = graph->readyQue;
  28. outputTensors = graph->GetOutputs();
  29. inputTensors = graph->GetInputs();
  30. }
  31. }
  32. GraphExecution::~GraphExecution() = default;
  33. int GraphExecution::TransInputDataToNc4hw4(const Tensor &src, Tensor *dst) {
  34. MS_ASSERT(dst != nullptr);
  35. if (dst->GetData() == nullptr) {
  36. auto ret = dst->MallocData(nullptr, MSConst_WEIGHT_REFCOUNT);
  37. if (ret != RET_OK) {
  38. MS_LOGE("Malloc inputTensors failed: %d", ret);
  39. return ret;
  40. }
  41. }
  42. auto ret = NchwToNc4hw4(&src, dst);
  43. if (ret != RET_OK) {
  44. MS_LOGE("NchwToNc4hw4 failed");
  45. return ret;
  46. }
  47. return RET_OK;
  48. }
  49. int GraphExecution::SetInputTensors(const std::vector<Tensor *> &inputs) {
  50. size_t num = inputs.size();
  51. if (num != inputTensors.size()) {
  52. MS_LOGE("input num %zu != model input num %zu", num, inputTensors.size());
  53. return RET_INPUT_TENSOR_ERROR;
  54. }
  55. for (size_t i = 0; i < num; i++) {
  56. MS_ASSERT(inputs[i] != nullptr);
  57. // The input Tensor desc must be equivalent with the model tensor
  58. if (inputs[i]->GetData() == nullptr) {
  59. MS_LOGE("input tensor data is null!");
  60. return RET_INPUT_TENSOR_ERROR;
  61. }
  62. if (inputTensors[i] == nullptr) {
  63. MS_LOGE("inputTensors[%zu] is nullptr", i);
  64. return RET_ERROR;
  65. }
  66. if (!inputs[i]->CompareShape(*inputTensors[i])) {
  67. MS_LOGE("tensor shape in graph and executor are different!");
  68. return RET_INPUT_TENSOR_ERROR;
  69. }
  70. if (inputs[i]->GetDataType() != inputTensors[i]->GetDataType()) {
  71. MS_LOGE("tensor datatype in graph and executor are different!");
  72. return RET_INPUT_TENSOR_ERROR;
  73. }
  74. if (inputs[i]->GetFormat() != Format_NCHW) {
  75. MS_LOGE("input format not support. only nchw is supported now");
  76. return RET_INPUT_TENSOR_ERROR;
  77. }
  78. if (inputs[i]->GetFormat() == inputTensors[i]->GetFormat()) {
  79. auto data = inputs[i]->GetData();
  80. if (data == nullptr) {
  81. MS_LOGE("data of input tensor is null!");
  82. return RET_INPUT_TENSOR_ERROR;
  83. }
  84. inputTensors[i]->SetData(data);
  85. } else if (inputTensors[i]->GetFormat() == Format_NC4HW4) {
  86. auto ret = TransInputDataToNc4hw4(*inputs[i], inputTensors[i]);
  87. if (ret != RET_OK) {
  88. MS_LOGE("TransInputDataToNc4hw4 failed");
  89. return ret;
  90. }
  91. } else {
  92. MS_LOGE("graphDef inputTensors format is invalid: %d", inputTensors[i]->GetFormat());
  93. return RET_ERROR;
  94. }
  95. }
  96. return RET_OK;
  97. }
  98. int GraphExecution::MallocOutput() {
  99. for (auto tensor : outputTensors) {
  100. auto ret = tensor->MallocData();
  101. if (ret != RET_OK) {
  102. MS_LOGE("malloc output data failed");
  103. return RET_ERROR;
  104. }
  105. }
  106. return RET_OK;
  107. }
  108. void GraphExecution::FreeTensors(std::vector<Tensor *> *tensors) {
  109. for (auto &tensor : (*tensors)) {
  110. delete tensor;
  111. }
  112. tensors->clear();
  113. }
  114. void GraphExecution::FreeOutputMap(std::map<NODE_ID, std::vector<Tensor *>> *map) {
  115. MS_ASSERT(map != nullptr);
  116. for (auto &m : *map) {
  117. FreeTensors(&(m.second));
  118. }
  119. map->clear();
  120. }
  121. int GraphExecution::CopyOutputTensors(const std::vector<Tensor *> &refOutputs, std::vector<Tensor *> *outputs) {
  122. for (auto tensor : refOutputs) {
  123. if (tensor == nullptr) {
  124. MS_LOGE("tensor in refOutputs is nullptr");
  125. return RET_INPUT_TENSOR_ERROR;
  126. }
  127. std::unique_ptr<Tensor> t(new Tensor(*tensor));
  128. if (t == nullptr) {
  129. MS_LOGE("new Tensor failed.");
  130. if (outputs != nullptr) {
  131. FreeTensors(outputs);
  132. }
  133. return RET_ERROR;
  134. }
  135. if (tensor->GetFormat() == Format_NC4HW4) {
  136. t->SetFormat(Format_NCHW);
  137. auto ret = t->MallocData();
  138. if (ret != RET_OK) {
  139. MS_LOGE("malloc data failed.")
  140. FreeTensors(outputs);
  141. return ret;
  142. }
  143. ret = Nc4hw4ToNchw(tensor, t.get());
  144. if (ret != RET_OK) {
  145. MS_LOGE("Nc4hw4ToNchw failed");
  146. return ret;
  147. }
  148. tensor->FreeData();
  149. } else {
  150. t->SetData(tensor->GetData());
  151. tensor->SetData(nullptr);
  152. }
  153. outputs->push_back(t.release());
  154. }
  155. return RET_OK;
  156. }
  157. std::map<NODE_ID, std::vector<Tensor *>> GraphExecution::GetAllOutput() {
  158. std::map<NODE_ID, std::vector<Tensor *>> outputs{};
  159. for (auto &outputNode : graph->GetOutputsMap()) {
  160. std::vector<Tensor *> outputNodeTensors{};
  161. auto ret = this->CopyOutputTensors(outputNode.second, &outputNodeTensors);
  162. if (ret != RET_OK) {
  163. MS_LOGE("copy output failed.");
  164. FreeOutputMap(&outputs);
  165. return outputs;
  166. }
  167. outputs.emplace(std::pair<NODE_ID, std::vector<Tensor *>>(outputNode.first, outputNodeTensors));
  168. }
  169. return outputs;
  170. }
  171. std::vector<Tensor *> GraphExecution::GetOutput(const NODE_ID &nodeName) {
  172. std::vector<Tensor *> outputNodeTensors{};
  173. auto iter = graph->GetOutputsMap().find(nodeName);
  174. if (iter == graph->GetOutputsMap().end()) {
  175. MS_LOGE("node name is not in output.");
  176. return outputNodeTensors;
  177. }
  178. auto ret = this->CopyOutputTensors(iter->second, &outputNodeTensors);
  179. if (ret != RET_OK) {
  180. MS_LOGE("copy output failed.");
  181. }
  182. return outputNodeTensors;
  183. }
  184. std::vector<Tensor *> GraphExecution::GetInput() {
  185. std::vector<Tensor *> inputs{};
  186. for (auto refInput : graph->GetInputs()) {
  187. if (refInput == nullptr) {
  188. MS_LOGE("tensor from graph->GetInputs() is nullptr");
  189. return inputs;
  190. }
  191. std::unique_ptr<Tensor> t(new Tensor(refInput->GetDataType(), refInput->GetDims(), Format_NCHW, nullptr));
  192. if (t == nullptr) {
  193. MS_LOGE("new Tensor failed.")
  194. FreeTensors(&inputs);
  195. return inputs;
  196. }
  197. inputs.push_back(t.release());
  198. }
  199. return inputs;
  200. }
  201. void GraphExecution::ResetInputData() {
  202. for (auto tensor : inputTensors) {
  203. if (tensor == nullptr) {
  204. MS_LOGW("tensor in inputTensors is nullptr");
  205. continue;
  206. }
  207. if (tensor->GetFormat() == Format_NC4HW4) {
  208. if (tensor->GetData() != nullptr) {
  209. free(tensor->GetData());
  210. tensor->SetData(nullptr);
  211. }
  212. continue;
  213. }
  214. tensor->SetData(nullptr);
  215. }
  216. }
  217. void GraphExecution::FreeAllTensors() { graph->FreeAllTensors(); }
  218. int GraphExecution::Run(const std::vector<Tensor *> &inputs) {
  219. if (inputs.empty()) {
  220. MS_LOGE("input is empty");
  221. return RET_ERROR;
  222. }
  223. int ret;
  224. if (readyQue.empty()) {
  225. MS_LOGE("readyQue is empty");
  226. return RET_ERROR;
  227. }
  228. ret = SetInputTensors(inputs);
  229. if (ret != RET_OK) {
  230. MS_LOGE("SetInputTensors failed: %d", ret);
  231. ResetInputData();
  232. return ret;
  233. }
  234. ret = MallocOutput();
  235. if (ret != RET_OK) {
  236. MS_LOGE("MallocOutput failed: %d", ret);
  237. ResetInputData();
  238. return ret;
  239. }
  240. while (!readyQue.empty()) {
  241. auto *node = readyQue.front();
  242. readyQue.pop_front();
  243. ret = node->Run(_ctx);
  244. if (ret != RET_OK) {
  245. MS_LOGE("node (%s) failed to run op (%s). error code:%d", node->ID().c_str(), node->Type().c_str(), ret);
  246. ResetInputData();
  247. FreeAllTensors();
  248. return ret;
  249. }
  250. for (auto outNode : node->GetAllOutEdges()) {
  251. auto nodeDepend = depends.find(outNode);
  252. nodeDepend->second.erase(node);
  253. if (nodeDepend->second.empty()) {
  254. depends.erase(nodeDepend);
  255. readyQue.push_back(outNode);
  256. }
  257. }
  258. }
  259. ResetInputData();
  260. return RET_OK;
  261. }
  262. } // namespace predict
  263. } // namespace mindspore