You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_copy_graph.cc 82 kB

6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
5 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
5 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/preprocess/multi_batch_copy_graph.h"
  17. #include <queue>
  18. #include <set>
  19. #include <string>
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "common/ge/ge_util.h"
  22. #include "common/util/error_manager/error_manager.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "framework/common/ge_inner_error_codes.h"
  25. #include "framework/common/string_util.h"
  26. #include "framework/common/types.h"
  27. #include "framework/omg/omg_inner_types.h"
  28. #include "graph/debug/ge_attr_define.h"
  29. #include "graph/ge_context.h"
  30. #include "graph/passes/data_pass.h"
  31. #include "graph/passes/multi_batch_clone_pass.h"
  32. #include "graph/passes/prune_pass.h"
  33. #include "graph/preprocess/multi_batch_options.h"
  34. #include "graph/utils/attr_utils.h"
  35. #include "graph/utils/graph_utils.h"
  36. #include "graph/utils/node_utils.h"
  37. #include "graph/utils/tensor_utils.h"
  38. #include "graph/utils/type_utils.h"
  39. #include "inc/pass_manager.h"
  40. #include "graph/common/local_context.h"
  41. #include "graph/common/omg_util.h"
  42. using std::set;
  43. using std::string;
  44. using std::vector;
  45. using std::map;
  46. using std::queue;
  47. namespace ge {
  48. namespace multibatch {
  49. namespace {
  50. const char *const kMbatchSwitchnName = "mbatch-switch-name";
  51. const char *const kGetNextName = "IteratorV2";
  52. const int kSwitchNDataIndex = 0;
  53. const int kSwitchNPredIndex = 1;
  54. const int kDataOutIndex = 0;
  55. const int kDataInIndex = 0;
  56. const int kMergeDataOutIndex = 0;
  57. const int kStaticOutput = -1;
  58. const int kDivisionConst = 2;
  59. const int32_t kOneInDataNode = 1;
  60. const int32_t kFindNoMatch = 0;
  61. inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); }
  62. inline bool IsEnterType(const string &node_type) { return (node_type == ENTER) || (node_type == REFENTER); }
  63. const set<string> unchange_types({CONSTANT, CONSTANTOP, ENTER, REFENTER});
  64. inline bool IsGetNextType(const NodePtr &node) {
  65. std::string original_type;
  66. GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS,
  67. GELOGW("Get original type failed"); return false);
  68. return (original_type == kGetNextName);
  69. }
  70. NodePtr InsertMergeNodeToGraph(const std::string &name, size_t input_num, const ComputeGraphPtr &graph) {
  71. OpDescPtr desc = MakeShared<OpDesc>();
  72. if (desc == nullptr) {
  73. GELOGE(OUT_OF_MEMORY, "Failed to insert merge node, name %s", name.c_str());
  74. return nullptr;
  75. }
  76. desc->SetName(name);
  77. desc->SetType(MERGE);
  78. GeTensorDesc tensor_desc;
  79. for (size_t i = 0; i < input_num; ++i) {
  80. auto ret = desc->AddInputDesc("x" + std::to_string(i), tensor_desc);
  81. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  82. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add input %zu, error-code %u",
  83. name.c_str(), i, ret);
  84. return nullptr);
  85. }
  86. auto ret = desc->AddOutputDesc("y", tensor_desc);
  87. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  88. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add output 'y', error-code %u",
  89. name.c_str(), ret);
  90. return nullptr);
  91. tensor_desc.SetDataType(DT_INT32);
  92. ret = desc->AddOutputDesc("value_index", tensor_desc);
  93. if (ret != GRAPH_SUCCESS) {
  94. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add output 'value_index', error-code %u",
  95. name.c_str(), ret);
  96. return nullptr;
  97. }
  98. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  99. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add attr", name.c_str());
  100. return nullptr;
  101. }
  102. return graph->AddNode(desc);
  103. }
  104. NodePtr InsertCopyNode(const NodePtr &node, size_t n) {
  105. const std::string &name = node->GetName() + "_ascend_mbatch_batch_" + std::to_string(n);
  106. auto src_op_desc = node->GetOpDesc();
  107. GE_IF_BOOL_EXEC(src_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "Failed to copy node %s to %s, the OpDesc is null",
  108. node->GetName().c_str(), name.c_str());
  109. return nullptr);
  110. auto desc = AttrUtils::CopyOpDesc(src_op_desc);
  111. GE_IF_BOOL_EXEC(desc == nullptr, GELOGE(OUT_OF_MEMORY, "Failed to create op desc for copy node for node %s name %s",
  112. node->GetName().c_str(), name.c_str());
  113. return nullptr);
  114. desc->SetName(name);
  115. desc->CopyAttrsFrom(*src_op_desc);
  116. for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) {
  117. auto input_desc = desc->MutableInputDesc(i);
  118. GE_IF_BOOL_EXEC(input_desc == nullptr,
  119. GELOGW("Get null input desc by index %u from node %s when copy from %s", i,
  120. desc->GetName().c_str(), node->GetName().c_str());
  121. continue);
  122. input_desc->CopyAttrsFrom(src_op_desc->GetInputDesc(i));
  123. }
  124. for (uint32_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
  125. auto output_desc = desc->MutableOutputDesc(i);
  126. GE_IF_BOOL_EXEC(output_desc == nullptr,
  127. GELOGE(INTERNAL_ERROR, "Failed to get output desc by index %u from node %s when copy from %s", i,
  128. desc->GetName().c_str(), node->GetName().c_str());
  129. return nullptr);
  130. output_desc->CopyAttrsFrom(src_op_desc->GetOutputDesc(i));
  131. }
  132. const std::string &batch_label = "Batch_" + std::to_string(n);
  133. if (!AttrUtils::SetStr(desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  134. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", name.c_str());
  135. return nullptr;
  136. }
  137. (void)AttrUtils::SetListStr(desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, {node->GetName()});
  138. auto graph = node->GetOwnerComputeGraph();
  139. return graph->AddNode(desc);
  140. }
  141. bool IsAllDimsPositive(const std::vector<int64_t> &dims) {
  142. for (auto dim : dims) {
  143. if (dim < 0) {
  144. return false;
  145. }
  146. }
  147. return true;
  148. }
  149. NodePtr InsertConst(const std::string &name, const ComputeGraphPtr &graph) {
  150. auto desc = MakeShared<OpDesc>();
  151. if (desc == nullptr) {
  152. GELOGE(OUT_OF_MEMORY, "Failed to create const op %s, out of memory", name.c_str());
  153. return nullptr;
  154. }
  155. desc->SetName(name);
  156. desc->SetType(CONSTANT);
  157. GeTensor tensor;
  158. tensor.SetData(std::vector<uint8_t>({0}));
  159. if (!AttrUtils::SetTensor(desc, ATTR_NAME_WEIGHTS, tensor)) {
  160. GELOGE(OUT_OF_MEMORY, "Failed to init tensor value for const %s", name.c_str());
  161. return nullptr;
  162. }
  163. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  164. GELOGE(OUT_OF_MEMORY, "Failed to set insert flag for const node %s", name.c_str());
  165. return nullptr;
  166. }
  167. if (desc->AddOutputDesc(GeTensorDesc()) != GRAPH_SUCCESS) {
  168. GELOGE(OUT_OF_MEMORY, "Failed to add output desc for const node %s", name.c_str());
  169. return nullptr;
  170. }
  171. return graph->AddNode(desc);
  172. }
  173. bool IsOnlyOutputToAipp(const NodePtr &node) {
  174. for (const auto &out_node : node->GetOutDataNodes()) {
  175. if (out_node->GetType() != AIPP) {
  176. return false;
  177. }
  178. }
  179. return true;
  180. }
  181. } // namespace
  182. Status MultiBatchGraphCopyer::CopyGraph() {
  183. auto ret = Init();
  184. if (ret != SUCCESS) {
  185. return ret;
  186. }
  187. if (LabelStatus() != SUCCESS) {
  188. GELOGE(INTERNAL_ERROR, "Failed to label status for all nodes.");
  189. return INTERNAL_ERROR;
  190. }
  191. ret = CheckAndParseDynamicData();
  192. if (ret != SUCCESS) {
  193. return ret;
  194. }
  195. ret = CreateNewNodes();
  196. if (ret != SUCCESS) {
  197. return ret;
  198. }
  199. ret = LinkEdges();
  200. if (ret != SUCCESS) {
  201. return ret;
  202. }
  203. GELOGI("Begin to remove useless nodes by prune pass after copy process");
  204. PrunePass prune_pass;
  205. ret = prune_pass.Run(graph_);
  206. if (ret != SUCCESS) {
  207. GELOGE(ret, "Failed to prune");
  208. return ret;
  209. }
  210. return CheckCopyResult(origin_data_nodes_);
  211. }
  212. Status MultiBatchGraphCopyer::Init() {
  213. auto ret = CheckArguments();
  214. if (ret != SUCCESS) {
  215. return ret;
  216. }
  217. ret = RelinkConstCtrlEdge();
  218. if (ret != SUCCESS) {
  219. GELOGE(FAILED, "Relink const's control edge failed.");
  220. return FAILED;
  221. }
  222. ret = ExtractUnchangedStructureOutofCycle();
  223. if (ret != SUCCESS) {
  224. GELOGE(FAILED, "Extract unchanged structure out of cycle failed.");
  225. return FAILED;
  226. }
  227. for (auto &node : graph_->GetAllNodes()) {
  228. origin_all_nodes_.emplace_back(node);
  229. if (IsDataLikeType(node->GetType())) {
  230. origin_data_nodes_.emplace_back(node);
  231. }
  232. if (!GetLocalOmgContext().dynamic_node_type.empty() && IsGetNextType(node)) {
  233. origin_data_nodes_.emplace_back(node);
  234. }
  235. }
  236. return SUCCESS;
  237. }
  238. Status MultiBatchGraphCopyer::RelinkConstCtrlEdge() {
  239. for (auto &node : graph_->GetAllNodes()) {
  240. GE_CHECK_NOTNULL(node);
  241. if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) {
  242. if (node->GetOutDataNodes().empty()) {
  243. continue;
  244. }
  245. if (!node->GetInControlNodes().empty()) {
  246. auto in_ctrl_nodes = node->GetInControlNodes();
  247. auto out_nodes = node->GetOutAllNodes();
  248. bool has_merge_out = false;
  249. for (const auto &out_node : out_nodes) {
  250. GE_CHECK_NOTNULL(out_node);
  251. if (out_node->GetType() == MERGE || out_node->GetType() == REFMERGE) {
  252. has_merge_out = true;
  253. break;
  254. }
  255. }
  256. if (has_merge_out) {
  257. continue;
  258. }
  259. auto in_ctrl_anchor = node->GetInControlAnchor();
  260. GE_CHECK_NOTNULL(in_ctrl_anchor);
  261. in_ctrl_anchor->UnlinkAll();
  262. for (auto &in_ctrl_node : in_ctrl_nodes) {
  263. auto out_ctrl_anchor_of_in_ctrl_node = in_ctrl_node->GetOutControlAnchor();
  264. GE_CHECK_NOTNULL(out_ctrl_anchor_of_in_ctrl_node);
  265. for (auto &out_node : out_nodes) {
  266. if (IsEnterType(out_node->GetType())) {
  267. continue;
  268. }
  269. if (!out_ctrl_anchor_of_in_ctrl_node->IsLinkedWith(out_node->GetInControlAnchor())) {
  270. GE_CHK_STATUS_RET(out_ctrl_anchor_of_in_ctrl_node->LinkTo(out_node->GetInControlAnchor()))
  271. }
  272. }
  273. }
  274. }
  275. auto out_ctrl_anchor = node->GetOutControlAnchor();
  276. if (out_ctrl_anchor != nullptr) {
  277. out_ctrl_anchor->UnlinkAll();
  278. }
  279. }
  280. }
  281. return SUCCESS;
  282. }
  283. Status MultiBatchGraphCopyer::ExtractUnchangedStructureOutofCycle() {
  284. map<string, vector<NodePtr>> frame_enter;
  285. if (GetEnterNodesGroupByFrame(frame_enter) != SUCCESS) {
  286. GELOGE(FAILED, "Get enter nodes grouped by frame_name failed.");
  287. return FAILED;
  288. }
  289. queue<NodePtr> nodes_to_extract;
  290. if (GetNodeNeedExtract(frame_enter, nodes_to_extract) != SUCCESS) {
  291. GELOGE(FAILED, "Get nodes needed to extract failed.");
  292. return FAILED;
  293. }
  294. while (!nodes_to_extract.empty()) {
  295. auto node = nodes_to_extract.front();
  296. nodes_to_extract.pop();
  297. OpDescPtr enter_desc = nullptr;
  298. if (MoveInEntersInDataAnchorDown(node, enter_desc) != SUCCESS) {
  299. GELOGE(FAILED, "Move in enter nodes' in data anchors down of %s failed.", node->GetName().c_str());
  300. return FAILED;
  301. }
  302. set<NodePtr> out_nodes;
  303. if (InsertEnterAfterNode(node, enter_desc, out_nodes) != SUCCESS) {
  304. GELOGE(FAILED, "Insert enter node after %s failed.", node->GetName().c_str());
  305. return FAILED;
  306. }
  307. if (MoveCtrlEdgeToOutNodes(node, out_nodes) != SUCCESS) {
  308. GELOGE(FAILED, "Move %s's control edge to out nodes failed.", node->GetName().c_str());
  309. return FAILED;
  310. }
  311. for (auto &out_node : out_nodes) {
  312. GE_CHECK_NOTNULL(out_node);
  313. if (AllInDataNodesUnchangeAndNoMergeOut(out_node)) {
  314. nodes_to_extract.push(out_node);
  315. }
  316. }
  317. }
  318. if (DeleteEnterWithoutDataOut() != SUCCESS) {
  319. GELOGE(FAILED, "Delete enter node without out data nodes failed.");
  320. return FAILED;
  321. }
  322. return SUCCESS;
  323. }
  324. Status MultiBatchGraphCopyer::GetEnterNodesGroupByFrame(map<string, vector<NodePtr>> &frame_enter) {
  325. for (auto &node : graph_->GetAllNodes()) {
  326. GE_CHECK_NOTNULL(node);
  327. if (IsEnterType(node->GetType())) {
  328. if (!node->GetInControlNodes().empty() || !node->GetOutControlNodes().empty()) {
  329. continue;
  330. }
  331. auto op_desc = node->GetOpDesc();
  332. GE_CHECK_NOTNULL(op_desc);
  333. string frame_name;
  334. if (!AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) {
  335. GELOGE(FAILED, "Get attr frame_name of enter[%] failed.", node->GetName().c_str());
  336. return FAILED;
  337. }
  338. frame_enter[frame_name].emplace_back(node);
  339. }
  340. }
  341. return SUCCESS;
  342. }
  343. Status MultiBatchGraphCopyer::GetNodeNeedExtract(const map<string, vector<NodePtr>> &frame_enter,
  344. queue<NodePtr> &nodes_to_extract) {
  345. for (const auto &one_group : frame_enter) {
  346. auto enters = one_group.second;
  347. for (const auto &enter : enters) {
  348. auto out_data_nodes = enter->GetOutDataNodes();
  349. for (const auto &out_data_node : out_data_nodes) {
  350. GE_CHECK_NOTNULL(out_data_node);
  351. if (AllInDataNodesUnchangeAndNoMergeOut(out_data_node)) {
  352. nodes_to_extract.push(out_data_node);
  353. }
  354. }
  355. }
  356. }
  357. return SUCCESS;
  358. }
  359. bool MultiBatchGraphCopyer::AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node) {
  360. auto out_data_nodes = node->GetOutDataNodes();
  361. for (const auto &out_data_node : out_data_nodes) {
  362. if (out_data_node == nullptr) {
  363. return false;
  364. }
  365. if (out_data_node->GetType() == MERGE || out_data_node->GetType() == REFMERGE) {
  366. return false;
  367. }
  368. }
  369. auto in_data_nodes = node->GetInDataNodes();
  370. if (in_data_nodes.size() == kOneInDataNode) {
  371. return true;
  372. }
  373. for (const auto &in_data_node : in_data_nodes) {
  374. if (in_data_node == nullptr) {
  375. return false;
  376. }
  377. if (unchange_types.count(in_data_node->GetType()) == kFindNoMatch) {
  378. return false;
  379. }
  380. }
  381. return true;
  382. }
  383. Status MultiBatchGraphCopyer::MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc) {
  384. auto in_data_anchors = node->GetAllInDataAnchors();
  385. for (auto &in_data_anchor : in_data_anchors) {
  386. auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor();
  387. GE_CHECK_NOTNULL(peer_out_data_anchor);
  388. auto peer_in_data_node = peer_out_data_anchor->GetOwnerNode();
  389. if (IsEnterType(peer_in_data_node->GetType())) {
  390. GE_CHK_STATUS_RET(peer_out_data_anchor->Unlink(in_data_anchor))
  391. GELOGD("Unlink data edge from %s to %s.", peer_in_data_node->GetName().c_str(), node->GetName().c_str());
  392. auto enter_in_data_anchors = peer_in_data_node->GetAllInDataAnchors();
  393. for (auto &enter_in_data_anchor : enter_in_data_anchors) {
  394. auto peer_out_data_anchor_of_enter = enter_in_data_anchor->GetPeerOutAnchor();
  395. GE_CHECK_NOTNULL(peer_out_data_anchor_of_enter);
  396. if (peer_out_data_anchor_of_enter->IsLinkedWith(in_data_anchor)) {
  397. continue;
  398. }
  399. GE_CHK_STATUS_RET(peer_out_data_anchor_of_enter->LinkTo(in_data_anchor))
  400. GELOGD("Relink data edge from %s to %s.", peer_out_data_anchor_of_enter->GetOwnerNode()->GetName().c_str(),
  401. node->GetName().c_str());
  402. }
  403. enter_desc = peer_in_data_node->GetOpDesc();
  404. GE_CHECK_NOTNULL(enter_desc);
  405. }
  406. }
  407. return SUCCESS;
  408. }
  409. Status MultiBatchGraphCopyer::InsertEnterAfterNode(NodePtr &node, const OpDescPtr &copy_desc, set<NodePtr> &out_nodes) {
  410. if (copy_desc == nullptr) {
  411. return SUCCESS;
  412. }
  413. map<OutDataAnchorPtr, vector<std::pair<InDataAnchorPtr, NodePtr>>> outanchors_inanchors_nodes;
  414. auto out_data_anchors = node->GetAllOutDataAnchors();
  415. for (auto &out_data_anchor : out_data_anchors) {
  416. auto peer_in_data_anchors = out_data_anchor->GetPeerInDataAnchors();
  417. for (auto peer_in_data_anchor : peer_in_data_anchors) {
  418. GE_CHECK_NOTNULL(peer_in_data_anchor);
  419. auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode();
  420. out_nodes.emplace(peer_in_data_node);
  421. outanchors_inanchors_nodes[out_data_anchor].emplace_back(std::make_pair(peer_in_data_anchor, peer_in_data_node));
  422. }
  423. }
  424. int32_t i = 0;
  425. auto node_desc = node->GetOpDesc();
  426. GE_CHECK_NOTNULL(node_desc);
  427. // Insert one enter node after node's per out data anchor
  428. for (auto &outanchor_inanchors_nodes : outanchors_inanchors_nodes) {
  429. string name = node->GetName() + "_" + ENTER + "_" + std::to_string(i++);
  430. GELOGD("Create Enter op %s after %s.", name.c_str(), node->GetName().c_str());
  431. auto enter_desc = AttrUtils::CopyOpDesc(copy_desc);
  432. enter_desc->SetName(name);
  433. GE_CHK_STATUS_RET(
  434. enter_desc->UpdateInputDesc("x", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx())))
  435. GE_CHK_STATUS_RET(
  436. enter_desc->UpdateOutputDesc("y", node_desc->GetOutputDesc(outanchor_inanchors_nodes.first->GetIdx())))
  437. auto enter_node = graph_->AddNode(enter_desc);
  438. GE_CHECK_NOTNULL(enter_node);
  439. GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->LinkTo(enter_node->GetInDataAnchor(kDataInIndex)))
  440. GE_CHECK_NOTNULL(enter_node->GetOutDataAnchor(kDataInIndex));
  441. for (auto &inanchor_node : outanchor_inanchors_nodes.second) {
  442. GE_CHK_STATUS_RET(outanchor_inanchors_nodes.first->Unlink(inanchor_node.first))
  443. GE_CHK_STATUS_RET(enter_node->GetOutDataAnchor(kDataInIndex)->LinkTo(inanchor_node.first))
  444. GELOGD("Unlink from %s to %s, link from %s to %s then to %s.", node->GetName().c_str(),
  445. inanchor_node.second->GetName().c_str(), node->GetName().c_str(), enter_node->GetName().c_str(),
  446. inanchor_node.second->GetName().c_str());
  447. }
  448. }
  449. return SUCCESS;
  450. }
  451. // Move node's in control edges to out data nodes
  452. Status MultiBatchGraphCopyer::MoveCtrlEdgeToOutNodes(NodePtr &node, set<NodePtr> &out_nodes) {
  453. auto in_ctrl_anchor = node->GetInControlAnchor();
  454. GE_CHECK_NOTNULL(in_ctrl_anchor);
  455. auto peer_out_ctrl_anchors = in_ctrl_anchor->GetPeerOutControlAnchors();
  456. for (auto &peer_out_ctrl_anchor : peer_out_ctrl_anchors) {
  457. GE_CHK_STATUS_RET(peer_out_ctrl_anchor->Unlink(in_ctrl_anchor))
  458. GELOGD("Unlink control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
  459. node->GetName().c_str());
  460. for (auto &out_node : out_nodes) {
  461. auto in_ctrl_anchor_of_out_node = out_node->GetInControlAnchor();
  462. GE_CHECK_NOTNULL(in_ctrl_anchor_of_out_node);
  463. if (!peer_out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor_of_out_node)) {
  464. GE_CHK_STATUS_RET(peer_out_ctrl_anchor->LinkTo(in_ctrl_anchor_of_out_node))
  465. GELOGD("Link control edge from %s to %s.", peer_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
  466. out_node->GetName().c_str());
  467. }
  468. }
  469. }
  470. return SUCCESS;
  471. }
  472. Status MultiBatchGraphCopyer::DeleteEnterWithoutDataOut() {
  473. for (auto &node : graph_->GetAllNodes()) {
  474. GE_CHECK_NOTNULL(node);
  475. if (IsEnterType(node->GetType())) {
  476. auto out_nodes = node->GetOutAllNodes();
  477. if (out_nodes.empty()) {
  478. GELOGD("Delete enter node: %s which has no output.", node->GetName().c_str());
  479. GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {}))
  480. GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node))
  481. }
  482. }
  483. }
  484. return SUCCESS;
  485. }
  486. void MultiBatchGraphCopyer::LabelStatusForData(const NodePtr &data) {
  487. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  488. GELOGI("Label status for %s, shape_dims is %s.", data->GetName().c_str(),
  489. formats::JoinToString(data_shape.GetDims()).c_str());
  490. if (!IsAllDimsPositive(data_shape.GetDims())) {
  491. origin_nodes_status_[data.get()] = kNodeInBatchBranch;
  492. }
  493. }
  494. void MultiBatchGraphCopyer::LabelStatusForGetNextSink(const NodePtr &data) {
  495. auto op_desc = data->GetOpDesc();
  496. GELOGI("Out count of %s is %zu.", data->GetName().c_str(), op_desc->GetOutputsSize());
  497. size_t data_count = op_desc->GetOutputsSize() / kDivisionConst;
  498. for (size_t i = 0; i < data_count; ++i) {
  499. GeTensorDesc output_desc = op_desc->GetOutputDesc(i);
  500. GELOGD("The %zu data shape from getnext sink is %s.", i,
  501. formats::JoinToString(output_desc.GetShape().GetDims()).c_str());
  502. const auto &out_data_anchor = data->GetOutDataAnchor(i);
  503. if (out_data_anchor == nullptr) {
  504. continue;
  505. }
  506. size_t reference_times = out_data_anchor->GetPeerInDataAnchors().size();
  507. GELOGD("The %zu data has %zu referenced times.", i, reference_times);
  508. getnext_sink_dynamic_out_mapping_.emplace_back(std::make_pair(i, reference_times));
  509. if (!IsAllDimsPositive(output_desc.GetShape().GetDims())) {
  510. getnext_sink_dynamic_dims_ = true;
  511. }
  512. }
  513. if (getnext_sink_dynamic_dims_) {
  514. origin_nodes_status_[data.get()] = kNodeInBatchBranch;
  515. }
  516. }
  517. Status MultiBatchGraphCopyer::LabelInBatchBranchStatus() {
  518. GELOGD("Start label in batch branch status.");
  519. for (const auto &data : origin_data_nodes_) {
  520. auto op_desc = data->GetOpDesc();
  521. GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(PARAM_INVALID, "Op desc is nullptr.");
  522. return PARAM_INVALID);
  523. LabelStatusForData(data);
  524. if (!GetLocalOmgContext().dynamic_node_type.empty()) {
  525. LabelStatusForGetNextSink(data);
  526. }
  527. }
  528. map<string, vector<NodePtr>> frame_enters;
  529. InitStatus(frame_enters);
  530. bool changed = true;
  531. // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch
  532. while (changed) {
  533. changed = false;
  534. for (const auto &node : origin_all_nodes_) {
  535. auto iter = origin_nodes_status_.find(node.get());
  536. if (iter != origin_nodes_status_.end()) {
  537. continue;
  538. }
  539. for (auto &in_node : node->GetInDataNodes()) {
  540. if (origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end()) {
  541. if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) {
  542. origin_nodes_status_[node.get()] == kNodeInBatchBranch;
  543. ResetEnterStatus(frame_enters, node);
  544. changed = true;
  545. }
  546. break;
  547. }
  548. }
  549. }
  550. }
  551. return SUCCESS;
  552. }
  553. void MultiBatchGraphCopyer::InitStatus(map<string, vector<NodePtr>> &frame_enters) {
  554. for (const auto &node : origin_all_nodes_) {
  555. if (!IsEnterType(node->GetType())) {
  556. continue;
  557. }
  558. auto op_desc = node->GetOpDesc();
  559. if (op_desc == nullptr) {
  560. continue;
  561. }
  562. string frame_name;
  563. if (AttrUtils::GetStr(op_desc, ENTER_ATTR_FRAME_NAME, frame_name)) {
  564. frame_enters[frame_name].emplace_back(node);
  565. }
  566. }
  567. for (const auto &data : origin_data_nodes_) {
  568. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  569. if (!IsAllDimsPositive(data_shape.GetDims())) {
  570. origin_nodes_status_[data.get()] = kNodeInBatchBranch;
  571. }
  572. }
  573. }
  574. void MultiBatchGraphCopyer::ResetEnterStatus(map<string, vector<NodePtr>> &frame_enters, const NodePtr &node) {
  575. if (!IsEnterType(node->GetType())) {
  576. return;
  577. }
  578. for (const auto &frame_enter : frame_enters) {
  579. auto &enters = frame_enter.second;
  580. if (std::find(enters.begin(), enters.end(), node) != enters.end()) {
  581. for (const auto &enter : enters) {
  582. origin_nodes_status_[enter.get()] = kNodeInBatchBranch;
  583. }
  584. break;
  585. }
  586. }
  587. }
  588. Status MultiBatchGraphCopyer::LabelStatus() {
  589. if (LabelInBatchBranchStatus() != SUCCESS) {
  590. GELOGE(PARAM_INVALID, "Failed to label no in batch branch");
  591. return PARAM_INVALID;
  592. }
  593. for (const auto &node : origin_all_nodes_) {
  594. if (!(node->GetOpDesc()->GetSubgraphInstanceNames().empty())) {
  595. origin_nodes_status_[node.get()] = kNodeNotSupportNode;
  596. continue;
  597. }
  598. if (node->GetType() == NETOUTPUT) {
  599. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  600. continue;
  601. }
  602. if (GetLocalOmgContext().dynamic_node_type.empty()) {
  603. if (IsDataLikeType(node->GetType())) {
  604. if (IsOnlyOutputToAipp(node)) {
  605. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  606. } else {
  607. origin_nodes_status_[node.get()] = kNodeStartNode;
  608. }
  609. continue;
  610. }
  611. } else {
  612. if (IsDataLikeType(node->GetType())) {
  613. origin_nodes_status_[node.get()] = kNodeStartNode;
  614. continue;
  615. }
  616. if (IsGetNextType(node)) {
  617. origin_nodes_status_[node.get()] = kNodeStartNode;
  618. continue;
  619. }
  620. }
  621. if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) {
  622. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  623. }
  624. }
  625. return SUCCESS;
  626. }
  627. Status MultiBatchGraphCopyer::CheckAndParseDynamicData(){
  628. size_t unknown_shape_count = 0;
  629. auto data_name_and_shape = GetLocalOmgContext().user_input_dims;
  630. GELOGD("raw data_name_and_shape size: %zu", data_name_and_shape.size());
  631. if (!getnext_sink_dynamic_dims_) {
  632. for (const auto &node : origin_all_nodes_) {
  633. auto data_desc = NodeUtils::GetOutputDesc(*node, kDataOutIndex);
  634. auto data_shape = data_desc.GetShape();
  635. auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" :
  636. data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others";
  637. auto data_name = node->GetName();
  638. auto branch_status = GetNodeStatus(node);
  639. if (branch_status != kNodeStartNode) {
  640. continue;
  641. }
  642. GELOGI("CheckAndParseDynamicData shape_dims is %s.", formats::JoinToString(data_shape.GetDims()).c_str());
  643. if (IsAllDimsPositive(data_shape.GetDims())) {
  644. continue;
  645. }
  646. std::vector<int64_t> data_shape_dims = data_shape.GetDims();
  647. ++unknown_shape_count;
  648. auto iter = find(data_name_order_.begin(), data_name_order_.end(), data_name);
  649. if (iter == data_name_order_.end()) {
  650. if (dynamic_type_ == DynamicType::kDynamicBatch) {
  651. auto ret = CheckDynamicBatchShape(data_shape_dims, data_name);
  652. GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic batch shape of %s.",
  653. data_name.c_str()); return PARAM_INVALID);
  654. } else if (dynamic_type_ == DynamicType::kDynamicImageSize) {
  655. auto ret = CheckDynamicImageSizeShape(data_shape_dims, data_name, data_format);
  656. GE_IF_BOOL_EXEC(ret == false, GELOGE(PARAM_INVALID, "Failed to check dynamic image size shape of %s.",
  657. data_name.c_str()); return PARAM_INVALID);
  658. } else if (dynamic_type_ == DynamicType::kDynamicDims) {
  659. ErrorManager::GetInstance().ATCReportErrMessage("E10001",
  660. {"parameter", "reason"},
  661. {"--input_shape",
  662. "all dynamic data must be set in --input_shape"});
  663. GELOGE(INTERNAL_ERROR, "data: %s shape:%s must be set int --input_shape",
  664. node->GetName().c_str(), data_shape.ToString().c_str());
  665. return INTERNAL_ERROR;
  666. }
  667. GELOGI("Data shape of %s is %s", data_name.c_str(), formats::JoinToString(data_shape_dims).c_str());
  668. data_name_and_shape.emplace_back(data_name, data_shape_dims);
  669. }
  670. }
  671. }
  672. auto ret = ParserDataToDynmaicInfo(shapes_, data_name_and_shape, data_to_dynamic_info_);
  673. GE_CHK_STATUS_RET(ret, "Failed to parse data to dynamic info.");
  674. if (!getnext_sink_dynamic_dims_ && unknown_shape_count == 0) {
  675. ErrorManager::GetInstance().ATCReportErrMessage("E10040");
  676. GELOGE(PARAM_INVALID,
  677. "Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims");
  678. return PARAM_INVALID;
  679. }
  680. return SUCCESS;
  681. }
  682. Status MultiBatchGraphCopyer::CreateNewNodes() {
  683. if (!getnext_sink_dynamic_dims_) {
  684. shape_data_ = InsertShapeDataNode();
  685. } else {
  686. shape_data_ = InsertGetDynamicDimsNode();
  687. }
  688. GE_IF_BOOL_EXEC(shape_data_ == nullptr, GELOGE(INTERNAL_ERROR, "Failed to create the shape node for multi batch");
  689. return INTERNAL_ERROR);
  690. GE_CHECK_NOTNULL(shape_data_->GetOpDesc());
  691. for (const auto &node : origin_all_nodes_) {
  692. GE_CHECK_NOTNULL(node->GetOpDesc());
  693. auto node_type = node->GetType();
  694. Status ret = INTERNAL_ERROR;
  695. auto branch_status = GetNodeStatus(node);
  696. GELOGD("Process node %s, status %d", node->GetName().c_str(), static_cast<int>(branch_status));
  697. switch (branch_status) {
  698. case kNodeStartNode:
  699. GELOGD("Name: %s, type: %s, status: kNodeStartNode.", node->GetName().c_str(), node->GetType().c_str());
  700. ret = InsertSwitchNAndUpdateMaxShape(node);
  701. break;
  702. case kNodeInBatchBranch:
  703. GELOGD("Name: %s, type: %s, status: kNodeInBatchBranch.", node->GetName().c_str(), node->GetType().c_str());
  704. ret = CopyNodeInBatchBranch(node);
  705. break;
  706. case kNodeOutBatchBranch:
  707. GELOGD("Name: %s, type: %s, status: kNodeOutBatchBranch.", node->GetName().c_str(), node->GetType().c_str());
  708. ret = InsertMergeForEdgeNode(node);
  709. if (ret == SUCCESS) {
  710. ret = LinkGetDynamicDimsToNetOutput(node);
  711. }
  712. break;
  713. case kNodeNotSupportNode:
  714. GELOGD("Name: %s, type: %s, status: kNodeNotSupportNode.", node->GetName().c_str(), node->GetType().c_str());
  715. break;
  716. default:
  717. GELOGE(INTERNAL_ERROR, "Unexpected status %d on node %s", static_cast<int>(branch_status),
  718. node->GetName().c_str());
  719. break;
  720. }
  721. if (ret != SUCCESS) {
  722. GELOGE(ret, "Failed to deal with node %s in multi-batch process", node->GetName().c_str());
  723. return ret;
  724. }
  725. }
  726. return SUCCESS;
  727. }
  728. NodePtr MultiBatchGraphCopyer::InsertMergeNode(const NodePtr &node, int index) {
  729. if (index < 0) {
  730. // the merge node must has data inputs, if origin connection is a control
  731. // edge, we use data edge instead
  732. index = 0;
  733. }
  734. auto &merge_nodes = nodes_to_merge_nodes_[node.get()];
  735. if (merge_nodes.empty()) {
  736. auto count = node->GetAllOutDataAnchorsSize();
  737. if (count == 0) {
  738. count = 1;
  739. }
  740. merge_nodes.resize(count, nullptr);
  741. }
  742. if (merge_nodes.at(index) != nullptr) {
  743. return merge_nodes[index];
  744. }
  745. auto merge_node_name = node->GetName() + "_ascend_mbatch_merge_" + std::to_string(index);
  746. auto merge_node = InsertMergeNodeToGraph(merge_node_name, shapes_.size(), node->GetOwnerComputeGraph());
  747. GE_IF_BOOL_EXEC(merge_node == nullptr, GELOGE(INTERNAL_ERROR, "Failed to create merge node for node %s, out index %d",
  748. node->GetName().c_str(), index);
  749. return nullptr);
  750. merge_nodes[index] = merge_node;
  751. GELOGI("Create merge node %s for node %s index %d", merge_node_name.c_str(), node->GetName().c_str(), index);
  752. return merge_node;
  753. }
  754. NodePtr MultiBatchGraphCopyer::FindSwitchnNodeForDataEdge(const OutDataAnchorPtr &data_out_anchor,
  755. const NodePtr &origin_node) {
  756. auto data_node = data_out_anchor->GetOwnerNode();
  757. GELOGD("Start find switchn node insert between %s and %s", data_node->GetName().c_str(),
  758. origin_node->GetName().c_str());
  759. NodePtr switchn = nullptr;
  760. if (!getnext_sink_dynamic_dims_ && data_nodes_to_switchn_.count(data_node.get()) > 0) {
  761. switchn = data_nodes_to_switchn_[data_node.get()];
  762. return switchn;
  763. }
  764. bool is_getnext_sink_data = false;
  765. for (size_t i = 0; i < getnext_nodes_to_switchn_.size(); ++i) {
  766. for (size_t j = 0; j < getnext_nodes_to_switchn_.at(i).size(); ++j) {
  767. if (getnext_nodes_to_switchn_.at(i).at(j).first == data_node.get()) {
  768. is_getnext_sink_data = true;
  769. break;
  770. }
  771. }
  772. }
  773. // get output_idx of origin_node(getnext)
  774. if (is_getnext_sink_data) {
  775. auto output_idx = data_out_anchor->GetIdx();
  776. size_t referenced_index = 0;
  777. GELOGI("The output idx %zu has %zu referenced nums.", output_idx, data_out_anchor->GetPeerInDataAnchors().size());
  778. for (const auto &peer_in_anchor : data_out_anchor->GetPeerInDataAnchors()) {
  779. if (peer_in_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
  780. GELOGE(INTERNAL_ERROR, "Op desc should not be nullptr.");
  781. return nullptr;
  782. }
  783. if (getnext_nodes_to_switchn_.at(output_idx).empty()) {
  784. GELOGI("Output idx %zu of %s is static output.", output_idx, data_node->GetName().c_str());
  785. return nullptr;
  786. }
  787. if (output_idx >= static_cast<int>(getnext_nodes_to_switchn_.size()) ||
  788. referenced_index >= getnext_nodes_to_switchn_.at(output_idx).size()) {
  789. GELOGE(INTERNAL_ERROR, "Output idx is %zu, referenced index is %zu", output_idx, referenced_index);
  790. return nullptr;
  791. }
  792. if (peer_in_anchor->GetOwnerNode()->GetOpDesc()->GetName() == origin_node->GetName()) {
  793. switchn = getnext_nodes_to_switchn_.at(output_idx).at(referenced_index).second;
  794. GELOGI("Name of switchn is %s.", switchn->GetName().c_str());
  795. return switchn;
  796. }
  797. referenced_index++;
  798. }
  799. }
  800. return switchn;
  801. }
  802. Status MultiBatchGraphCopyer::CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr &copyed_node) {
  803. GELOGI("Start copy data edges for %s and %s.", origin_node->GetName().c_str(), copyed_node->GetName().c_str());
  804. for (auto &in_anchor : origin_node->GetAllInDataAnchors()) {
  805. auto origin_src_anchor = in_anchor->GetPeerOutAnchor();
  806. if (origin_src_anchor == nullptr) {
  807. GELOGD("The node %s does not have input on index %d", origin_node->GetName().c_str(), in_anchor->GetIdx());
  808. continue;
  809. }
  810. auto origin_src_node = origin_src_anchor->GetOwnerNode();
  811. auto dst_anchor = copyed_node->GetInDataAnchor(in_anchor->GetIdx());
  812. GE_CHECK_NOTNULL(dst_anchor);
  813. auto switchn = FindSwitchnNodeForDataEdge(origin_src_anchor, origin_node);
  814. if (switchn != nullptr) {
  815. auto ret = GraphUtils::AddEdge(switchn->GetOutDataAnchor(batch_num), dst_anchor);
  816. if (ret != GRAPH_SUCCESS) {
  817. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s(%d) to %s(%d), error-code %u",
  818. switchn->GetName().c_str(), batch_num, copyed_node->GetName().c_str(), in_anchor->GetIdx(),
  819. ret);
  820. return INTERNAL_ERROR;
  821. }
  822. GELOGD("Add data edge from %s(%d) to %s(%d)", switchn->GetName().c_str(), batch_num,
  823. copyed_node->GetName().c_str(), in_anchor->GetIdx());
  824. continue;
  825. }
  826. auto batch_branch_iter = nodes_to_batch_nodes_.find(origin_src_node.get());
  827. if (batch_branch_iter != nodes_to_batch_nodes_.end()) {
  828. auto src_batch_node = batch_branch_iter->second.at(batch_num);
  829. auto ret = GraphUtils::AddEdge(src_batch_node->GetOutDataAnchor(origin_src_anchor->GetIdx()), dst_anchor);
  830. if (ret != GRAPH_SUCCESS) {
  831. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s(%d) to %s(%d), error-code %u",
  832. src_batch_node->GetName().c_str(), batch_num, copyed_node->GetName().c_str(), in_anchor->GetIdx(), ret);
  833. return INTERNAL_ERROR;
  834. }
  835. GELOGD("Add data edge from %s(%d) to %s(%d)", src_batch_node->GetName().c_str(), batch_num,
  836. copyed_node->GetName().c_str(), in_anchor->GetIdx());
  837. continue;
  838. }
  839. auto ret = GraphUtils::AddEdge(origin_src_anchor, dst_anchor);
  840. if (ret != GRAPH_SUCCESS) {
  841. GELOGE(INTERNAL_ERROR, "Failed to add data edge between origin node %s(%d) to copyed %s(%d)",
  842. origin_src_node->GetName().c_str(), origin_src_anchor->GetIdx(), copyed_node->GetName().c_str(),
  843. dst_anchor->GetIdx());
  844. return INTERNAL_ERROR;
  845. }
  846. GELOGD("Add data edge between branch-out %s(%d) to branch-in %s(%d)", origin_src_node->GetName().c_str(),
  847. origin_src_anchor->GetIdx(), copyed_node->GetName().c_str(), dst_anchor->GetIdx());
  848. }
  849. return SUCCESS;
  850. }
  851. Status MultiBatchGraphCopyer::CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr &copyed_node) {
  852. GELOGI("Start copy control edge for %s and %s.", node->GetName().c_str(), copyed_node->GetName().c_str());
  853. for (auto &origin_src_node : node->GetInControlNodes()) {
  854. auto switchn_iter = data_nodes_to_switchn_.find(origin_src_node.get());
  855. if (switchn_iter != data_nodes_to_switchn_.end()) {
  856. // reconnect data node
  857. auto ret = GraphUtils::AddEdge(switchn_iter->second->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  858. if (ret != GRAPH_SUCCESS) {
  859. GELOGE(INTERNAL_ERROR, "Failed to add control edge between %s to %s, error-code %u",
  860. switchn_iter->second->GetName().c_str(), copyed_node->GetName().c_str(), ret);
  861. return INTERNAL_ERROR;
  862. }
  863. GELOGD("Add control edge from %s to %s", switchn_iter->second->GetName().c_str(), copyed_node->GetName().c_str());
  864. continue;
  865. }
  866. auto batch_branch_iter = nodes_to_batch_nodes_.find(origin_src_node.get());
  867. if (batch_branch_iter != nodes_to_batch_nodes_.end()) {
  868. // reconnect node in batch branch
  869. auto src_batch_node = batch_branch_iter->second.at(batch_num);
  870. auto ret = GraphUtils::AddEdge(src_batch_node->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  871. if (ret != GRAPH_SUCCESS) {
  872. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s to %s, error-code %u",
  873. src_batch_node->GetName().c_str(), copyed_node->GetName().c_str(), ret);
  874. return INTERNAL_ERROR;
  875. }
  876. GELOGD("Add control edge from %s to %s", src_batch_node->GetName().c_str(), copyed_node->GetName().c_str());
  877. continue;
  878. }
  879. auto ret = GraphUtils::AddEdge(origin_src_node->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  880. if (ret != GRAPH_SUCCESS) {
  881. GELOGE(INTERNAL_ERROR, "Failed to add control edge from origin %s to copyed %s",
  882. origin_src_node->GetName().c_str(), copyed_node->GetName().c_str());
  883. return INTERNAL_ERROR;
  884. }
  885. GELOGD("Add control edge between branch-out %s to branch-in %s", origin_src_node->GetName().c_str(),
  886. copyed_node->GetName().c_str());
  887. }
  888. return SUCCESS;
  889. }
  890. NodePtr MultiBatchGraphCopyer::InsertShapeDataNode() {
  891. auto desc = MakeShared<OpDesc>();
  892. if (desc == nullptr) {
  893. GELOGE(OUT_OF_MEMORY, "Failed to create shape data node, out of memory");
  894. return nullptr;
  895. }
  896. string node_name = "ascend_mbatch_shape_data";
  897. // Only flush subgraph name
  898. if (graph_->GetParentGraph() != nullptr) {
  899. node_name = graph_->GetName() + "_" + node_name;
  900. }
  901. desc->SetName(node_name);
  902. desc->SetType(DATA);
  903. // input and output of DATA is gear_info
  904. GeTensorDesc tensor_desc(GeShape({static_cast<int64_t>(shapes_.at(0).size())}), FORMAT_ND, DT_INT64);
  905. auto ret = desc->AddInputDesc(tensor_desc);
  906. if (ret != GRAPH_SUCCESS) {
  907. GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data");
  908. return nullptr;
  909. }
  910. ret = desc->AddOutputDesc(tensor_desc);
  911. if (ret != GRAPH_SUCCESS) {
  912. GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data");
  913. return nullptr;
  914. }
  915. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  916. GELOGE(INTERNAL_ERROR, "Failed to add attr for created data");
  917. return nullptr;
  918. }
  919. auto data_node = graph_->AddNode(desc);
  920. if (data_node == nullptr) {
  921. GELOGE(INTERNAL_ERROR, "Failed to add shape data node to graph");
  922. return nullptr;
  923. }
  924. ret = GraphUtils::AppendInputNode(graph_, data_node);
  925. if (ret != GRAPH_SUCCESS) {
  926. GELOGE(INTERNAL_ERROR, "Failed to append data node %s as input to graph", data_node->GetName().c_str());
  927. return nullptr;
  928. }
  929. return data_node;
  930. }
  931. NodePtr MultiBatchGraphCopyer::InsertGetDynamicDimsNode() {
  932. GELOGD("Start insert getdynamicdims node to get shape info.");
  933. auto desc = MakeShared<OpDesc>();
  934. if (desc == nullptr) {
  935. GELOGE(OUT_OF_MEMORY, "Failed to create shape data node, out of memory");
  936. return nullptr;
  937. }
  938. string node_name = "ascend_mbatch_get_dynamic_dims_node";
  939. // Only flush subgraph name
  940. if (graph_->GetParentGraph() != nullptr) {
  941. node_name = graph_->GetName() + "_" + node_name;
  942. }
  943. desc->SetName(node_name);
  944. desc->SetType(GETDYNAMICDIMS);
  945. // input of GetDynamicDims is shape_of_each_data, output is gear_info
  946. for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) {
  947. size_t input_shape_dims = GetLocalOmgContext().user_input_dims.at(i).second.size();
  948. if (input_shape_dims == 1 && GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) {
  949. GeTensorDesc tensor_desc;
  950. tensor_desc.SetFormat(FORMAT_ND);
  951. tensor_desc.SetDataType(DT_INT64);
  952. auto ret = desc->AddInputDesc(tensor_desc);
  953. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data");
  954. return nullptr);
  955. continue;
  956. }
  957. GeTensorDesc tensor_desc(GeShape({static_cast<int64_t>(input_shape_dims)}), FORMAT_ND, DT_INT64);
  958. auto ret = desc->AddInputDesc(tensor_desc);
  959. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data");
  960. return nullptr);
  961. }
  962. GeTensorDesc tensor_desc(GeShape({static_cast<int64_t>(shapes_.at(0).size())}), FORMAT_ND, DT_INT64);
  963. auto ret = desc->AddOutputDesc(tensor_desc);
  964. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data");
  965. return nullptr);
  966. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  967. GELOGE(INTERNAL_ERROR, "Failed to add attr for created data");
  968. return nullptr;
  969. }
  970. auto data_node = graph_->AddNode(desc);
  971. if (data_node == nullptr) {
  972. GELOGE(INTERNAL_ERROR, "Failed to add shape data node to graph");
  973. return nullptr;
  974. }
  975. ret = GraphUtils::AppendInputNode(graph_, data_node);
  976. if (ret != GRAPH_SUCCESS) {
  977. GELOGE(INTERNAL_ERROR, "Failed to append data node %s as input to graph", data_node->GetName().c_str());
  978. return nullptr;
  979. }
  980. return data_node;
  981. }
  982. Status MultiBatchGraphCopyer::CheckArguments() {
  983. if (graph_ == nullptr) {
  984. GELOGE(PARAM_INVALID, "Failed to copy graph, the graph is null");
  985. return PARAM_INVALID;
  986. }
  987. return CheckDynamicParams(shapes_);
  988. }
  989. Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_nodes) {
  990. for (auto &node : start_nodes) {
  991. if (IsOnlyOutputToAipp(node)) {
  992. continue;
  993. }
  994. auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims();
  995. if (!IsAllDimsPositive(dims)) {
  996. ErrorManager::GetInstance().ATCReportErrMessage("E15004", {"opname", "shape"},
  997. {node->GetName(), formats::ShapeToString(dims)});
  998. GELOGE(INTERNAL_ERROR, "Failed to copy multi batch graph, the node %s still has unknown shape %s",
  999. node->GetName().c_str(), formats::ShapeToString(dims).c_str());
  1000. return INTERNAL_ERROR;
  1001. }
  1002. }
  1003. return SUCCESS;
  1004. }
  1005. bool MultiBatchGraphCopyer::IsInBatchBranch(const NodePtr &node) {
  1006. if (!getnext_sink_dynamic_dims_) {
  1007. return (nodes_to_batch_nodes_.count(node.get()) > 0) || (data_nodes_to_switchn_.count(node.get()) > 0);
  1008. } else {
  1009. for (size_t i = 0; i < getnext_nodes_to_switchn_.size(); ++i) {
  1010. for (size_t j = 0; j < getnext_nodes_to_switchn_.at(i).size(); ++j) {
  1011. if (getnext_nodes_to_switchn_.at(i).at(j).first == node.get()) {
  1012. return true;
  1013. }
  1014. }
  1015. }
  1016. return nodes_to_batch_nodes_.count(node.get()) > 0;
  1017. }
  1018. }
  1019. Status MultiBatchGraphCopyer::LinkDataToMerge(const NodePtr &data, const NodePtr &merge, const NodePtr &switchn) {
  1020. // The caller should make sure that the there is a SwitchN node in the map
  1021. GELOGI("Link edge between data %s to merge %s throw switchn %s", data->GetName().c_str(), merge->GetName().c_str(),
  1022. switchn->GetName().c_str());
  1023. for (size_t i = 0; i < shapes_.size(); ++i) {
  1024. auto ret = GraphUtils::AddEdge(switchn->GetOutDataAnchor(i), merge->GetInDataAnchor(i));
  1025. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  1026. GELOGE(INTERNAL_ERROR, "Failed to add edge between switchn %s(%zu) to merge %s(%zu), error-code %u",
  1027. switchn->GetName().c_str(), i, merge->GetName().c_str(), i, ret);
  1028. return INTERNAL_ERROR);
  1029. }
  1030. return SUCCESS;
  1031. }
  1032. Status MultiBatchGraphCopyer::LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge) {
  1033. auto &copyed_nodes = nodes_to_batch_nodes_[node.get()];
  1034. if (copyed_nodes.size() != shapes_.size()) {
  1035. GELOGE(INTERNAL_ERROR,
  1036. "Failed to create merge node for node %s, the copyed nodes for it count %zu different with shape %zu",
  1037. node->GetName().c_str(), copyed_nodes.size(), shapes_.size());
  1038. return INTERNAL_ERROR;
  1039. }
  1040. for (size_t i = 0; i < copyed_nodes.size(); ++i) {
  1041. auto src_node = copyed_nodes[i];
  1042. if (src_node->GetAllOutDataAnchorsSize() == 0) {
  1043. // if the node does not has any data output, we should create an const for it, like this:
  1044. // c d
  1045. // node ---> const ---> merge
  1046. auto const_name = src_node->GetName() + "_merge_const";
  1047. GELOGI("The node %s on the batch branch edge does not have any data output, create a const %s for it",
  1048. src_node->GetName().c_str(), const_name.c_str());
  1049. auto const_node = InsertConst(const_name, graph_);
  1050. GE_IF_BOOL_EXEC(const_node == nullptr,
  1051. GELOGE(OUT_OF_MEMORY, "Failed to create const for node %s to connect to a merge node",
  1052. src_node->GetName().c_str());
  1053. return OUT_OF_MEMORY);
  1054. auto ret = GraphUtils::AddEdge(src_node->GetOutControlAnchor(), const_node->GetInControlAnchor());
  1055. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add control edge from %s to %s",
  1056. src_node->GetName().c_str(), const_node->GetName().c_str());
  1057. return INTERNAL_ERROR);
  1058. src_node = const_node;
  1059. }
  1060. auto ret = GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_index), merge->GetInDataAnchor(i));
  1061. if (ret != GRAPH_SUCCESS) {
  1062. GELOGE(INTERNAL_ERROR,
  1063. "Failed to add edge between copyed node %s(%d) to inserted merge node %s(%zu), error-code %u",
  1064. copyed_nodes[i]->GetName().c_str(), out_index, merge->GetName().c_str(), i, ret);
  1065. return INTERNAL_ERROR;
  1066. }
  1067. }
  1068. return SUCCESS;
  1069. }
  1070. Status MultiBatchGraphCopyer::InsertSwitchNAndUpdateMaxShape(const NodePtr &node) {
  1071. std::vector<std::pair<Node *, NodePtr>> dynamic_out_to_switchn;
  1072. if (!getnext_sink_dynamic_dims_) {
  1073. if (InsertSwitchNForData(node, kDataOutIndex, kDataOutIndex, dynamic_out_to_switchn) != SUCCESS) {
  1074. GELOGE(PARAM_INVALID, "Failed to insert switchn for %s.", node->GetName().c_str());
  1075. return PARAM_INVALID;
  1076. }
  1077. if (UpdateMaxShapeToData(node, kDataOutIndex) != SUCCESS) {
  1078. GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", node->GetName().c_str());
  1079. return PARAM_INVALID;
  1080. }
  1081. } else {
  1082. if (!IsGetNextType(node)) {
  1083. GELOGI("No need to insert switchn and update max shape for %s when get sink dynamic.", node->GetName().c_str());
  1084. return SUCCESS;
  1085. }
  1086. for (size_t i = 0; i < getnext_sink_dynamic_out_mapping_.size(); ++i) {
  1087. dynamic_out_to_switchn.clear();
  1088. for (size_t j = 0; j < getnext_sink_dynamic_out_mapping_.at(i).second; ++j) {
  1089. GELOGI("The %zu data_index has %zu referenced nums.", getnext_sink_dynamic_out_mapping_.at(i).first,
  1090. getnext_sink_dynamic_out_mapping_.at(i).second);
  1091. if (InsertSwitchNForData(node, getnext_sink_dynamic_out_mapping_.at(i).first, j, dynamic_out_to_switchn) !=
  1092. SUCCESS) {
  1093. GELOGE(PARAM_INVALID, "Failed to insert switchn for %s of %zu out anchor when referenced index is %zu",
  1094. node->GetName().c_str(), getnext_sink_dynamic_out_mapping_.at(i).first, j);
  1095. return PARAM_INVALID;
  1096. }
  1097. }
  1098. getnext_nodes_to_switchn_.emplace_back(dynamic_out_to_switchn);
  1099. }
  1100. for (size_t i = 0; i < getnext_sink_dynamic_out_mapping_.size(); ++i) {
  1101. if(UpdateMaxShapeToData(node, i) != SUCCESS) {
  1102. GELOGE(PARAM_INVALID, "Failed to update max shape of %zu out anchor", node->GetName().c_str(), i);
  1103. return PARAM_INVALID;
  1104. }
  1105. }
  1106. }
  1107. return SUCCESS;
  1108. }
  1109. Status MultiBatchGraphCopyer::UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index) {
  1110. auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape();
  1111. size_t shape_index = out_anchor_index + (node->GetAllOutDataAnchors().size() / kDivisionConst);
  1112. GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(shape_index);
  1113. std::vector<int64_t> output_dims = {static_cast<int64_t>(data_shape.GetDims().size())};
  1114. GeShape output_shape(output_dims);
  1115. output_desc.SetShape(output_shape);
  1116. if (node->GetOpDesc()->UpdateOutputDesc(shape_index, output_desc) != SUCCESS) {
  1117. GELOGE(FAILED, "Update output desc fail.");
  1118. return FAILED;
  1119. }
  1120. return SUCCESS;
  1121. }
  1122. Status MultiBatchGraphCopyer::UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index) {
  1123. GELOGD("Start update max shape of %s, %zu output.", node->GetName().c_str(), out_anchor_index);
  1124. auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape();
  1125. string data_name = node->GetName();
  1126. if (getnext_sink_dynamic_dims_) {
  1127. data_name.append("_").append(std::to_string(out_anchor_index));
  1128. }
  1129. GELOGD("Update max shape of %s, shape dims is %s.", data_name.c_str(),
  1130. formats::JoinToString(data_shape.GetDims()).c_str());
  1131. if (!getnext_sink_dynamic_dims_) {
  1132. if (IsAllDimsPositive(data_shape.GetDims())) {
  1133. GELOGD("No need to do anything for static data.");
  1134. return SUCCESS;
  1135. }
  1136. } else {
  1137. if (IsAllDimsPositive(data_shape.GetDims())) {
  1138. // need to update shape of Shape_node
  1139. GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(node, out_anchor_index), "Failed to update shape of shape node");
  1140. return SUCCESS;
  1141. }
  1142. }
  1143. size_t max_shape_index = 0;
  1144. int64_t max_size = 0;
  1145. for (size_t i = 0; i < shapes_.size(); ++i) {
  1146. int64_t size = 1;
  1147. for (auto dim : data_to_dynamic_info_.at(data_name).at(i)) {
  1148. if (INT64_MAX / dim < size) {
  1149. GELOGE(PARAM_INVALID, "The shape %s size overflow",
  1150. formats::ShapeToString(data_to_dynamic_info_[data_name].at(i)).c_str());
  1151. return PARAM_INVALID;
  1152. }
  1153. size *= dim;
  1154. }
  1155. if (size > max_size) {
  1156. max_size = size;
  1157. max_shape_index = i;
  1158. }
  1159. }
  1160. // must not be error, the calc result has been checked in function InsertSwitchNForData
  1161. (void)CalcShape(data_to_dynamic_info_.at(data_name).at(max_shape_index), data_shape);
  1162. auto ret = NodeUtils::UpdateOutputShape(*node, out_anchor_index, data_shape);
  1163. GE_CHK_STATUS_RET(ret, "Failed to update output shape for data %s", node->GetName().c_str());
  1164. // getnext_sink not has input
  1165. if (!getnext_sink_dynamic_dims_) {
  1166. ret = NodeUtils::UpdateInputShape(*node, kDataInIndex, data_shape);
  1167. GE_CHK_STATUS_RET(ret, "Failed to update input shape for data %s", node->GetName().c_str());
  1168. } else {
  1169. // need to update shape of Shape_node when getnext_sink_dynamic
  1170. GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(node, out_anchor_index), "Failed to update shape of shape node");
  1171. }
  1172. GELOGI("Update the data %s input/output shape to the max %s", node->GetName().c_str(),
  1173. formats::ShapeToString(data_shape).c_str());
  1174. return SUCCESS;
  1175. }
  1176. Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index,
  1177. const size_t &peer_in_anchor_index,
  1178. std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn) {
  1179. auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape();
  1180. string data_name = node->GetName();
  1181. if (getnext_sink_dynamic_dims_) {
  1182. data_name.append("_").append(std::to_string(out_anchor_index));
  1183. }
  1184. (void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  1185. GELOGI("Insert switchn node of %s, shape dims is %s.", data_name.c_str(),
  1186. formats::JoinToString(data_shape.GetDims()).c_str());
  1187. if (IsAllDimsPositive(data_shape.GetDims())) {
  1188. GELOGI("The shape of data %s are positive(%s), skip the multi batch process", node->GetName().c_str(),
  1189. data_shape.ToString().c_str());
  1190. return SUCCESS;
  1191. }
  1192. auto switchn_desc = MakeShared<OpDesc>();
  1193. GE_IF_BOOL_EXEC(switchn_desc == nullptr,
  1194. GELOGE(OUT_OF_MEMORY, "Failed to create switchn for data %s", node->GetName().c_str());
  1195. return OUT_OF_MEMORY);
  1196. string switchn_name = node->GetName() + "_ascend_mbatch_switchn";
  1197. if (getnext_sink_dynamic_dims_) {
  1198. switchn_name.append("_").append(std::to_string(out_anchor_index))
  1199. .append("_").append(std::to_string(peer_in_anchor_index));
  1200. }
  1201. GELOGI("name of switchn is %s.", switchn_name.c_str());
  1202. switchn_desc->SetName(switchn_name);
  1203. switchn_desc->SetType(SWITCHN);
  1204. GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, out_anchor_index));
  1205. GE_IF_BOOL_EXEC(switchn_desc->AddInputDesc("data", tensor) != GRAPH_SUCCESS,
  1206. GELOGE(OUT_OF_MEMORY, "Failed to add input tensor desc for %s", switchn_desc->GetName().c_str());
  1207. return OUT_OF_MEMORY);
  1208. GeTensorDesc pred_tensor;
  1209. GE_IF_BOOL_EXEC(switchn_desc->AddInputDesc("pred_value", pred_tensor) != GRAPH_SUCCESS,
  1210. GELOGE(OUT_OF_MEMORY, "Failed to add input pred tensor desc for %s", switchn_desc->GetName().c_str());
  1211. return OUT_OF_MEMORY);
  1212. std::vector<std::string> input_dims_str;
  1213. for (size_t i = 0; i < shapes_.size(); ++i) {
  1214. GELOGI("Start clac shape for data %s, batch shape is %s.", data_name.c_str(),
  1215. formats::JoinToString(data_to_dynamic_info_.at(data_name).at(i)).c_str());
  1216. auto shape = data_shape;
  1217. auto ret = CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape);
  1218. if (ret != SUCCESS) {
  1219. GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match",
  1220. node->GetName().c_str());
  1221. return ret;
  1222. }
  1223. tensor.SetShape(shape);
  1224. string input_str;
  1225. int64_t tensor_size = 0;
  1226. (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size);
  1227. input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" +
  1228. TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + node->GetName() + ":" +
  1229. std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" +
  1230. formats::JoinToString(tensor.GetShape().GetDims());
  1231. input_dims_str.emplace_back(input_str);
  1232. if (!AttrUtils::SetListInt(tensor, ATTR_NAME_SWITCHN_PRED_VALUE, shapes_.at(i))) {
  1233. GELOGE(INTERNAL_ERROR, "Failed to add attr value on output %zu tensor", i);
  1234. return INTERNAL_ERROR;
  1235. }
  1236. (void) AttrUtils::SetListInt(tensor, ATTR_NAME_COMBINED_DYNAMIC_DIMS, shape.GetDims());
  1237. if (switchn_desc->AddOutputDesc("output" + std::to_string(i), tensor) != GRAPH_SUCCESS) {
  1238. GELOGE(GRAPH_FAILED, "Opdesc AddOutputDesc failed");
  1239. return GRAPH_FAILED;
  1240. }
  1241. GELOGD("The switchn %s output index %zu, shape %s", switchn_desc->GetName().c_str(), i, shape.ToString().c_str());
  1242. }
  1243. (void)AttrUtils::SetListStr(node->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str);
  1244. if (!AttrUtils::SetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) {
  1245. GELOGE(INTERNAL_ERROR, "Failed to add user designate shape order attr on switchn node %s",
  1246. switchn_desc->GetName().c_str());
  1247. return INTERNAL_ERROR;
  1248. }
  1249. if (!AttrUtils::SetBool(switchn_desc, ATTR_INSERT_BY_MBATCH, true)) {
  1250. GELOGE(INTERNAL_ERROR, "Failed to add insert attr on switchn node %s", switchn_desc->GetName().c_str());
  1251. return INTERNAL_ERROR;
  1252. }
  1253. if (!AttrUtils::SetStr(node->GetOpDesc(), kMbatchSwitchnName, switchn_desc->GetName())) {
  1254. GELOGE(INTERNAL_ERROR, "Failed to add switchn attr on data node %s", node->GetName().c_str());
  1255. return INTERNAL_ERROR;
  1256. }
  1257. if (StampDynamicType(switchn_desc) != SUCCESS) {
  1258. GELOGE(INTERNAL_ERROR, "Failed to add dynamic type attr on switchn node %s", switchn_desc->GetName().c_str());
  1259. return INTERNAL_ERROR;
  1260. }
  1261. auto switchn = graph_->AddNode(switchn_desc);
  1262. GE_IF_BOOL_EXEC(switchn == nullptr,
  1263. GELOGE(OUT_OF_MEMORY, "Failed to create switchn %s from desc", switchn_desc->GetName().c_str());
  1264. return OUT_OF_MEMORY);
  1265. if (!getnext_sink_dynamic_dims_) {
  1266. data_nodes_to_switchn_[node.get()] = switchn;
  1267. } else {
  1268. dynamic_out_to_switchn.emplace_back(std::make_pair(node.get(), switchn));
  1269. GELOGD("Insert %s for %s.", switchn->GetName().c_str(), node->GetName().c_str());
  1270. }
  1271. return SUCCESS;
  1272. }
  1273. Status MultiBatchGraphCopyer::InsertMergeForEdgeNode(const NodePtr &node) {
  1274. for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
  1275. auto src_out_anchor = in_data_anchor->GetPeerOutAnchor();
  1276. if (src_out_anchor == nullptr) {
  1277. GELOGD("The node %s does not has input at index %d", node->GetName().c_str(), in_data_anchor->GetIdx());
  1278. continue;
  1279. }
  1280. auto in_node = src_out_anchor->GetOwnerNode();
  1281. if (!IsInBatchBranch(in_node)) {
  1282. continue;
  1283. }
  1284. auto merge_node = InsertMergeNode(in_node, src_out_anchor->GetIdx());
  1285. if (merge_node == nullptr) {
  1286. return INTERNAL_ERROR;
  1287. }
  1288. }
  1289. for (auto &in_node : node->GetInControlNodes()) {
  1290. if (!IsInBatchBranch(in_node)) {
  1291. continue;
  1292. }
  1293. auto merge_node = InsertMergeNode(in_node, -1);
  1294. if (merge_node == nullptr) {
  1295. return INTERNAL_ERROR;
  1296. }
  1297. }
  1298. return SUCCESS;
  1299. }
  1300. Status MultiBatchGraphCopyer::LinkGetDynamicDimsToNetOutput(const NodePtr &node) {
  1301. if (node->GetType() == NETOUTPUT) {
  1302. if (!GetLocalOmgContext().dynamic_node_type.empty()) {
  1303. if (!AttrUtils::SetStr(node->GetOpDesc(), ATTR_ALL_GEARS_INFO, GetLocalOmgContext().dynamic_dims)) {
  1304. GELOGE(INTERNAL_ERROR, "Failed to set all gears info attr on netoutput %s.", node->GetName().c_str());
  1305. return INTERNAL_ERROR;
  1306. }
  1307. }
  1308. if (getnext_sink_dynamic_dims_) {
  1309. size_t input_index = node->GetAllInDataAnchors().size();
  1310. if (NodeUtils::AppendInputAnchor(node, input_index + 1) != GRAPH_SUCCESS) {
  1311. GELOGE(INTERNAL_ERROR, "Append input anchor of %s of %zu failed.", node->GetName().c_str(), input_index);
  1312. return INTERNAL_ERROR;
  1313. }
  1314. auto ret =
  1315. ge::GraphUtils::AddEdge(shape_data_->GetOutDataAnchor(kDataOutIndex), node->GetInDataAnchor(input_index));
  1316. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link netoutput %s to getdynamicdims %s",
  1317. node->GetName().c_str(), shape_data_->GetName().c_str());
  1318. return INTERNAL_ERROR);
  1319. if (!AttrUtils::SetBool(node->GetOpDesc(), ATTR_GETNEXT_SINK_DYNMAIC, true)) {
  1320. GELOGE(INTERNAL_ERROR, "Failed to set getnext sink dynamic attr on netoutput %s.", node->GetName().c_str());
  1321. return INTERNAL_ERROR;
  1322. }
  1323. }
  1324. }
  1325. return SUCCESS;
  1326. }
  1327. Status MultiBatchGraphCopyer::CopyNodeInBatchBranch(const NodePtr &node) {
  1328. auto &copyed_nodes = nodes_to_batch_nodes_[node.get()];
  1329. for (size_t i = 0; i < shapes_.size(); ++i) {
  1330. auto copyed_node = InsertCopyNode(node, i);
  1331. if (copyed_node == nullptr) {
  1332. GELOGE(INTERNAL_ERROR, "Failed to add node to graph when copy node %s", node->GetName().c_str());
  1333. return INTERNAL_ERROR;
  1334. }
  1335. copyed_nodes.emplace_back(copyed_node);
  1336. GELOGI("Copy node %s type %s for shape %s, new node name %s", node->GetName().c_str(), node->GetType().c_str(),
  1337. formats::JoinToString(shapes_.at(i)).c_str(), copyed_node->GetName().c_str());
  1338. }
  1339. return SUCCESS;
  1340. }
  1341. Status MultiBatchGraphCopyer::AddAttrForGetDynamicDims(const NodePtr &node) {
  1342. GELOGD("Add attr for :%s, type is %s:", shape_data_->GetName().c_str(), shape_data_->GetType().c_str());
  1343. size_t data_count = node->GetAllOutDataAnchors().size() / kDivisionConst;
  1344. if (!AttrUtils::SetInt(shape_data_->GetOpDesc(), ATTR_GETNEXT_SINK_DATA_COUNT, data_count)) {
  1345. GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_DATA_COUNT failed");
  1346. return INTERNAL_ERROR;
  1347. }
  1348. vector<int64_t> shape_info;
  1349. for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) {
  1350. if (GetLocalOmgContext().user_input_dims.at(i).second.size() == 1 &&
  1351. GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) {
  1352. shape_info.emplace_back(0);
  1353. continue;
  1354. }
  1355. shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.size());
  1356. for (size_t j = 0; j < GetLocalOmgContext().user_input_dims.at(i).second.size(); ++j) {
  1357. shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.at(j));
  1358. }
  1359. }
  1360. if (!AttrUtils::SetListInt(shape_data_->GetOpDesc(), ATTR_GETNEXT_SINK_SHAPE_INFO, shape_info)) {
  1361. GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_SHAPE_INFO failed");
  1362. return INTERNAL_ERROR;
  1363. }
  1364. return SUCCESS;
  1365. }
  1366. Status MultiBatchGraphCopyer::AddLinkForGetDynamicDims(const NodePtr &node) {
  1367. GELOGD("Start relink out anchor from shape node to getdynamicdims, and delete link between shape node and identity.");
  1368. size_t input_index = 0;
  1369. GELOGD("Out count of %s is %zu.", node->GetName().c_str(), node->GetAllOutDataAnchors().size());
  1370. size_t data_count = node->GetAllOutDataAnchors().size() / kDivisionConst;
  1371. for (size_t out_index = data_count; out_index < node->GetAllOutDataAnchors().size(); ++out_index, ++input_index) {
  1372. GELOGI("Start add %s of %zu out_anchor to %s of %zu in_anchor.", node->GetName().c_str(), out_index,
  1373. shape_data_->GetName().c_str(), input_index);
  1374. auto out_data_anchor = node->GetOutDataAnchor(out_index);
  1375. auto ret = GraphUtils::AddEdge(out_data_anchor, shape_data_->GetInDataAnchor(input_index));
  1376. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link getnext %s to getdynamicdims %s",
  1377. node->GetName().c_str(), shape_data_->GetName().c_str());
  1378. return INTERNAL_ERROR);
  1379. }
  1380. return SUCCESS;
  1381. }
  1382. Status MultiBatchGraphCopyer::LinkEdges() {
  1383. Status ret;
  1384. for (const auto &node : origin_all_nodes_) {
  1385. GE_CHECK_NOTNULL(node->GetOpDesc());
  1386. if (!getnext_sink_dynamic_dims_) {
  1387. if (data_nodes_to_switchn_.count(node.get()) > 0) {
  1388. auto switchn = data_nodes_to_switchn_[node.get()];
  1389. GE_IF_BOOL_EXEC(switchn == nullptr,
  1390. GELOGE(PARAM_INVALID, "Switchn should not be nullptr for %s.", node->GetName().c_str());
  1391. return OUT_OF_MEMORY);
  1392. ret = LinkDataToSwitchN(node, switchn, kDataOutIndex);
  1393. GE_CHK_STATUS_RET(ret, "Link data to switchn failed.");
  1394. }
  1395. } else {
  1396. if (IsGetNextType(node)) {
  1397. GELOGD("Start add attr and link edge for %s.", node->GetName().c_str());
  1398. GE_CHK_STATUS_RET(AddAttrForGetDynamicDims(node), "Failed to add attr for %s.", node->GetName().c_str());
  1399. GE_CHK_STATUS_RET(AddLinkForGetDynamicDims(node), "Failed to add link for %s.", node->GetName().c_str());
  1400. }
  1401. for (size_t i = 0; i < getnext_nodes_to_switchn_.size(); ++i) {
  1402. for (size_t j = 0; j < getnext_nodes_to_switchn_.at(i).size(); ++j) {
  1403. if (getnext_nodes_to_switchn_.at(i).at(j).first == node.get()) {
  1404. auto switchn = getnext_nodes_to_switchn_.at(i).at(j).second;
  1405. GE_CHK_STATUS_RET(LinkDataToSwitchN(node, switchn, i), "Link %s to %s failed.", node->GetName().c_str(),
  1406. switchn->GetName().c_str());
  1407. }
  1408. }
  1409. }
  1410. }
  1411. if (nodes_to_merge_nodes_.count(node.get()) > 0) {
  1412. GE_CHK_STATUS_RET(LinkToMerge(node), "Link %s to merge failed.", node->GetName().c_str());
  1413. }
  1414. if (nodes_to_batch_nodes_.count(node.get()) > 0) {
  1415. ret = LinkToNodeInBranch(node);
  1416. } else {
  1417. ret = LinkToNodeOutBranch(node);
  1418. }
  1419. if (ret != SUCCESS) {
  1420. return ret;
  1421. }
  1422. }
  1423. return SUCCESS;
  1424. }
  1425. Status MultiBatchGraphCopyer::LinkDataToSwitchN(const NodePtr &data, const NodePtr &switchn, const int &out_index) {
  1426. auto ret =
  1427. GraphUtils::AddEdge(shape_data_->GetOutDataAnchor(kDataOutIndex), switchn->GetInDataAnchor(kSwitchNPredIndex));
  1428. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link shape data %s to switchn %s",
  1429. shape_data_->GetName().c_str(), switchn->GetName().c_str());
  1430. return INTERNAL_ERROR);
  1431. ret = GraphUtils::AddEdge(data->GetOutDataAnchor(out_index), switchn->GetInDataAnchor(kSwitchNDataIndex));
  1432. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link data %s to switchn %s",
  1433. data->GetName().c_str(), switchn->GetName().c_str());
  1434. return INTERNAL_ERROR);
  1435. return SUCCESS;
  1436. }
  1437. Status MultiBatchGraphCopyer::LinkToMerge(const NodePtr &node) {
  1438. auto &merge_nodes = nodes_to_merge_nodes_[node.get()];
  1439. for (size_t i = 0; i < merge_nodes.size(); ++i) {
  1440. auto merge_node = merge_nodes[i];
  1441. if (merge_node == nullptr) {
  1442. continue;
  1443. }
  1444. if (nodes_to_batch_nodes_.count(node.get()) > 0) {
  1445. auto ret = LinkNodeToMerge(node, i, merge_node);
  1446. if (ret != SUCCESS) {
  1447. return ret;
  1448. }
  1449. continue;
  1450. }
  1451. if (!getnext_sink_dynamic_dims_) {
  1452. if (data_nodes_to_switchn_.count(node.get()) > 0) {
  1453. auto &switchn = data_nodes_to_switchn_[node.get()];
  1454. auto ret = LinkDataToMerge(node, merge_node, switchn);
  1455. if (ret != SUCCESS) {
  1456. return ret;
  1457. }
  1458. continue;
  1459. }
  1460. } else {
  1461. for (size_t j = 0; j < getnext_nodes_to_switchn_.size(); ++j) {
  1462. for (size_t k = 0; k < getnext_nodes_to_switchn_.at(j).size(); ++k) {
  1463. if (getnext_nodes_to_switchn_.at(j).at(k).first == node.get()) {
  1464. auto &switchn = getnext_nodes_to_switchn_.at(j).at(k).second;
  1465. auto ret = LinkDataToMerge(node, merge_node, switchn);
  1466. if (ret != SUCCESS) {
  1467. return ret;
  1468. }
  1469. }
  1470. }
  1471. }
  1472. continue;
  1473. }
  1474. GELOGE(INTERNAL_ERROR, "The merge node %s is created, index %zu, but can not find the src node",
  1475. merge_node->GetName().c_str(), i);
  1476. return INTERNAL_ERROR;
  1477. }
  1478. return SUCCESS;
  1479. }
  1480. Status MultiBatchGraphCopyer::LinkToNodeInBranch(const NodePtr &node) {
  1481. GELOGI("Start LinkToNodeInBranch for %s.", node->GetName().c_str());
  1482. auto &branch_nodes = nodes_to_batch_nodes_[node.get()];
  1483. for (size_t i = 0; i < branch_nodes.size(); ++i) {
  1484. auto ret = CopyInDataEdges(node, i, branch_nodes[i]);
  1485. if (ret != SUCCESS) {
  1486. return ret;
  1487. }
  1488. ret = CopyInControlEdges(node, i, branch_nodes[i]);
  1489. if (ret != SUCCESS) {
  1490. return ret;
  1491. }
  1492. }
  1493. return SUCCESS;
  1494. }
  1495. Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) {
  1496. for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
  1497. auto src_out_anchor = in_data_anchor->GetPeerOutAnchor();
  1498. if (src_out_anchor == nullptr) {
  1499. GELOGD("The node %s does not has input at index %d", node->GetName().c_str(), in_data_anchor->GetIdx());
  1500. continue;
  1501. }
  1502. auto in_node = src_out_anchor->GetOwnerNode();
  1503. if (!IsInBatchBranch(in_node)) {
  1504. continue;
  1505. }
  1506. auto iter = nodes_to_merge_nodes_.find(in_node.get());
  1507. if (iter == nodes_to_merge_nodes_.end()) {
  1508. GELOGE(INTERNAL_ERROR, "Failed to link IO data edge from %s(%d) to %s(%d), no merge node found",
  1509. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  1510. return INTERNAL_ERROR;
  1511. }
  1512. auto merge_node = iter->second[src_out_anchor->GetIdx()];
  1513. if (merge_node == nullptr) {
  1514. GELOGE(INTERNAL_ERROR, "Failed to link IO data edge from %s(%d) to %s(%d), no merge node found",
  1515. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  1516. return INTERNAL_ERROR;
  1517. }
  1518. auto ret = src_out_anchor->Unlink(in_data_anchor);
  1519. if (ret != GRAPH_SUCCESS) {
  1520. GELOGE(INTERNAL_ERROR, "Failed to unlink the control edge from %s(%d) to %s(%d)", in_node->GetName().c_str(),
  1521. src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  1522. return INTERNAL_ERROR;
  1523. }
  1524. ret = GraphUtils::AddEdge(merge_node->GetOutDataAnchor(kMergeDataOutIndex), in_data_anchor);
  1525. if (ret != GRAPH_SUCCESS) {
  1526. GELOGE(INTERNAL_ERROR, "Failed to add data edge from %s(%d) to %s(%d)", merge_node->GetName().c_str(),
  1527. src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  1528. return INTERNAL_ERROR;
  1529. }
  1530. GELOGI("Link data edge from merge %s(from %s(%d)) to %s(%d)", merge_node->GetName().c_str(),
  1531. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  1532. }
  1533. for (auto &in_node : node->GetInControlNodes()) {
  1534. if (!IsInBatchBranch(in_node)) {
  1535. continue;
  1536. }
  1537. auto iter = nodes_to_merge_nodes_.find(in_node.get());
  1538. if (iter == nodes_to_merge_nodes_.end()) {
  1539. GELOGE(INTERNAL_ERROR, "Failed to link IO control edge from %s to %s, no merge node found",
  1540. in_node->GetName().c_str(), node->GetName().c_str());
  1541. return INTERNAL_ERROR;
  1542. }
  1543. auto merge_node = iter->second[0];
  1544. if (merge_node == nullptr) {
  1545. GELOGE(INTERNAL_ERROR, "Failed to link IO control edge from %s to %s, no merge node found",
  1546. in_node->GetName().c_str(), node->GetName().c_str());
  1547. return INTERNAL_ERROR;
  1548. }
  1549. GE_IF_BOOL_EXEC(in_node->GetOutControlAnchor() == nullptr,
  1550. GELOGE(INTERNAL_ERROR, "Innode outputControlAnchor is null");
  1551. return INTERNAL_ERROR);
  1552. auto ret = in_node->GetOutControlAnchor()->Unlink(node->GetInControlAnchor());
  1553. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to unlink the control edge from %s to %s",
  1554. in_node->GetName().c_str(), node->GetName().c_str());
  1555. return INTERNAL_ERROR);
  1556. ret = GraphUtils::AddEdge(merge_node->GetOutControlAnchor(), node->GetInControlAnchor());
  1557. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add control edge from %s to %s",
  1558. merge_node->GetName().c_str(), node->GetName().c_str());
  1559. return INTERNAL_ERROR);
  1560. GELOGI("Link control edge from merge %s(from %s) to %s", merge_node->GetName().c_str(), in_node->GetName().c_str(),
  1561. node->GetName().c_str());
  1562. }
  1563. return SUCCESS;
  1564. }
  1565. Status ProcessMultiBatch(ComputeGraphPtr &graph) {
  1566. if (GetLocalOmgContext().dynamic_node_type.empty()) {
  1567. const char *multi_batch_with_switchn = std::getenv("MULTI_BATCH_WITH_SWITCHN");
  1568. if (multi_batch_with_switchn == nullptr) {
  1569. PassManager pass_manager;
  1570. GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass));
  1571. GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass));
  1572. return pass_manager.Run(graph);
  1573. }
  1574. }
  1575. if (!GetLocalOmgContext().need_multi_batch) {
  1576. GELOGI("No need to process_multi for no_train graph.");
  1577. return SUCCESS;
  1578. }
  1579. std::vector<NodePtr> data_nodes;
  1580. std::vector<NodePtr> getnext_nosink_nodes;
  1581. std::vector<NodePtr> getnext_sink_nodes;
  1582. if (CheckSequenceOfOptions(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) {
  1583. GELOGE(PARAM_INVALID, "[Train_Dynamic] CheckSequenceOfOptions failed.");
  1584. return PARAM_INVALID;
  1585. }
  1586. if (UpdateNameOfInputShape(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) {
  1587. GELOGE(PARAM_INVALID, "[Train_Dynamic] UpdateNameForInputShapeOfOption failed.");
  1588. return PARAM_INVALID;
  1589. }
  1590. if (DeleteIdentityInsertByAdapter(graph) != SUCCESS) {
  1591. GELOGE(PARAM_INVALID, "DeleteIdentityInsertByAdapter failed.");
  1592. return PARAM_INVALID;
  1593. }
  1594. std::vector<std::vector<int64_t>> shapes;
  1595. if (!InitDynamicParams(shapes)) {
  1596. GELOGD("There is no multi-batch options, no need to process multi-batch copy");
  1597. return SUCCESS;
  1598. }
  1599. if (CheckNegativeCountOfOptions(shapes) != SUCCESS) {
  1600. GELOGE(PARAM_INVALID, "Input_shape and dynamic_dims should set correct params.");
  1601. return PARAM_INVALID;
  1602. }
  1603. DynamicType dynamic_type = DynamicType::kDynamicUnknown;
  1604. if (!GetLocalOmgContext().dynamic_batch_size.empty()) {
  1605. dynamic_type = DynamicType::kDynamicBatch;
  1606. } else if (!GetLocalOmgContext().dynamic_image_size.empty()) {
  1607. dynamic_type = DynamicType::kDynamicImageSize;
  1608. } else if (!GetLocalOmgContext().dynamic_dims.empty()) {
  1609. dynamic_type = DynamicType::kDynamicDims;
  1610. }
  1611. std::vector<std::pair<std::string, std::vector<int64_t>>> user_designate_shape;
  1612. user_designate_shape = GetLocalOmgContext().user_input_dims;
  1613. GELOGI("Begin to copy graph for multi-batch");
  1614. multibatch::MultiBatchGraphCopyer copyer(graph);
  1615. for (auto &shape : shapes) {
  1616. copyer.AddShape(shape);
  1617. }
  1618. copyer.SetDynamicType(dynamic_type);
  1619. copyer.SetUserDesignateShape(user_designate_shape);
  1620. return copyer.CopyGraph();
  1621. }
  1622. // +-----------+
  1623. // | Data | +-----------+ +-----------+ +-----------+
  1624. // +-----------+ | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  1625. // \ /. +-----------+ +-----------+ +-----------+
  1626. // \ /.
  1627. // +-----------+ +-----------+ /. +-----------+ +-----------+ +-----------+
  1628. // | Data | ----> | Case | S--- | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  1629. // +-----------+ +-----------+ \. +-----------+ +-----------+ +-----------+
  1630. // \ \.
  1631. // \ \. +-----------+ +-----------+ +-----------+
  1632. // +-----------+ | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  1633. // | NetOutput | +-----------+ +-----------+ +-----------+
  1634. // +-----------+
  1635. // +-----------+ /
  1636. // | Data | --------------->/
  1637. // +-----------+
  1638. void GetDynamicShapeByGraph(const ComputeGraphPtr &graph, const NodePtr &node,
  1639. set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1640. GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1641. const auto &func_desc = node->GetOpDesc();
  1642. if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
  1643. GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1644. return;
  1645. }
  1646. const auto &dynamic_branch_names = func_desc->GetSubgraphInstanceNames();
  1647. for (size_t i = 0; i < func_desc->GetOutputsSize(); ++i) {
  1648. for (size_t j = 0; j < dynamic_branch_names.size(); ++j) {
  1649. const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[j]);
  1650. if (subgraph == nullptr) {
  1651. GELOGE(GE_GRAPH_EMPTY_SUBGRAPH, "Subgraph not found, name: %s", dynamic_branch_names[j].c_str());
  1652. dynamic_output_dims.clear();
  1653. return;
  1654. }
  1655. const auto &out_node = subgraph->FindFirstNodeMatchType(NETOUTPUT);
  1656. if (out_node == nullptr) {
  1657. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "NetOutput not found, name: %s", dynamic_branch_names[j].c_str());
  1658. dynamic_output_dims.clear();
  1659. return;
  1660. }
  1661. GELOGI("Find the subgraph Output node %s and the index is %zu", out_node->GetName().c_str(), i);
  1662. const auto &out_desc = out_node->GetOpDesc();
  1663. if (out_desc == nullptr || out_desc->GetInputsSize() <= i) {
  1664. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Get Input desc failed, name: %s, index: %zu", out_node->GetName().c_str(), i);
  1665. dynamic_output_dims.clear();
  1666. return;
  1667. }
  1668. const auto &input_tensor = out_desc->GetInputDesc(i);
  1669. const auto &shape_msg = input_tensor.GetShape().ToString();
  1670. string output_shape = std::to_string(j) + "," + std::to_string(i) + "," + shape_msg;
  1671. GELOGI("The shape msg in dynamic batch is %s", output_shape.c_str());
  1672. dynamic_output_dims.emplace_back(output_shape);
  1673. uint32_t parent_index = 0;
  1674. (void)AttrUtils::GetInt(input_tensor, ATTR_NAME_PARENT_NODE_INDEX, parent_index);
  1675. dynamic_output_index.insert(parent_index);
  1676. }
  1677. }
  1678. }
  1679. // +-----------+ +-----------+ i = 0
  1680. // +----> | SoftmaxV2 | ----> |MemcpyAsync| ----> \.
  1681. // / +-----------+ +-----------+ \.
  1682. // / \.
  1683. // +-----------+ +-----------+ +-----------+ +-----------+ i = 1 +-----------+
  1684. // | Data | ----> | SwitchN | ----> | SoftmaxV2 | ----> |MemcpyAsync| ----> | Merge |
  1685. // +-----------+ +-----------+ +-----------+ +-----------+ +-----------+
  1686. // \ / \. j = 0
  1687. // \ +-----------+ +-----------+ i = 2 / \.
  1688. // +----> | SoftmaxV2 | ----> |MemcpyAsync| ----> / +-----------+
  1689. // +-----------+ +-----------+ | NetOutput |
  1690. // +-----------+
  1691. // +-----------+ /.
  1692. // | Data | --------------------------------------------------------------------------->/. j = 1
  1693. // +-----------+
  1694. void GetDynamicShapeByMerge(const ComputeGraphPtr &graph, const NodePtr &node,
  1695. set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1696. GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1697. const auto &netoutput_desc = node->GetOpDesc();
  1698. const auto &inputnode_to_netoutput = node->GetInAllNodes();
  1699. GELOGI("Train_Dynamic Find the merge node size is %zu.", inputnode_to_netoutput.size());
  1700. for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) {
  1701. bool insert_by_mbatch = false;
  1702. (void)AttrUtils::GetBool(inputnode_to_netoutput.at(i)->GetOpDesc(), ATTR_INSERT_BY_MBATCH, insert_by_mbatch);
  1703. GELOGI("Train_Dynamic type is %s", inputnode_to_netoutput.at(i)->GetType().c_str());
  1704. if (inputnode_to_netoutput.at(i)->GetType() == MERGE && insert_by_mbatch) {
  1705. GELOGI("Find the merge node %s with mbatch attr and the index is %zu",
  1706. inputnode_to_netoutput.at(i)->GetName().c_str(), i);
  1707. dynamic_output_index.insert(i);
  1708. for (size_t j = 0; j < inputnode_to_netoutput.at(i)->GetInNodes().size(); ++j) {
  1709. auto input_desc = inputnode_to_netoutput.at(i)->GetOpDesc();
  1710. auto input_tensor_desc = input_desc->GetInputDesc(j);
  1711. auto shape_msg = input_tensor_desc.GetShape().ToString();
  1712. string output_shape = std::to_string(j) + "," + std::to_string(i) + "," + shape_msg;
  1713. GELOGI("The shape msg in dynamic batch is %s", output_shape.c_str());
  1714. dynamic_output_dims.emplace_back(output_shape);
  1715. }
  1716. }
  1717. }
  1718. }
  1719. // Connect NetOutput directly
  1720. void GetDirectOutputShape(const ComputeGraphPtr &graph, const NodePtr &node,
  1721. const set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1722. if (!GetLocalOmgContext().dynamic_node_type.empty()) {
  1723. GELOGD("No need to get directly shape info of %s when train.", node->GetName().c_str());
  1724. return;
  1725. }
  1726. GELOGD("Try get directly shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1727. const auto &netoutput_desc = node->GetOpDesc();
  1728. const auto &inputnode_to_netoutput = node->GetInAllNodes();
  1729. for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) {
  1730. if (dynamic_output_index.count(i) > 0) {
  1731. continue;
  1732. }
  1733. auto tensor_desc = netoutput_desc->GetInputDesc(i);
  1734. auto shape = tensor_desc.GetShape().ToString();
  1735. string static_output_shape = std::to_string(kStaticOutput) + "," + std::to_string(i) + "," + shape;
  1736. GELOGI("The static output shape msg is %s", static_output_shape.c_str());
  1737. dynamic_output_dims.emplace_back(static_output_shape);
  1738. }
  1739. }
  1740. Status GetDynamicOutputShape(ComputeGraphPtr &graph) {
  1741. GE_CHECK_NOTNULL(graph);
  1742. GELOGI("Start to get output dynamic batch shape message");
  1743. NodePtr net_output;
  1744. set<size_t> dynamic_output_index;
  1745. vector<string> dynamic_output_dims;
  1746. for (auto &node : graph->GetDirectNode()) {
  1747. if (node->GetType() == NETOUTPUT) {
  1748. net_output = node;
  1749. GetDynamicShapeByMerge(graph, node, dynamic_output_index, dynamic_output_dims);
  1750. } else if (node->GetType() == CASE) {
  1751. GetDynamicShapeByGraph(graph, node, dynamic_output_index, dynamic_output_dims);
  1752. }
  1753. }
  1754. if ((net_output != nullptr) && !dynamic_output_dims.empty()) {
  1755. GetDirectOutputShape(graph, net_output, dynamic_output_index, dynamic_output_dims);
  1756. if (!AttrUtils::SetListStr(net_output->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_dims)) {
  1757. GELOGE(FAILED, "Set dynamic output dims attr failed");
  1758. return FAILED;
  1759. }
  1760. }
  1761. return SUCCESS;
  1762. }
  1763. } // namespace multibatch
  1764. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知.