You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_util.cc 23 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/common/graph_util.h"
  17. #include <stdlib.h>
  18. #include <time.h>
  19. #include <utility>
  20. #include <set>
  21. #include "schema/inner/model_generated.h"
  22. #include "tools/common/tensor_util.h"
  23. #include "tools/common/node_util.h"
  24. #include "src/common/log_adapter.h"
  25. #include "src/common/utils.h"
  26. namespace mindspore {
  27. namespace lite {
  28. OpDefCopyer GetSimpleOpCopyer() {
  29. return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> {
  30. std::unique_ptr<CNodeT> newCNode(new CNodeT);
  31. newCNode->name = inCNode->name;
  32. newCNode->quantType = inCNode->quantType;
  33. newCNode->primitive = std::make_unique<schema::PrimitiveT>();
  34. newCNode->primitive->value.type = inCNode->primitive->value.type;
  35. return newCNode;
  36. };
  37. }
  38. std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
  39. return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
  40. }
  41. std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) {
  42. std::vector<uint32_t> inputIndexes;
  43. if (inputIndexIdx == -1) {
  44. inputIndexes = node.inputIndex;
  45. } else {
  46. MS_ASSERT(node.inputIndex.size() > inputIndexIdx);
  47. inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
  48. }
  49. std::set<size_t> inputNodeIdx;
  50. for (uint32_t inputIdx : inputIndexes) {
  51. auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
  52. inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
  53. }
  54. std::vector<size_t> ret;
  55. ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
  56. return ret;
  57. }
  58. std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
  59. const int outputIndexIdx) {
  60. return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
  61. }
  62. std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) {
  63. std::vector<uint32_t> outputIndexes;
  64. if (outputIndexIdx == -1) {
  65. outputIndexes = node.outputIndex;
  66. } else {
  67. MS_ASSERT(node.outputIndex.size() > outputIndexIdx);
  68. outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
  69. }
  70. std::set<size_t> outputNodeIdx;
  71. for (uint32_t outputIdx : outputIndexes) {
  72. auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
  73. outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
  74. }
  75. std::vector<size_t> ret;
  76. ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
  77. return ret;
  78. }
  79. std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
  80. std::vector<size_t> preNodeIdx;
  81. for (size_t i = 0; i < graphT.nodes.size(); i++) {
  82. auto &oldNode = graphT.nodes.at(i);
  83. if (oldNode == nullptr) {
  84. continue;
  85. }
  86. auto outputIndexes = oldNode->outputIndex;
  87. if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
  88. preNodeIdx.emplace_back(i);
  89. }
  90. }
  91. return preNodeIdx;
  92. }
  93. std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
  94. std::vector<size_t> postNodeIdx;
  95. for (size_t i = 0; i < graphT.nodes.size(); i++) {
  96. auto &oldNode = graphT.nodes.at(i);
  97. if (oldNode == nullptr) {
  98. continue;
  99. }
  100. auto inputIndexes = oldNode->inputIndex;
  101. if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
  102. postNodeIdx.emplace_back(i);
  103. }
  104. }
  105. return postNodeIdx;
  106. }
  107. STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
  108. MS_ASSERT(graphT != nullptr);
  109. MS_ASSERT(node != nullptr);
  110. size_t nodeIdx = 0;
  111. for (size_t i = 0; i < graphT->nodes.size(); i++) {
  112. auto &inNode = graphT->nodes.at(i);
  113. MS_ASSERT(inNode != nullptr);
  114. if (inNode->name == node->name) {
  115. nodeIdx = i;
  116. break;
  117. }
  118. }
  119. auto inputTensorIdxes = node->inputIndex;
  120. auto outputTensorIdxes = node->outputIndex;
  121. if (inputTensorIdxes.empty()) {
  122. MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
  123. return RET_ERROR;
  124. }
  125. if (outputTensorIdxes.size() != 1) {
  126. MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
  127. << "should has 1 output, in fact: " << outputTensorIdxes.size();
  128. return RET_ERROR;
  129. }
  130. auto inDataTensorIdx = inputTensorIdxes.front();
  131. auto outDataTensorIdx = outputTensorIdxes.front();
  132. MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
  133. auto &gOutTensorIdx = graphT->outputIndex;
  134. for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
  135. if (*iter == outDataTensorIdx) {
  136. *iter = inDataTensorIdx;
  137. break;
  138. }
  139. }
  140. // find poseNode
  141. auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
  142. for (auto postNodeIdx : postNodeIdxes) {
  143. MS_ASSERT(graphT->nodes.size() > postNodeIdx);
  144. auto &postNode = graphT->nodes.at(postNodeIdx);
  145. MS_ASSERT(postNode != nullptr);
  146. for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
  147. if (*iter == outDataTensorIdx) {
  148. *iter = inDataTensorIdx;
  149. break;
  150. }
  151. }
  152. }
  153. // whether need to remove weightInputTensores
  154. // remove all node's outputTensors
  155. RemoveTensor(graphT, outputTensorIdxes);
  156. node->inputIndex.clear();
  157. node->outputIndex.clear();
  158. return RET_OK;
  159. }
  160. STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
  161. MS_ASSERT(graph != nullptr);
  162. return IsolateOneWayNode(graph, nodeIdx, removeTensor);
  163. }
  164. STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
  165. MS_ASSERT(graphT != nullptr);
  166. if (graphT->nodes.size() <= nodeIdx) {
  167. MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
  168. return RET_PARAM_INVALID;
  169. }
  170. CNodeT *node = graphT->nodes.at(nodeIdx).get();
  171. auto inputTensorIdxes = node->inputIndex;
  172. auto outputTensorIdxes = node->outputIndex;
  173. auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
  174. if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
  175. MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
  176. return RET_ERROR;
  177. }
  178. if (inputTensorIdxes.empty()) {
  179. MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
  180. return RET_ERROR;
  181. }
  182. auto inDataTensorIdx = inputTensorIdxes.front();
  183. if (!outputTensorIdxes.empty()) {
  184. auto outDataTensorIdx = outputTensorIdxes.front();
  185. MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
  186. MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
  187. auto &gOutTensorIdx = graphT->outputIndex;
  188. for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
  189. if (*iter == outDataTensorIdx) {
  190. *iter = inDataTensorIdx;
  191. break;
  192. }
  193. }
  194. // find poseNode
  195. auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
  196. for (auto postNodeIdx : postNodeIdxes) {
  197. MS_ASSERT(graphT->nodes.size() > postNodeIdx);
  198. auto &postNode = graphT->nodes.at(postNodeIdx);
  199. MS_ASSERT(postNode != nullptr);
  200. for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
  201. if (*iter == outDataTensorIdx) {
  202. *iter = inDataTensorIdx;
  203. break;
  204. }
  205. }
  206. }
  207. }
  208. if (removeTensor) {
  209. // now all node's outputTensors are useless
  210. // remove all node's outputTensors
  211. auto status = RemoveTensor(graphT, outputTensorIdxes);
  212. if (status != RET_OK) {
  213. MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
  214. return RET_ERROR;
  215. }
  216. }
  217. node->inputIndex.clear();
  218. node->outputIndex.clear();
  219. return RET_OK;
  220. }
  221. STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) {
  222. MS_ASSERT(graphT != nullptr);
  223. MS_ASSERT(node != nullptr);
  224. bool isSubNode = false;
  225. size_t nodeIdx = 0;
  226. for (size_t i = 0; i < graphT->nodes.size(); i++) {
  227. auto &inNode = graphT->nodes.at(i);
  228. if (inNode->name == node->name) {
  229. isSubNode = true;
  230. nodeIdx = i;
  231. break;
  232. }
  233. }
  234. if (!isSubNode) {
  235. MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
  236. return RET_PARAM_INVALID;
  237. } else {
  238. return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
  239. }
  240. }
  241. STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
  242. for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
  243. uint32_t deleteIdx = *iter;
  244. if (!forceDelete) {
  245. if (GetRefCount(graphT, deleteIdx) > 1) {
  246. iter++;
  247. continue;
  248. }
  249. }
  250. // update graph input indexes
  251. for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
  252. if (*gInIdx > deleteIdx) {
  253. (*gInIdx)--;
  254. }
  255. }
  256. // update graph output indexes
  257. for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
  258. if (*gOutIdx > deleteIdx) {
  259. (*gOutIdx)--;
  260. }
  261. }
  262. // update nodes indexes
  263. for (auto nodeIter = graphT->nodes.begin(); nodeIter != graphT->nodes.end(); nodeIter++) {
  264. // update nodes input indexes
  265. UpdateNodeIndex((*nodeIter).get(), deleteIdx);
  266. }
  267. // update deleteTensorIdx
  268. for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
  269. if (*selfIt > deleteIdx) {
  270. (*selfIt)--;
  271. }
  272. }
  273. graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
  274. iter = toDeleteTensorIdxes.erase(iter);
  275. }
  276. return RET_OK;
  277. }
  278. STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) {
  279. for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
  280. if (*inIdxIt == deleteIdx) {
  281. inIdxIt = node->inputIndex.erase(inIdxIt);
  282. } else {
  283. if (*inIdxIt > deleteIdx) {
  284. (*inIdxIt)--;
  285. }
  286. inIdxIt++;
  287. }
  288. }
  289. // update nodes output indexes
  290. for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
  291. if (*outIdxIt == deleteIdx) {
  292. outIdxIt = node->outputIndex.erase(outIdxIt);
  293. } else {
  294. if (*outIdxIt > deleteIdx) {
  295. (*outIdxIt)--;
  296. }
  297. outIdxIt++;
  298. }
  299. }
  300. return RET_OK;
  301. }
  302. STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor,
  303. InsertPlace place) {
  304. if (nodeIdx >= graphT->nodes.size()) {
  305. MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
  306. return RET_PARAM_INVALID;
  307. }
  308. graphT->allTensors.emplace_back(std::move(tensor));
  309. uint32_t newTensorIdx = graphT->allTensors.size() - 1;
  310. auto node = graphT->nodes.at(nodeIdx).get();
  311. if (place == kBefore) {
  312. node->inputIndex.emplace_back(newTensorIdx);
  313. } else {
  314. node->outputIndex.emplace_back(newTensorIdx);
  315. }
  316. return RET_OK;
  317. }
  318. STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx,
  319. std::unique_ptr<TensorT> tensor) {
  320. if (nodeIdx >= graphT->nodes.size()) {
  321. MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
  322. return RET_PARAM_INVALID;
  323. }
  324. auto node = graphT->nodes.at(nodeIdx).get();
  325. if (inTensorIdx >= graphT->allTensors.size()) {
  326. MS_LOG(ERROR) << "inTensorIdx out of range: " << nodeIdx;
  327. return RET_PARAM_INVALID;
  328. }
  329. if (!IsContain(node->inputIndex, inTensorIdx)) {
  330. MS_LOG(ERROR) << "inTensorIdx(" << inTensorIdx << ") is not a inputIdx of node(" << nodeIdx << ")";
  331. return RET_PARAM_INVALID;
  332. }
  333. graphT->allTensors.at(inTensorIdx).swap(tensor);
  334. return RET_OK;
  335. }
  336. NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex,
  337. std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) {
  338. if (existNodeIdx >= graphT->nodes.size()) {
  339. MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx;
  340. return graphT->nodes.end();
  341. }
  342. auto nodeIter = graphT->nodes.begin() + existNodeIdx;
  343. MS_ASSERT(nodeIter != graphT->nodes.begin());
  344. MS_ASSERT((*nodeIter) != nullptr);
  345. return InsertNode(graphT, nodeIter, place, inoutIndex, std::move(toAddNode), errorCode);
  346. }
  347. NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx,
  348. std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer) {
  349. if (place == kBefore) {
  350. return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer);
  351. } else if (place == kAfter) {
  352. return InsertNodeAfter(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer);
  353. } else {
  354. MS_LOG(ERROR) << "Invalid InsertPlace : " << place;
  355. return graphT->nodes.end();
  356. }
  357. }
  358. NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx,
  359. std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) {
  360. auto &existNode = *existNodeIter;
  361. MS_ASSERT(existNode != nullptr);
  362. MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx);
  363. MS_ASSERT(toAddNodeIn != nullptr);
  364. auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx);
  365. MS_ASSERT(graphT->allTensors.size() > preTensorIdx);
  366. auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode.get()), inputIndexIdx);
  367. if (preNodeIdxes.empty()) {
  368. auto &preTensor = graphT->allTensors.at(preTensorIdx);
  369. MS_ASSERT(preTensor != nullptr);
  370. auto toAddTensor = CopyTensorDefT(preTensor);
  371. if (toAddTensor == nullptr) {
  372. MS_LOG(ERROR) << "Copy TensorT failed";
  373. *errorCode = RET_NULL_PTR;
  374. return graphT->nodes.end();
  375. }
  376. preTensor->refCount = 0;
  377. preTensor->data.clear();
  378. if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
  379. preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
  380. toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
  381. }
  382. graphT->allTensors.emplace_back(std::move(toAddTensor));
  383. size_t toAddTensorIdx = graphT->allTensors.size() - 1;
  384. auto toAddNode = opDefCopyer(toAddNodeIn.get());
  385. if (toAddNode == nullptr) {
  386. MS_LOG(ERROR) << "copy toAddNodeIn failed";
  387. *errorCode = RET_NULL_PTR;
  388. return graphT->nodes.end();
  389. }
  390. toAddNode->inputIndex.clear();
  391. toAddNode->inputIndex.push_back(preTensorIdx);
  392. toAddNode->outputIndex.clear();
  393. toAddNode->outputIndex.push_back(toAddTensorIdx);
  394. for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) {
  395. if (*iter == preTensorIdx) {
  396. *iter = toAddTensorIdx;
  397. break;
  398. }
  399. }
  400. existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
  401. existNodeIter++;
  402. } else {
  403. std::vector<std::unique_ptr<CNodeT>> toAddNodes;
  404. for (size_t i = 0; i < preNodeIdxes.size(); i++) {
  405. MS_ASSERT(graphT->nodes.size() > preNodeIdxes.at(i));
  406. auto &preTensor = graphT->allTensors.at(preTensorIdx);
  407. MS_ASSERT(preTensor != nullptr);
  408. auto toAddTensor = CopyTensorDefT(preTensor);
  409. if (toAddTensor == nullptr) {
  410. *errorCode = RET_NULL_PTR;
  411. MS_LOG(ERROR) << "Copy TensorT failed";
  412. return graphT->nodes.end();
  413. }
  414. if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
  415. preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
  416. toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
  417. }
  418. graphT->allTensors.emplace_back(std::move(toAddTensor));
  419. size_t toAddTensorIdx = graphT->allTensors.size() - 1;
  420. auto toAddNode = opDefCopyer(toAddNodeIn.get());
  421. if (toAddNode == nullptr) {
  422. MS_LOG(ERROR) << "copy toAddNodeIn failed";
  423. *errorCode = RET_NULL_PTR;
  424. return graphT->nodes.end();
  425. }
  426. toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++);
  427. toAddNode->inputIndex.clear();
  428. toAddNode->inputIndex.push_back(preTensorIdx);
  429. toAddNode->outputIndex.clear();
  430. toAddNode->outputIndex.push_back(toAddTensorIdx);
  431. for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) {
  432. if (*iter == preTensorIdx) {
  433. *iter = toAddTensorIdx;
  434. break;
  435. }
  436. }
  437. toAddNodes.emplace_back(std::move(toAddNode));
  438. }
  439. for (auto &toAddNode : toAddNodes) {
  440. existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
  441. existNodeIter++;
  442. }
  443. }
  444. *errorCode = RET_OK;
  445. return existNodeIter;
  446. }
  447. NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx,
  448. std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode, OpDefCopyer opDefCopyer) {
  449. auto &existNode = *existNodeIter;
  450. MS_ASSERT(existNode != nullptr);
  451. MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx);
  452. MS_ASSERT(toAddNodeIn != nullptr);
  453. auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx);
  454. MS_ASSERT(graphT->allTensors.size() > postTensorIdx);
  455. auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode.get()), outputIndexIdx);
  456. if (postNodeIdxes.empty()) {
  457. auto &postTensor = graphT->allTensors.at(postTensorIdx);
  458. MS_ASSERT(postTensor != nullptr);
  459. auto toAddTensor = CopyTensorDefT(postTensor);
  460. if (toAddTensor == nullptr) {
  461. MS_LOG(ERROR) << "Copy TensorT failed";
  462. *errorCode = RET_NULL_PTR;
  463. return graphT->nodes.end();
  464. }
  465. if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
  466. postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
  467. toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
  468. }
  469. graphT->allTensors.emplace_back(std::move(toAddTensor));
  470. size_t toAddTensorIdx = graphT->allTensors.size() - 1;
  471. auto toAddNode = opDefCopyer(toAddNodeIn.get());
  472. if (toAddNode == nullptr) {
  473. MS_LOG(ERROR) << "copy toAddNodeIn failed";
  474. *errorCode = RET_NULL_PTR;
  475. return graphT->nodes.end();
  476. }
  477. toAddNode->inputIndex.clear();
  478. toAddNode->inputIndex.push_back(postTensorIdx);
  479. toAddNode->outputIndex.clear();
  480. toAddNode->outputIndex.push_back(toAddTensorIdx);
  481. for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) {
  482. if (*iter == postTensorIdx) {
  483. *iter = toAddTensorIdx;
  484. break;
  485. }
  486. }
  487. existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
  488. existNodeIter++;
  489. } else {
  490. std::vector<std::unique_ptr<schema::CNodeT>> toAddNodes;
  491. int i = 0;
  492. for (size_t postNodeIdx : postNodeIdxes) {
  493. MS_ASSERT(graphT->nodes.size() > postNodeIdx);
  494. auto &postNode = graphT->nodes.at(postNodeIdx);
  495. MS_ASSERT(postNode != nullptr);
  496. auto &postTensor = graphT->allTensors.at(postTensorIdx);
  497. MS_ASSERT(postTensor != nullptr);
  498. // for multioutput,when one outpout as other node input,need add one more node
  499. if (IsContain(graphT->outputIndex, postTensorIdx)) {
  500. auto toAddTensor = CopyTensorDefT(postTensor);
  501. if (toAddTensor == nullptr) {
  502. MS_LOG(ERROR) << "Copy TensorT failed";
  503. *errorCode = RET_NULL_PTR;
  504. return graphT->nodes.end();
  505. }
  506. graphT->allTensors.emplace_back(std::move(toAddTensor));
  507. size_t toAddTensorIdx = graphT->allTensors.size() - 1;
  508. auto toAddNode = opDefCopyer(toAddNodeIn.get());
  509. toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++);
  510. toAddNode->inputIndex.clear();
  511. toAddNode->inputIndex.push_back(postTensorIdx);
  512. toAddNode->outputIndex.clear();
  513. toAddNode->outputIndex.push_back(toAddTensorIdx);
  514. for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) {
  515. if (*iter == postTensorIdx) {
  516. *iter = toAddTensorIdx;
  517. break;
  518. }
  519. }
  520. toAddNodes.emplace_back(std::move(toAddNode));
  521. }
  522. auto toAddTensor = CopyTensorDefT(postTensor);
  523. if (toAddTensor == nullptr) {
  524. MS_LOG(ERROR) << "Copy TensorT failed";
  525. *errorCode = RET_NULL_PTR;
  526. return graphT->nodes.end();
  527. }
  528. if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
  529. postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
  530. toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
  531. }
  532. graphT->allTensors.emplace_back(std::move(toAddTensor));
  533. size_t toAddTensorIdx = graphT->allTensors.size() - 1;
  534. auto toAddNode = opDefCopyer(toAddNodeIn.get());
  535. if (toAddNode == nullptr) {
  536. MS_LOG(ERROR) << "copy toAddNodeIn failed";
  537. *errorCode = RET_NULL_PTR;
  538. return graphT->nodes.end();
  539. }
  540. toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++);
  541. toAddNode->inputIndex.clear();
  542. toAddNode->inputIndex.push_back(postTensorIdx);
  543. toAddNode->outputIndex.clear();
  544. toAddNode->outputIndex.push_back(toAddTensorIdx);
  545. MS_ASSERT(IsContain(postNode->inputIndex, postTensorIdx));
  546. for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
  547. if (*iter == postTensorIdx) {
  548. *iter = toAddTensorIdx;
  549. break;
  550. }
  551. }
  552. toAddNodes.emplace_back(std::move(toAddNode));
  553. }
  554. for (auto &toAddNode : toAddNodes) {
  555. existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
  556. existNodeIter++;
  557. }
  558. }
  559. *errorCode = RET_OK;
  560. return existNodeIter;
  561. }
  562. STATUS ValidateFileStr(const std::string &modelFile, std::string fileType) {
  563. if (modelFile.size() > fileType.size()) {
  564. if (modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
  565. return RET_OK;
  566. } else {
  567. return RET_ERROR;
  568. }
  569. } else {
  570. return RET_ERROR;
  571. }
  572. }
  573. std::string GetModelName(const std::string &modelFile) {
  574. std::string modelName = modelFile;
  575. modelName = modelName.substr(modelName.find_last_of('/') + 1);
  576. modelName = modelName.substr(0, modelName.find_last_of('.'));
  577. srand((unsigned)time(NULL));
  578. modelName = modelName + std::to_string(rand());
  579. return modelName;
  580. }
  581. } // namespace lite
  582. } // namespace mindspore