You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_util.cc 32 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/common/graph_util.h"
  17. #include <ctime>
  18. #include <utility>
  19. #include <set>
  20. #include "schema/inner/model_generated.h"
  21. #include "tools/common/tensor_util.h"
  22. #include "tools/common/node_util.h"
  23. #include "src/common/log_adapter.h"
  24. #include "src/common/utils.h"
  25. namespace mindspore {
  26. namespace lite {
  27. OpDefCopyer GetSimpleOpCopyer() {
  28. return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> {
  29. std::unique_ptr<CNodeT> newCNode = std::make_unique<CNodeT>();
  30. if (newCNode == nullptr) {
  31. return nullptr;
  32. }
  33. newCNode->name = inCNode->name;
  34. newCNode->quantType = inCNode->quantType;
  35. newCNode->primitive = std::make_unique<schema::PrimitiveT>();
  36. newCNode->primitive->value.type = inCNode->primitive->value.type;
  37. return newCNode;
  38. };
  39. }
  40. std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx, const int inputIndexIdx) {
  41. return GetInputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), inputIndexIdx);
  42. }
  43. std::vector<size_t> GetInputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int inputIndexIdx) {
  44. std::vector<uint32_t> inputIndexes;
  45. if (inputIndexIdx == -1) {
  46. inputIndexes = node.inputIndex;
  47. } else {
  48. MS_ASSERT(node.inputIndex.size() > inputIndexIdx);
  49. inputIndexes.emplace_back(node.inputIndex.at(inputIndexIdx));
  50. }
  51. std::set<size_t> inputNodeIdx;
  52. for (uint32_t inputIdx : inputIndexes) {
  53. auto linkedPreIdx = GetLinkedPreIdx(graphT, inputIdx);
  54. inputNodeIdx.insert(linkedPreIdx.begin(), linkedPreIdx.end());
  55. }
  56. std::vector<size_t> ret;
  57. ret.insert(ret.end(), inputNodeIdx.begin(), inputNodeIdx.end());
  58. return ret;
  59. }
  60. std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const size_t &nodeIdx,
  61. const int outputIndexIdx) {
  62. return GetOutputNodeIdx(graphT, *(graphT.nodes.at(nodeIdx).get()), outputIndexIdx);
  63. }
  64. std::vector<size_t> GetOutputNodeIdx(const schema::MetaGraphT &graphT, const CNodeT &node, const int outputIndexIdx) {
  65. std::vector<uint32_t> outputIndexes;
  66. if (outputIndexIdx == -1) {
  67. outputIndexes = node.outputIndex;
  68. } else {
  69. MS_ASSERT(node.outputIndex.size() > outputIndexIdx);
  70. outputIndexes.emplace_back(node.outputIndex.at(outputIndexIdx));
  71. }
  72. std::set<size_t> outputNodeIdx;
  73. for (uint32_t outputIdx : outputIndexes) {
  74. auto linkedPostIdx = GetLinkedPostIdx(graphT, outputIdx);
  75. outputNodeIdx.insert(linkedPostIdx.begin(), linkedPostIdx.end());
  76. }
  77. std::vector<size_t> ret;
  78. ret.insert(ret.end(), outputNodeIdx.begin(), outputNodeIdx.end());
  79. return ret;
  80. }
  81. std::vector<size_t> GetLinkedPreIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
  82. std::vector<size_t> preNodeIdx;
  83. for (size_t i = 0; i < graphT.nodes.size(); i++) {
  84. auto &oldNode = graphT.nodes.at(i);
  85. if (oldNode == nullptr) {
  86. continue;
  87. }
  88. auto outputIndexes = oldNode->outputIndex;
  89. if (IsContain<uint32_t>(outputIndexes, tensorIdx)) {
  90. preNodeIdx.emplace_back(i);
  91. }
  92. }
  93. return preNodeIdx;
  94. }
  95. std::vector<size_t> GetLinkedPostIdx(const schema::MetaGraphT &graphT, const size_t &tensorIdx) {
  96. std::vector<size_t> postNodeIdx;
  97. for (size_t i = 0; i < graphT.nodes.size(); i++) {
  98. auto &oldNode = graphT.nodes.at(i);
  99. if (oldNode == nullptr) {
  100. continue;
  101. }
  102. auto inputIndexes = oldNode->inputIndex;
  103. if (IsContain<uint32_t>(inputIndexes, tensorIdx)) {
  104. postNodeIdx.emplace_back(i);
  105. }
  106. }
  107. return postNodeIdx;
  108. }
  109. STATUS IsolateNode(schema::MetaGraphT *graphT, CNodeT *node) {
  110. MS_ASSERT(graphT != nullptr);
  111. MS_ASSERT(node != nullptr);
  112. size_t nodeIdx = 0;
  113. for (size_t i = 0; i < graphT->nodes.size(); i++) {
  114. auto &inNode = graphT->nodes.at(i);
  115. MS_ASSERT(inNode != nullptr);
  116. if (inNode->name == node->name) {
  117. nodeIdx = i;
  118. break;
  119. }
  120. }
  121. auto inputTensorIdxes = node->inputIndex;
  122. auto outputTensorIdxes = node->outputIndex;
  123. if (inputTensorIdxes.empty()) {
  124. MS_LOG(ERROR) << "Node " << node->name.c_str() << "should has no inputs";
  125. return RET_ERROR;
  126. }
  127. if (outputTensorIdxes.size() != 1) {
  128. MS_LOG(ERROR) << "FakeQuantNode " << node->name.c_str()
  129. << "should has 1 output, in fact: " << outputTensorIdxes.size();
  130. return RET_ERROR;
  131. }
  132. auto inDataTensorIdx = inputTensorIdxes.front();
  133. auto outDataTensorIdx = outputTensorIdxes.front();
  134. MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
  135. auto &gOutTensorIdx = graphT->outputIndex;
  136. for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
  137. if (*iter == outDataTensorIdx) {
  138. *iter = inDataTensorIdx;
  139. break;
  140. }
  141. }
  142. // find poseNode
  143. auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
  144. for (auto postNodeIdx : postNodeIdxes) {
  145. MS_ASSERT(graphT->nodes.size() > postNodeIdx);
  146. auto &postNode = graphT->nodes.at(postNodeIdx);
  147. MS_ASSERT(postNode != nullptr);
  148. for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
  149. if (*iter == outDataTensorIdx) {
  150. *iter = inDataTensorIdx;
  151. break;
  152. }
  153. }
  154. }
  155. RemoveTensor(graphT, outputTensorIdxes);
  156. node->inputIndex.clear();
  157. node->outputIndex.clear();
  158. return RET_OK;
  159. }
  160. STATUS IsolateOneWayNode(schema::MetaGraphT *graph, size_t subGraphIdx, size_t nodeIdx, bool removeTensor) {
  161. MS_ASSERT(graph != nullptr);
  162. return IsolateOneWayNode(graph, nodeIdx, removeTensor);
  163. }
  164. STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, size_t nodeIdx, bool removeTensor) {
  165. MS_ASSERT(graphT != nullptr);
  166. if (graphT->nodes.size() <= nodeIdx) {
  167. MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
  168. return RET_PARAM_INVALID;
  169. }
  170. CNodeT *node = graphT->nodes.at(nodeIdx).get();
  171. if (node == nullptr) {
  172. MS_LOG(ERROR) << "node is null";
  173. return RET_NULL_PTR;
  174. }
  175. auto inputTensorIdxes = node->inputIndex;
  176. auto outputTensorIdxes = node->outputIndex;
  177. auto preNodeIdxes = GetInputNodeIdx(*graphT, nodeIdx);
  178. if (preNodeIdxes.size() > 1 || outputTensorIdxes.size() > 1) {
  179. MS_LOG(ERROR) << "Only support node who has no more than one input and one output";
  180. return RET_ERROR;
  181. }
  182. if (inputTensorIdxes.empty()) {
  183. MS_LOG(ERROR) << "Error, " << nodeIdx << "th node has no input tensor";
  184. return RET_ERROR;
  185. }
  186. auto inDataTensorIdx = inputTensorIdxes.front();
  187. if (!outputTensorIdxes.empty()) {
  188. auto outDataTensorIdx = outputTensorIdxes.front();
  189. MS_ASSERT(graphT->allTensors.size() > inDataTensorIdx);
  190. MS_ASSERT(graphT->allTensors.at(inDataTensorIdx) != nullptr);
  191. auto &gOutTensorIdx = graphT->outputIndex;
  192. for (auto iter = gOutTensorIdx.begin(); iter != gOutTensorIdx.end(); iter++) {
  193. if (*iter == outDataTensorIdx) {
  194. *iter = inDataTensorIdx;
  195. break;
  196. }
  197. }
  198. // find poseNode
  199. auto postNodeIdxes = GetOutputNodeIdx(*graphT, nodeIdx, 0);
  200. for (auto postNodeIdx : postNodeIdxes) {
  201. MS_ASSERT(graphT->nodes.size() > postNodeIdx);
  202. auto &postNode = graphT->nodes.at(postNodeIdx);
  203. MS_ASSERT(postNode != nullptr);
  204. for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
  205. if (*iter == outDataTensorIdx) {
  206. *iter = inDataTensorIdx;
  207. break;
  208. }
  209. }
  210. }
  211. }
  212. if (removeTensor) {
  213. // now all node's outputTensors are useless
  214. // remove all node's outputTensors
  215. auto status = RemoveTensor(graphT, outputTensorIdxes);
  216. if (status != RET_OK) {
  217. MS_LOG(ERROR) << "RemoveOutputTensors of node " << node->name.c_str() << "failed";
  218. return RET_ERROR;
  219. }
  220. }
  221. node->inputIndex.clear();
  222. node->outputIndex.clear();
  223. return RET_OK;
  224. }
  225. STATUS IsolateOneWayNode(schema::MetaGraphT *graphT, CNodeT *node, bool removeTensor) {
  226. MS_ASSERT(graphT != nullptr);
  227. MS_ASSERT(node != nullptr);
  228. bool isSubNode = false;
  229. size_t nodeIdx = 0;
  230. for (size_t i = 0; i < graphT->nodes.size(); i++) {
  231. auto &inNode = graphT->nodes.at(i);
  232. MS_ASSERT(inNode != nullptr);
  233. if (inNode->name == node->name) {
  234. isSubNode = true;
  235. nodeIdx = i;
  236. break;
  237. }
  238. }
  239. if (!isSubNode) {
  240. MS_LOG(ERROR) << "Node " << node->name.c_str() << "is not in graphT " << graphT->name.c_str();
  241. return RET_PARAM_INVALID;
  242. } else {
  243. return IsolateOneWayNode(graphT, nodeIdx, removeTensor);
  244. }
  245. }
  246. STATUS RemoveTensor(schema::MetaGraphT *graphT, std::vector<uint32_t> toDeleteTensorIdxes, bool forceDelete) {
  247. MS_ASSERT(graphT != nullptr);
  248. for (auto iter = toDeleteTensorIdxes.begin(); iter != toDeleteTensorIdxes.end();) {
  249. uint32_t deleteIdx = *iter;
  250. if (!forceDelete) {
  251. if (GetRefCount(graphT, deleteIdx) > 1) {
  252. iter++;
  253. continue;
  254. }
  255. }
  256. // update graph input indices
  257. for (auto gInIdx = graphT->inputIndex.begin(); gInIdx != graphT->inputIndex.end(); gInIdx++) {
  258. if (*gInIdx > deleteIdx) {
  259. (*gInIdx)--;
  260. }
  261. }
  262. // update graph output indices
  263. for (auto gOutIdx = graphT->outputIndex.begin(); gOutIdx != graphT->outputIndex.end(); gOutIdx++) {
  264. if (*gOutIdx > deleteIdx) {
  265. (*gOutIdx)--;
  266. }
  267. }
  268. for (auto &subgraph : graphT->subGraph) {
  269. // update subgraph input indices
  270. for (auto gInIdx = subgraph->inputIndices.begin(); gInIdx != subgraph->inputIndices.end(); gInIdx++) {
  271. if (*gInIdx > deleteIdx) {
  272. (*gInIdx)--;
  273. }
  274. }
  275. // update subgraph output indices
  276. for (auto gOutIdx = subgraph->outputIndices.begin(); gOutIdx != subgraph->outputIndices.end(); gOutIdx++) {
  277. if (*gOutIdx > deleteIdx) {
  278. (*gOutIdx)--;
  279. }
  280. }
  281. // update subgraph output indices
  282. for (auto idx = subgraph->tensorIndices.begin(); idx != subgraph->tensorIndices.end(); idx++) {
  283. if (*idx > deleteIdx) {
  284. (*idx)--;
  285. }
  286. }
  287. }
  288. // update nodes indexes
  289. for (auto node_iter = graphT->nodes.begin(); node_iter != graphT->nodes.end(); node_iter++) {
  290. // update nodes input indexes
  291. UpdateNodeIndex((*node_iter).get(), deleteIdx);
  292. }
  293. // update deleteTensorIdx
  294. for (auto selfIt = toDeleteTensorIdxes.begin(); selfIt != toDeleteTensorIdxes.end(); selfIt++) {
  295. if (*selfIt > deleteIdx) {
  296. (*selfIt)--;
  297. }
  298. }
  299. graphT->allTensors.erase(graphT->allTensors.begin() + deleteIdx);
  300. iter = toDeleteTensorIdxes.erase(iter);
  301. }
  302. return RET_OK;
  303. }
  304. STATUS UpdateNodeIndex(CNodeT *node, uint32_t deleteIdx) {
  305. MS_ASSERT(node != nullptr);
  306. for (auto inIdxIt = node->inputIndex.begin(); inIdxIt != node->inputIndex.end();) {
  307. if (*inIdxIt == deleteIdx) {
  308. inIdxIt = node->inputIndex.erase(inIdxIt);
  309. } else {
  310. if (*inIdxIt > deleteIdx) {
  311. (*inIdxIt)--;
  312. }
  313. inIdxIt++;
  314. }
  315. }
  316. // update nodes output indexes
  317. for (auto outIdxIt = node->outputIndex.begin(); outIdxIt != node->outputIndex.end();) {
  318. if (*outIdxIt == deleteIdx) {
  319. outIdxIt = node->outputIndex.erase(outIdxIt);
  320. } else {
  321. if (*outIdxIt > deleteIdx) {
  322. (*outIdxIt)--;
  323. }
  324. outIdxIt++;
  325. }
  326. }
  327. return RET_OK;
  328. }
  329. STATUS AddTensor2Node(schema::MetaGraphT *graphT, uint32_t nodeIdx, std::unique_ptr<TensorT> tensor,
  330. InsertPlace place) {
  331. if (nodeIdx >= graphT->nodes.size()) {
  332. MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
  333. return RET_PARAM_INVALID;
  334. }
  335. graphT->allTensors.emplace_back(std::move(tensor));
  336. uint32_t newTensorIdx = graphT->allTensors.size() - 1;
  337. auto node = graphT->nodes.at(nodeIdx).get();
  338. MS_ASSERT(node != nullptr);
  339. if (place == kBefore) {
  340. node->inputIndex.emplace_back(newTensorIdx);
  341. } else {
  342. node->outputIndex.emplace_back(newTensorIdx);
  343. }
  344. return RET_OK;
  345. }
  346. STATUS ReplaceTensorOfNode(schema::MetaGraphT *graphT, uint32_t nodeIdx, uint32_t inTensorIdx,
  347. std::unique_ptr<TensorT> tensor) {
  348. MS_ASSERT(graphT != nullptr);
  349. if (nodeIdx >= graphT->nodes.size()) {
  350. MS_LOG(ERROR) << "nodeIdx out of range: " << nodeIdx;
  351. return RET_PARAM_INVALID;
  352. }
  353. auto node = graphT->nodes.at(nodeIdx).get();
  354. MS_ASSERT(node != nullptr);
  355. if (inTensorIdx >= graphT->allTensors.size()) {
  356. MS_LOG(ERROR) << "inTensorIdx out of range: " << nodeIdx;
  357. return RET_PARAM_INVALID;
  358. }
  359. if (!IsContain(node->inputIndex, inTensorIdx)) {
  360. MS_LOG(ERROR) << "inTensorIdx(" << inTensorIdx << ") is not a inputIdx of node(" << nodeIdx << ")";
  361. return RET_PARAM_INVALID;
  362. }
  363. graphT->allTensors.at(inTensorIdx).swap(tensor);
  364. return RET_OK;
  365. }
  366. NodeIter InsertNode(schema::MetaGraphT *graphT, uint32_t existNodeIdx, InsertPlace place, size_t inoutIndex,
  367. std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer) {
  368. MS_ASSERT(graphT != nullptr);
  369. MS_ASSERT(errorCode != nullptr);
  370. if (existNodeIdx >= graphT->nodes.size()) {
  371. MS_LOG(ERROR) << "nodeIdx out of range: " << existNodeIdx;
  372. return graphT->nodes.end();
  373. }
  374. auto node_iter = graphT->nodes.begin() + existNodeIdx;
  375. MS_ASSERT(node_iter != graphT->nodes.begin());
  376. MS_ASSERT((*node_iter) != nullptr);
  377. return InsertNode(graphT, node_iter, place, inoutIndex, std::move(toAddNode), errorCode);
  378. }
  379. NodeIter InsertNode(schema::MetaGraphT *graphT, NodeIter existNodeIter, InsertPlace place, size_t inoutIndexIdx,
  380. std::unique_ptr<CNodeT> toAddNode, STATUS *errorCode, const OpDefCopyer &opDefCopyer) {
  381. MS_ASSERT(graphT != nullptr);
  382. MS_ASSERT(errorCode != nullptr);
  383. if (place == kBefore) {
  384. return InsertNodeBefore(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer);
  385. } else if (place == kAfter) {
  386. return InsertNodeAfter(graphT, existNodeIter, inoutIndexIdx, std::move(toAddNode), errorCode, opDefCopyer);
  387. } else {
  388. MS_LOG(ERROR) << "Invalid InsertPlace : " << place;
  389. return graphT->nodes.end();
  390. }
  391. }
  392. NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t inputIndexIdx,
  393. std::unique_ptr<CNodeT> toAddNodeIn, STATUS *errorCode, const OpDefCopyer &opDefCopyer) {
  394. MS_ASSERT(graphT != nullptr);
  395. MS_ASSERT(errorCode != nullptr);
  396. auto &existNode = *existNodeIter;
  397. MS_ASSERT(existNode != nullptr);
  398. MS_ASSERT(existNode->inputIndex.size() > inputIndexIdx);
  399. MS_ASSERT(toAddNodeIn != nullptr);
  400. auto preTensorIdx = existNode->inputIndex.at(inputIndexIdx);
  401. MS_ASSERT(graphT->allTensors.size() > preTensorIdx);
  402. auto preNodeIdxes = GetInputNodeIdx(*graphT, *(existNode), inputIndexIdx);
  403. if (preNodeIdxes.empty()) {
  404. auto &preTensor = graphT->allTensors.at(preTensorIdx);
  405. MS_ASSERT(preTensor != nullptr);
  406. auto toAddTensor = CopyTensorDefT(preTensor);
  407. if (toAddTensor == nullptr) {
  408. MS_LOG(ERROR) << "Copy TensorT failed";
  409. *errorCode = RET_NULL_PTR;
  410. return graphT->nodes.end();
  411. }
  412. toAddTensor->nodeType = schema::NodeType_CNode;
  413. preTensor->refCount = 0;
  414. preTensor->data.clear();
  415. MS_ASSERT(toAddNodeIn->primitive != nullptr);
  416. if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
  417. auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
  418. MS_ASSERT(prim != nullptr);
  419. preTensor->dataType = prim->srcT;
  420. toAddTensor->dataType = prim->dstT;
  421. if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
  422. preTensor->quantParams.front()->zeroPoint += 128;
  423. } else if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) {
  424. toAddTensor->quantParams.front()->zeroPoint += 128;
  425. }
  426. }
  427. graphT->allTensors.emplace_back(std::move(toAddTensor));
  428. size_t toAddTensorIdx = graphT->allTensors.size() - 1;
  429. auto toAddNode = opDefCopyer(toAddNodeIn.get());
  430. if (toAddNode == nullptr) {
  431. MS_LOG(ERROR) << "copy toAddNodeIn failed";
  432. *errorCode = RET_NULL_PTR;
  433. return graphT->nodes.end();
  434. }
  435. toAddNode->inputIndex.clear();
  436. toAddNode->inputIndex.push_back(preTensorIdx);
  437. toAddNode->outputIndex.clear();
  438. toAddNode->outputIndex.push_back(toAddTensorIdx);
  439. for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) {
  440. if (*iter == preTensorIdx) {
  441. *iter = toAddTensorIdx;
  442. break;
  443. }
  444. }
  445. existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
  446. existNodeIter++;
  447. } else {
  448. std::vector<std::unique_ptr<CNodeT>> toAddNodes;
  449. for (size_t i = 0; i < preNodeIdxes.size(); i++) {
  450. MS_ASSERT(graphT->nodes.size() > preNodeIdxes.at(i));
  451. auto &preTensor = graphT->allTensors.at(preTensorIdx);
  452. MS_ASSERT(preTensor != nullptr);
  453. auto toAddTensor = CopyTensorDefT(preTensor);
  454. if (toAddTensor == nullptr) {
  455. *errorCode = RET_NULL_PTR;
  456. MS_LOG(ERROR) << "Copy TensorT failed";
  457. return graphT->nodes.end();
  458. }
  459. toAddTensor->nodeType = schema::NodeType_CNode;
  460. MS_ASSERT(toAddNodeIn->primitive != nullptr);
  461. if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
  462. auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
  463. MS_ASSERT(prim != nullptr);
  464. preTensor->dataType = prim->srcT;
  465. toAddTensor->dataType = prim->dstT;
  466. if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
  467. preTensor->quantParams.front()->zeroPoint += 128;
  468. } else if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) {
  469. toAddTensor->quantParams.front()->zeroPoint += 128;
  470. }
  471. }
  472. graphT->allTensors.emplace_back(std::move(toAddTensor));
  473. size_t toAddTensorIdx = graphT->allTensors.size() - 1;
  474. auto toAddNode = opDefCopyer(toAddNodeIn.get());
  475. if (toAddNode == nullptr) {
  476. MS_LOG(ERROR) << "copy toAddNodeIn failed";
  477. *errorCode = RET_NULL_PTR;
  478. return graphT->nodes.end();
  479. }
  480. toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++);
  481. toAddNode->inputIndex.clear();
  482. toAddNode->inputIndex.push_back(preTensorIdx);
  483. toAddNode->outputIndex.clear();
  484. toAddNode->outputIndex.push_back(toAddTensorIdx);
  485. for (auto iter = existNode->inputIndex.begin(); iter != existNode->inputIndex.end(); iter++) {
  486. if (*iter == preTensorIdx) {
  487. *iter = toAddTensorIdx;
  488. break;
  489. }
  490. }
  491. toAddNodes.emplace_back(std::move(toAddNode));
  492. }
  493. for (auto &toAddNode : toAddNodes) {
  494. existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
  495. existNodeIter++;
  496. }
  497. }
  498. *errorCode = RET_OK;
  499. return existNodeIter;
  500. }
  501. NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, size_t outputIndexIdx,
  502. std::unique_ptr<schema::CNodeT> toAddNodeIn, STATUS *errorCode,
  503. const OpDefCopyer &opDefCopyer) {
  504. MS_ASSERT(graphT != nullptr);
  505. MS_ASSERT(errorCode != nullptr);
  506. auto &existNode = *existNodeIter;
  507. MS_ASSERT(existNode != nullptr);
  508. MS_ASSERT(existNode->outputIndex.size() > outputIndexIdx);
  509. MS_ASSERT(toAddNodeIn != nullptr);
  510. auto postTensorIdx = existNode->outputIndex.at(outputIndexIdx);
  511. MS_ASSERT(graphT->allTensors.size() > postTensorIdx);
  512. auto postNodeIdxes = GetOutputNodeIdx(*graphT, *(existNode), outputIndexIdx);
  513. if (postNodeIdxes.empty()) {
  514. auto &postTensor = graphT->allTensors.at(postTensorIdx);
  515. MS_ASSERT(postTensor != nullptr);
  516. auto toAddTensor = CopyTensorDefT(postTensor);
  517. if (toAddTensor == nullptr) {
  518. MS_LOG(ERROR) << "Copy TensorT failed";
  519. *errorCode = RET_NULL_PTR;
  520. return graphT->nodes.end();
  521. }
  522. toAddTensor->nodeType = schema::NodeType_CNode;
  523. MS_ASSERT(toAddNodeIn->primitive != nullptr);
  524. if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
  525. auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
  526. MS_ASSERT(prim != nullptr);
  527. postTensor->dataType = prim->srcT;
  528. toAddTensor->dataType = prim->dstT;
  529. if (prim->srcT == TypeId::kNumberTypeInt8 && prim->dstT == TypeId::kNumberTypeUInt8) {
  530. toAddTensor->quantParams.front()->zeroPoint += 128;
  531. } else if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
  532. postTensor->quantParams.front()->zeroPoint += 128;
  533. }
  534. }
  535. graphT->allTensors.emplace_back(std::move(toAddTensor));
  536. size_t toAddTensorIdx = graphT->allTensors.size() - 1;
  537. auto toAddNode = opDefCopyer(toAddNodeIn.get());
  538. if (toAddNode == nullptr) {
  539. MS_LOG(ERROR) << "copy toAddNodeIn failed";
  540. *errorCode = RET_NULL_PTR;
  541. return graphT->nodes.end();
  542. }
  543. toAddNode->inputIndex.clear();
  544. toAddNode->inputIndex.push_back(postTensorIdx);
  545. toAddNode->outputIndex.clear();
  546. toAddNode->outputIndex.push_back(toAddTensorIdx);
  547. for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) {
  548. if (*iter == postTensorIdx) {
  549. *iter = toAddTensorIdx;
  550. break;
  551. }
  552. }
  553. existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
  554. existNodeIter++;
  555. } else {
  556. std::vector<std::unique_ptr<schema::CNodeT>> toAddNodes;
  557. int i = 0;
  558. for (size_t postNodeIdx : postNodeIdxes) {
  559. MS_ASSERT(graphT->nodes.size() > postNodeIdx);
  560. auto &postNode = graphT->nodes.at(postNodeIdx);
  561. MS_ASSERT(postNode != nullptr);
  562. auto &postTensor = graphT->allTensors.at(postTensorIdx);
  563. MS_ASSERT(postTensor != nullptr);
  564. // for multioutput,when one outpout as other node input,need add one more node
  565. if (IsContain(graphT->outputIndex, postTensorIdx)) {
  566. auto toAddTensor = CopyTensorDefT(postTensor);
  567. if (toAddTensor == nullptr) {
  568. MS_LOG(ERROR) << "Copy TensorT failed";
  569. *errorCode = RET_NULL_PTR;
  570. return graphT->nodes.end();
  571. }
  572. toAddTensor->nodeType = schema::NodeType_CNode;
  573. graphT->allTensors.emplace_back(std::move(toAddTensor));
  574. size_t toAddTensorIdx = graphT->allTensors.size() - 1;
  575. auto toAddNode = opDefCopyer(toAddNodeIn.get());
  576. toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++);
  577. toAddNode->inputIndex.clear();
  578. toAddNode->inputIndex.push_back(postTensorIdx);
  579. toAddNode->outputIndex.clear();
  580. toAddNode->outputIndex.push_back(toAddTensorIdx);
  581. for (auto iter = graphT->outputIndex.begin(); iter != graphT->outputIndex.end(); iter++) {
  582. if (*iter == postTensorIdx) {
  583. *iter = toAddTensorIdx;
  584. break;
  585. }
  586. }
  587. toAddNodes.emplace_back(std::move(toAddNode));
  588. }
  589. auto toAddTensor = CopyTensorDefT(postTensor);
  590. if (toAddTensor == nullptr) {
  591. MS_LOG(ERROR) << "Copy TensorT failed";
  592. *errorCode = RET_NULL_PTR;
  593. return graphT->nodes.end();
  594. }
  595. MS_ASSERT(toAddNodeIn->primitive != nullptr);
  596. if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
  597. auto prim = toAddNodeIn->primitive->value.AsQuantDTypeCast();
  598. MS_ASSERT(prim != nullptr);
  599. postTensor->dataType = prim->srcT;
  600. toAddTensor->dataType = prim->dstT;
  601. if (prim->dstT == TypeId::kNumberTypeUInt8 && prim->srcT == TypeId::kNumberTypeInt8) {
  602. toAddTensor->quantParams.front()->zeroPoint += 128;
  603. } else if (prim->srcT == TypeId::kNumberTypeUInt8 && prim->dstT == TypeId::kNumberTypeInt8) {
  604. postTensor->quantParams.front()->zeroPoint += 128;
  605. }
  606. }
  607. graphT->allTensors.emplace_back(std::move(toAddTensor));
  608. size_t toAddTensorIdx = graphT->allTensors.size() - 1;
  609. auto toAddNode = opDefCopyer(toAddNodeIn.get());
  610. if (toAddNode == nullptr) {
  611. MS_LOG(ERROR) << "copy toAddNodeIn failed";
  612. *errorCode = RET_NULL_PTR;
  613. return graphT->nodes.end();
  614. }
  615. toAddNode->name = toAddNodeIn->name + "_" + std::to_string(i++);
  616. toAddNode->inputIndex.clear();
  617. toAddNode->inputIndex.push_back(postTensorIdx);
  618. toAddNode->outputIndex.clear();
  619. toAddNode->outputIndex.push_back(toAddTensorIdx);
  620. MS_ASSERT(IsContain(postNode->inputIndex, postTensorIdx));
  621. for (auto iter = postNode->inputIndex.begin(); iter != postNode->inputIndex.end(); iter++) {
  622. if (*iter == postTensorIdx) {
  623. *iter = toAddTensorIdx;
  624. break;
  625. }
  626. }
  627. toAddNodes.emplace_back(std::move(toAddNode));
  628. }
  629. for (auto &toAddNode : toAddNodes) {
  630. existNodeIter = graphT->nodes.insert(existNodeIter, std::move(toAddNode));
  631. existNodeIter++;
  632. }
  633. }
  634. *errorCode = RET_OK;
  635. return existNodeIter;
  636. }
  637. STATUS ValidateFileStr(const std::string &modelFile, const std::string &fileType) {
  638. if (modelFile.size() > fileType.size() && modelFile.substr(modelFile.size() - fileType.size()) == fileType) {
  639. return RET_OK;
  640. } else {
  641. return RET_ERROR;
  642. }
  643. }
  644. void TransformAttrByAxes(int *origin_attr, int *axes, int element_size) {
  645. if (origin_attr == nullptr || axes == nullptr || element_size == 0) {
  646. MS_LOG(INFO) << "Attr data is from other nodes.";
  647. return;
  648. }
  649. auto axis_map = GetNc2NhAxisMap();
  650. std::vector<int> cur_attr;
  651. for (int dim = 0; dim < 4; ++dim) {
  652. for (int index = 0; index < element_size; ++index) {
  653. int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]];
  654. if (nhwc_dim == dim || (nhwc_dim + 4) == dim) {
  655. cur_attr.push_back(origin_attr[index]);
  656. }
  657. }
  658. }
  659. for (int index = 0; index < element_size; ++index) {
  660. origin_attr[index] = cur_attr[index];
  661. }
  662. }
  663. STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) {
  664. auto type = node->primitive->value.type;
  665. if (type == schema::PrimitiveType_StridedSlice) {
  666. // onnx input size is equal to 5 always.
  667. if (node->inputIndex.size() == 5) {
  668. for (int index = 1; index < 5; ++index) {
  669. if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) {
  670. MS_LOG(INFO) << "Here don't consider input is from other nodes.";
  671. return RET_NOT_SUPPORT;
  672. }
  673. }
  674. int element_num = graph->allTensors[node->inputIndex[1]]->dims[0];
  675. auto axes = graph->allTensors[node->inputIndex[3]]->data;
  676. for (int index = 1; index < 5; ++index) {
  677. TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()),
  678. reinterpret_cast<int *>(axes.data()), element_num);
  679. }
  680. }
  681. }
  682. if (type == schema::PrimitiveType_Slice) {
  683. auto attr = node->primitive->value.AsSlice();
  684. if (attr == nullptr) {
  685. MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr.";
  686. return RET_NULL_PTR;
  687. }
  688. // transform attr
  689. attr->format = schema::Format_NHWC;
  690. if (attr->begin.empty() || attr->size.empty()) {
  691. MS_LOG(INFO) << "Here don't consider these attr are from other nodes.";
  692. return RET_NOT_SUPPORT;
  693. }
  694. int element_num = attr->begin.size();
  695. if (attr->axes.empty()) {
  696. for (int index = 0; index < element_num; ++index) {
  697. attr->axes.push_back(index);
  698. }
  699. }
  700. TransformAttrByAxes(attr->begin.data(), attr->axes.data(), element_num);
  701. TransformAttrByAxes(attr->size.data(), attr->axes.data(), element_num);
  702. TransformAttrByAxes(attr->axes.data(), attr->axes.data(), element_num);
  703. }
  704. return RET_OK;
  705. }
  706. STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) {
  707. MS_ASSERT(node->primitive != nullptr);
  708. auto type = node->primitive->value.type;
  709. auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size();
  710. if (input1_ndim != 4) {
  711. if (node->inputIndex.size() > 1) {
  712. auto input2_ndim = graph->allTensors.at(node->inputIndex[1])->dims.size();
  713. if (input2_ndim != 4 && input2_ndim != 0) {
  714. MS_LOG(ERROR) << "change op axis only support 4 dims";
  715. return RET_NOT_SUPPORT;
  716. }
  717. } else {
  718. MS_LOG(ERROR) << "change op axis only support 4 dims";
  719. return RET_NOT_SUPPORT;
  720. }
  721. }
  722. if (type == schema::PrimitiveType_Concat) {
  723. MS_ASSERT(node->primitive->value.AsConcat() != nullptr);
  724. auto origin_axis = node->primitive->value.AsConcat()->axis;
  725. auto axis_map = GetNc2NhAxisMap();
  726. if (node->primitive->value.AsConcat() == nullptr) {
  727. MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr";
  728. return RET_NULL_PTR;
  729. }
  730. node->primitive->value.AsConcat()->axis = axis_map[origin_axis < 0 ? origin_axis + 4 : origin_axis];
  731. }
  732. if (type == schema::PrimitiveType_Split) {
  733. MS_ASSERT(node->primitive->value.AsSplit() != nullptr);
  734. auto origin_axis = node->primitive->value.AsSplit()->splitDim;
  735. auto axis_map = GetNc2NhAxisMap();
  736. if (node->primitive->value.AsSplit() == nullptr) {
  737. MS_LOG(ERROR) << "node->primitive->value.AsSplit() is nullptr";
  738. return RET_NULL_PTR;
  739. }
  740. node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis];
  741. }
  742. if (type == schema::PrimitiveType_Crop) {
  743. MS_ASSERT(node->primitive->value.AsCrop() != nullptr);
  744. auto origin_axis = node->primitive->value.AsCrop()->axis;
  745. auto offsets = node->primitive->value.AsCrop()->offsets;
  746. auto axis_map = GetNc2NhAxisMap();
  747. if (node->primitive->value.AsCrop() == nullptr) {
  748. MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr";
  749. return RET_NULL_PTR;
  750. }
  751. // nchw->nhwc,offsets need pad 0;
  752. if (axis_map[origin_axis] == 0) {
  753. offsets = {offsets[0], offsets[2], offsets[3], offsets[1]};
  754. } else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) {
  755. // orgin_axis = 2 or orgin_axis = 3
  756. offsets.push_back(0);
  757. } else if (axis_map[origin_axis] == -1) {
  758. // origin_axis = 1
  759. offsets = {offsets[1], offsets[2], offsets[0]};
  760. } else {
  761. // axis error
  762. MS_LOG(ERROR) << "Crop error";
  763. return RET_ERROR;
  764. }
  765. node->primitive->value.AsCrop()->offsets = offsets;
  766. }
  767. if (type == schema::PrimitiveType_Slice || type == schema::PrimitiveType_StridedSlice) {
  768. return ChangeOpAttrForSlice(graph, node);
  769. }
  770. return RET_OK;
  771. }
  772. std::string GetModelName(const std::string &modelFile) {
  773. std::string modelName = modelFile;
  774. modelName = modelName.substr(modelName.find_last_of('/') + 1);
  775. modelName = modelName.substr(0, modelName.find_last_of('.'));
  776. return modelName;
  777. }
  778. int SetSubgraphTensorIndices(schema::MetaGraphT *meta_graphT) {
  779. for (auto &subgraph : meta_graphT->subGraph) {
  780. std::vector<uint32_t> subgraph_indices{};
  781. for (auto &node_idx : subgraph->nodeIndices) {
  782. auto &node = meta_graphT->nodes.at(node_idx);
  783. for (auto &input_idx : node->inputIndex) {
  784. if (IsContain(subgraph_indices, input_idx)) {
  785. continue;
  786. } else {
  787. subgraph_indices.push_back(input_idx);
  788. }
  789. }
  790. for (auto &output_idx : node->outputIndex) {
  791. if (IsContain(subgraph_indices, output_idx)) {
  792. continue;
  793. } else {
  794. subgraph_indices.push_back(output_idx);
  795. }
  796. }
  797. }
  798. subgraph->tensorIndices.assign(subgraph_indices.begin(), subgraph_indices.end());
  799. }
  800. return RET_OK;
  801. }
  802. } // namespace lite
  803. } // namespace mindspore