You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_graph.cc 55 kB

4 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
4 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/kernel_graph.h"
  17. #include <algorithm>
  18. #include <queue>
  19. #include <set>
  20. #include <exception>
  21. #include "utils/hash_set.h"
  22. #include "base/core_ops.h"
  23. #include "ir/param_info.h"
  24. #include "utils/utils.h"
  25. #include "utils/check_convert_utils.h"
  26. #include "backend/session/anf_runtime_algorithm.h"
  27. #include "runtime/device/kernel_info.h"
  28. #include "backend/kernel_compiler/kernel_build_info.h"
  29. #include "runtime/device/kernel_runtime_manager.h"
  30. #include "backend/kernel_compiler/common_utils.h"
  31. namespace mindspore {
  32. namespace session {
  33. namespace {
  34. constexpr auto kIsFeatureMapOutput = "IsFeatureMapOutput";
  35. constexpr auto kIsFeatureMapInputList = "IsFeatureMapInputList";
  36. constexpr size_t k5dDims = 5;
  37. const std::set<std::string> kOpAssignKernelNameList = {prim::kPrimAssign->name(), prim::kPrimAssignAdd->name(),
  38. prim::kPrimAssignSub->name()};
  39. void PushNoVisitedNode(const AnfNodePtr &node, std::queue<AnfNodePtr> *que,
  40. mindspore::HashSet<AnfNodePtr> *visited_nodes) {
  41. MS_EXCEPTION_IF_NULL(node);
  42. MS_EXCEPTION_IF_NULL(que);
  43. MS_EXCEPTION_IF_NULL(visited_nodes);
  44. if (visited_nodes->find(node) == visited_nodes->end()) {
  45. que->push(node);
  46. (void)visited_nodes->insert(node);
  47. MS_LOG(DEBUG) << "Push que:" << node->DebugString();
  48. }
  49. }
  50. std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
  51. auto item_with_index =
  52. AnfAlgo::VisitKernelWithReturnType(call_node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
  53. AnfNodePtr node = item_with_index.first;
  54. MS_EXCEPTION_IF_NULL(node);
  55. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
  56. auto outputs = AnfAlgo::GetAllOutput(node);
  57. std::set<AnfNodePtr> memo;
  58. std::vector<AnfNodePtr> new_output;
  59. for (auto &output : outputs) {
  60. if (memo.find(output) != memo.end()) {
  61. continue;
  62. }
  63. memo.insert(output);
  64. new_output.push_back(output);
  65. }
  66. if (new_output.size() == 1 && AnfAlgo::CheckPrimitiveType(new_output[0], prim::kPrimCall)) {
  67. node = new_output[0];
  68. }
  69. }
  70. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimCall)) {
  71. return {node};
  72. }
  73. std::vector<AnfNodePtr> real_inputs;
  74. auto child_graphs = AnfAlgo::GetCallSwitchKernelGraph(node->cast<CNodePtr>());
  75. for (const auto &child_graph : child_graphs) {
  76. MS_EXCEPTION_IF_NULL(child_graph);
  77. auto real_input = child_graph->output();
  78. auto child_real_inputs = GetCallRealOutputs(real_input);
  79. std::copy(child_real_inputs.begin(), child_real_inputs.end(), std::back_inserter(real_inputs));
  80. }
  81. return real_inputs;
  82. }
  83. bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
  84. if (left == right) {
  85. return true;
  86. }
  87. if (left == nullptr || right == nullptr) {
  88. return false;
  89. }
  90. if (!IsPrimitiveCNode(left, GetCNodePrimitive(right))) {
  91. return false;
  92. }
  93. if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, left) && AnfAlgo::HasNodeAttr(kAttrLabelIndex, right)) {
  94. return AnfAlgo::GetNodeAttr<uint32_t>(left, kAttrLabelIndex) ==
  95. AnfAlgo::GetNodeAttr<uint32_t>(right, kAttrLabelIndex);
  96. }
  97. return false;
  98. }
  99. void SyncDeviceInfoToValueNode(const ValueNodePtr &value_node, std::vector<std::string> *device_formats,
  100. std::vector<TypeId> *device_types) {
  101. MS_EXCEPTION_IF_NULL(value_node);
  102. MS_EXCEPTION_IF_NULL(device_formats);
  103. MS_EXCEPTION_IF_NULL(device_types);
  104. ValuePtr value = value_node->value();
  105. std::vector<tensor::TensorPtr> tensors;
  106. TensorValueToTensor(value, &tensors);
  107. if (!tensors.empty()) {
  108. device_formats->clear();
  109. device_types->clear();
  110. for (const auto &tensor : tensors) {
  111. MS_EXCEPTION_IF_NULL(tensor);
  112. auto device_sync = tensor->device_address();
  113. if (device_sync != nullptr) {
  114. auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(device_sync);
  115. MS_EXCEPTION_IF_NULL(device_address);
  116. device_formats->emplace_back(device_address->format());
  117. device_types->emplace_back(device_address->type_id());
  118. continue;
  119. }
  120. device_formats->emplace_back(kOpFormat_DEFAULT);
  121. device_types->emplace_back(kTypeUnknown);
  122. }
  123. }
  124. }
  125. std::string GetNodeGroup(const AnfNodePtr &node) {
  126. MS_EXCEPTION_IF_NULL(node);
  127. auto cnode = node->cast<CNodePtr>();
  128. if (AnfAlgo::HasNodeAttr(kAttrGroup, cnode)) {
  129. return AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrGroup);
  130. }
  131. return "";
  132. }
  133. } // namespace
  134. AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) const {
  135. MS_EXCEPTION_IF_NULL(node);
  136. auto value_node = node->cast<ValueNodePtr>();
  137. if (value_node == nullptr) {
  138. return nullptr;
  139. }
  140. ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
  141. MS_EXCEPTION_IF_NULL(new_value_node);
  142. new_value_node->set_abstract(value_node->abstract());
  143. this->SetKernelInfoForNode(new_value_node);
  144. return new_value_node;
  145. }
  146. std::vector<AnfNodePtr> KernelGraph::outputs() const {
  147. auto graph_output = output();
  148. if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) {
  149. auto make_tuple = output()->cast<CNodePtr>();
  150. MS_EXCEPTION_IF_NULL(make_tuple);
  151. auto &inputs = make_tuple->inputs();
  152. return std::vector<AnfNodePtr>(inputs.begin() + 1, inputs.end());
  153. }
  154. return std::vector<AnfNodePtr>(1, graph_output);
  155. }
  156. void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
  157. mindspore::HashSet<AnfNodePtr> *visited_nodes, bool comm_first) {
  158. MS_EXCEPTION_IF_NULL(visit_queue);
  159. MS_EXCEPTION_IF_NULL(visited_nodes);
  160. auto it = node_output_edges_.find(node);
  161. if (it == node_output_edges_.end()) {
  162. // value node and parameter has no input,no need to print log
  163. if (node->isa<CNode>()) {
  164. MS_LOG(DEBUG) << "Can not find node [" << node->DebugString() << "]";
  165. }
  166. return;
  167. }
  168. // visit all reduce node first, then other nodes
  169. std::vector<AnfNodePtr> active_nodes;
  170. for (const auto &output_edge : it->second) {
  171. auto next_node = output_edge.first;
  172. MS_EXCEPTION_IF_NULL(next_node);
  173. if (node_input_num_.find(next_node) == node_input_num_.end()) {
  174. MS_LOG(EXCEPTION) << "Can't find node[" << next_node->DebugString() << "]";
  175. }
  176. MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
  177. << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
  178. if (node_input_num_[next_node] < output_edge.second) {
  179. MS_LOG(DEBUG) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node]
  180. << ",depend edge:" << output_edge.second;
  181. continue;
  182. }
  183. node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
  184. // allreduce first
  185. if (node_input_num_[next_node] == 0 && visited_nodes->find(next_node) == visited_nodes->end()) {
  186. (void)visited_nodes->insert(next_node);
  187. bool is_comm_node = AnfAlgo::IsCommunicationOp(next_node);
  188. if (AnfAlgo::CheckPrimitiveType(next_node, prim::kPrimLoad)) {
  189. EnqueueActiveNodes(next_node, visit_queue, visited_nodes);
  190. } else if ((is_comm_node && comm_first) || (!is_comm_node && !comm_first)) {
  191. MS_LOG(DEBUG) << "Visit node:" << next_node->DebugString();
  192. visit_queue->push(next_node);
  193. } else {
  194. active_nodes.emplace_back(next_node);
  195. }
  196. }
  197. }
  198. for (auto &active_node : active_nodes) {
  199. visit_queue->push(active_node);
  200. }
  201. }
  202. void KernelGraph::SetExecOrderByDefault() {
  203. std::queue<AnfNodePtr> seed_nodes;
  204. UpdateNodeEdgeList(&seed_nodes);
  205. execution_order_.clear();
  206. mindspore::HashSet<AnfNodePtr> visited_nodes;
  207. std::queue<AnfNodePtr> zero_input_nodes;
  208. std::queue<AnfNodePtr> delay_comm_stack;
  209. std::queue<AnfNodePtr> communication_descendants;
  210. std::string optimized_comm_group;
  211. while (!seed_nodes.empty() || !delay_comm_stack.empty()) {
  212. // seed nodes first, then delay comm nodes
  213. if (seed_nodes.empty()) {
  214. EnqueueActiveNodes(delay_comm_stack.front(), &communication_descendants, &visited_nodes, false);
  215. delay_comm_stack.pop();
  216. } else {
  217. zero_input_nodes.push(seed_nodes.front());
  218. seed_nodes.pop();
  219. }
  220. // comm descendant first, then common queue
  221. while (!zero_input_nodes.empty() || !communication_descendants.empty()) {
  222. AnfNodePtr node = nullptr;
  223. bool is_communication_descendant = false;
  224. if (communication_descendants.empty()) {
  225. node = zero_input_nodes.front();
  226. zero_input_nodes.pop();
  227. } else {
  228. node = communication_descendants.front();
  229. communication_descendants.pop();
  230. is_communication_descendant = true;
  231. }
  232. // add execute node
  233. MS_EXCEPTION_IF_NULL(node);
  234. if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
  235. execution_order_.push_back(node->cast<CNodePtr>());
  236. }
  237. // delay execute comm ops that need optimize
  238. bool is_fused_comm = AnfAlgo::IsFusedCommunicationOp(node);
  239. bool optimize_comm = false;
  240. if (is_fused_comm && optimized_comm_group.empty()) {
  241. auto node_group = GetNodeGroup(node);
  242. if (node_group.find(kSyncBnGroup) == string::npos) {
  243. optimized_comm_group = node_group;
  244. optimize_comm = true;
  245. }
  246. }
  247. if (optimize_comm) {
  248. while (!delay_comm_stack.empty()) {
  249. EnqueueActiveNodes(delay_comm_stack.front(), &communication_descendants, &visited_nodes, false);
  250. delay_comm_stack.pop();
  251. }
  252. delay_comm_stack.push(node);
  253. } else if (is_fused_comm) {
  254. delay_comm_stack.push(node);
  255. } else if (is_communication_descendant) {
  256. EnqueueActiveNodes(node, &communication_descendants, &visited_nodes);
  257. } else {
  258. EnqueueActiveNodes(node, &zero_input_nodes, &visited_nodes);
  259. }
  260. }
  261. }
  262. CheckLoop();
  263. // resort start label / end goto
  264. execution_order_ = SortStartLabelAndEndGoto();
  265. }
  266. std::vector<CNodePtr> KernelGraph::SortStartLabelAndEndGoto() {
  267. std::vector<CNodePtr> re_order;
  268. if (start_label_ != nullptr) {
  269. re_order.push_back(start_label_);
  270. }
  271. for (auto &node : execution_order_) {
  272. if (node == start_label_ || node == end_goto_) {
  273. continue;
  274. }
  275. if (IsSameLabel(node, end_goto_)) {
  276. end_goto_ = node;
  277. MS_LOG(INFO) << "Replace end_goto_ in kernel graph:" << graph_id();
  278. continue;
  279. }
  280. if (IsSameLabel(node, start_label_)) {
  281. start_label_ = node;
  282. MS_LOG(INFO) << "Replace start_label_ in kernel graph:" << graph_id();
  283. continue;
  284. }
  285. //
  286. // Re-order:
  287. // u = LabelGoto(...)
  288. // x = Mul(...)
  289. // LabelSet(u)
  290. // To:
  291. // u = LabelGoto(...)
  292. // LabelSet(u)
  293. // x = Mul(...)
  294. // This prevent Mul be skipped.
  295. //
  296. if (IsPrimitiveCNode(node, prim::kPrimLabelSet) && (re_order.back() != node->input(1))) {
  297. auto iter = std::find(re_order.rbegin() + 1, re_order.rend(), node->input(1));
  298. if (iter != re_order.rend()) {
  299. re_order.insert(iter.base(), node);
  300. continue;
  301. }
  302. }
  303. re_order.push_back(node);
  304. }
  305. if (end_goto_ != nullptr) {
  306. re_order.push_back(end_goto_);
  307. }
  308. return re_order;
  309. }
  310. void KernelGraph::GetLoopNodesByDFS(const AnfNodePtr &node, uint32_t *loop_num) {
  311. MS_EXCEPTION_IF_NULL(node);
  312. auto node_input_it = node_input_edges_.find(node);
  313. if (node_input_it == node_input_edges_.end()) {
  314. MS_LOG(DEBUG) << "Node [" << node->DebugString() << "] don't have input edges.";
  315. return;
  316. }
  317. if (*loop_num != 0) {
  318. return;
  319. }
  320. (void)visited_nodes_.insert(node);
  321. for (auto &input_edge : node_input_edges_[node]) {
  322. size_t input_num = node_input_num_[input_edge.first];
  323. if (input_num == 0) {
  324. continue;
  325. }
  326. if (find(visited_nodes_.begin(), visited_nodes_.end(), input_edge.first) == visited_nodes_.end()) {
  327. MS_EXCEPTION_IF_NULL(input_edge.first);
  328. edge_to_[input_edge.first] = node;
  329. GetLoopNodesByDFS(input_edge.first, loop_num);
  330. } else {
  331. AnfNodePtr node_iter = node;
  332. MS_EXCEPTION_IF_NULL(node_iter);
  333. MS_LOG(INFO) << "Print loop nodes start:";
  334. for (; node_iter != input_edge.first && node_iter != nullptr; node_iter = edge_to_[node_iter]) {
  335. loop_nodes_.push(node_iter);
  336. node_input_num_[node_iter]--;
  337. MS_LOG(INFO) << "Get loop node:" << node_iter->DebugString();
  338. }
  339. if (node_iter != nullptr) {
  340. loop_nodes_.push(node_iter);
  341. loop_nodes_.push(node);
  342. (*loop_num)++;
  343. node_input_num_[node_iter]--;
  344. MS_LOG(INFO) << "Get loop node:" << node_iter->DebugString();
  345. MS_LOG(INFO) << "Get loop node:" << node->DebugString();
  346. MS_LOG(INFO) << "Print loop nodes end, Loop num:" << *loop_num;
  347. while (!loop_nodes_.empty()) {
  348. loop_nodes_.pop();
  349. }
  350. return;
  351. }
  352. }
  353. }
  354. }
  355. uint32_t KernelGraph::GetLoopNum(const std::map<AnfNodePtr, size_t> &none_zero_nodes) {
  356. uint32_t loop_num = 0;
  357. for (auto &iter : none_zero_nodes) {
  358. auto node = iter.first;
  359. MS_EXCEPTION_IF_NULL(node);
  360. if (node_input_num_[node] == 0) {
  361. continue;
  362. }
  363. edge_to_.clear();
  364. visited_nodes_.clear();
  365. GetLoopNodesByDFS(node, &loop_num);
  366. }
  367. return loop_num;
  368. }
  369. void KernelGraph::CheckLoop() {
  370. std::map<AnfNodePtr, size_t> none_zero_nodes;
  371. if (node_input_edges_.size() != node_input_num_.size()) {
  372. MS_LOG(EXCEPTION) << "node_input_edges_ size :" << node_input_edges_.size()
  373. << "not equal to node_input_num_ size:" << node_input_num_.size();
  374. }
  375. for (auto &it : node_input_num_) {
  376. MS_EXCEPTION_IF_NULL(it.first);
  377. string str;
  378. auto node_input_it = node_input_edges_.find(it.first);
  379. if (node_input_it == node_input_edges_.end()) {
  380. MS_LOG(EXCEPTION) << "Can't find node [" << it.first->DebugString() << "]";
  381. }
  382. if (it.second != 0) {
  383. for (const auto &input_edge : node_input_edges_[it.first]) {
  384. MS_EXCEPTION_IF_NULL(input_edge.first);
  385. str = str.append(input_edge.first->DebugString()).append("|");
  386. }
  387. MS_LOG(WARNING) << "Node:" << it.first->DebugString() << ",inputs:" << str << ",input num:" << it.second;
  388. none_zero_nodes[it.first] = it.second;
  389. }
  390. }
  391. // if don't consider loop exit,a exception will be throw
  392. if (!none_zero_nodes.empty()) {
  393. MS_LOG(WARNING) << "Nums of loop:" << GetLoopNum(none_zero_nodes);
  394. MS_LOG(EXCEPTION) << "Nodes have loop, left node num:" << none_zero_nodes.size();
  395. }
  396. }
  397. CNodePtr KernelGraph::NewCNode(std::vector<AnfNodePtr> &&inputs) {
  398. auto cnode = FuncGraph::NewCNode(std::move(inputs));
  399. PostNewCNode(cnode);
  400. return cnode;
  401. }
  402. CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
  403. auto cnode = FuncGraph::NewCNode(inputs);
  404. PostNewCNode(cnode);
  405. return cnode;
  406. }
  407. void KernelGraph::PostNewCNode(const CNodePtr &cnode) {
  408. MS_EXCEPTION_IF_NULL(cnode);
  409. cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
  410. if (AnfAlgo::IsGraphKernel(cnode)) {
  411. CreateKernelInfoFromNewParameter(cnode);
  412. }
  413. if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
  414. AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
  415. }
  416. SetKernelInfoForNode(cnode);
  417. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  418. }
  419. CNodePtr KernelGraph::NewCNodeWithInfos(const std::vector<AnfNodePtr> &inputs, const CNodePtr &ori_cnode) {
  420. auto cnode = NewCNode(inputs);
  421. if (ori_cnode != nullptr) {
  422. cnode->set_attrs(ori_cnode->attrs());
  423. cnode->set_primal_attrs(ori_cnode->primal_attrs());
  424. cnode->set_primal_debug_infos(ori_cnode->primal_debug_infos());
  425. }
  426. return cnode;
  427. }
  428. void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
  429. auto func_graph = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  430. MS_EXCEPTION_IF_NULL(func_graph);
  431. std::vector<AnfNodePtr> node_list;
  432. std::vector<AnfNodePtr> input_list;
  433. std::vector<AnfNodePtr> output_list;
  434. kernel::GetValidKernelNodes(func_graph, &node_list, &input_list, &output_list);
  435. for (auto &anf_node : node_list) {
  436. MS_EXCEPTION_IF_NULL(anf_node);
  437. if (anf_node->kernel_info() == nullptr) {
  438. anf_node->set_kernel_info(std::make_shared<device::KernelInfo>());
  439. }
  440. auto anf_cnode = anf_node->cast<CNodePtr>();
  441. MS_EXCEPTION_IF_NULL(anf_cnode);
  442. size_t input_num = AnfAlgo::GetInputTensorNum(anf_cnode);
  443. for (size_t i = 0; i < input_num; ++i) {
  444. auto input_node = anf_cnode->input(i + 1);
  445. MS_EXCEPTION_IF_NULL(input_node);
  446. if (IsValueNode<tensor::Tensor>(input_node)) {
  447. auto new_input_node = MakeValueNode(input_node);
  448. if (new_input_node != nullptr) {
  449. anf_cnode->set_input(i + 1, new_input_node);
  450. }
  451. }
  452. }
  453. }
  454. for (auto &anf_node : input_list) {
  455. MS_EXCEPTION_IF_NULL(anf_node);
  456. if (anf_node->kernel_info() == nullptr) {
  457. anf_node->set_kernel_info(std::make_shared<device::KernelInfo>());
  458. }
  459. }
  460. }
  461. void KernelGraph::ResetAssignInputFeatureMapFlag(const CNodePtr &cnode) const {
  462. if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(cnode)) == kOpAssignKernelNameList.end()) {
  463. MS_LOG(EXCEPTION) << "Only supported to change the node [Assign , AssignSub, AssignAdd] node's input feature map "
  464. "flag but got the node :"
  465. << cnode->DebugString();
  466. }
  467. auto input_node = AnfAlgo::GetInputNode(cnode, 0);
  468. MS_EXCEPTION_IF_NULL(input_node);
  469. auto assign_value_node = AnfAlgo::GetInputNode(cnode, 1);
  470. if (AnfAlgo::IsFeatureMapOutput(input_node)) {
  471. return;
  472. }
  473. if (!AnfAlgo::IsFeatureMapOutput(input_node) && AnfAlgo::IsFeatureMapOutput(assign_value_node)) {
  474. auto kernel_info = dynamic_cast<device::KernelInfo *>(input_node->kernel_info());
  475. MS_EXCEPTION_IF_NULL(kernel_info);
  476. kernel_info->set_feature_map_flag(true);
  477. }
  478. }
  479. void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
  480. MS_EXCEPTION_IF_NULL(node);
  481. auto kernel_info = std::make_shared<device::KernelInfo>();
  482. MS_EXCEPTION_IF_NULL(kernel_info);
  483. node->set_kernel_info(kernel_info);
  484. if (node->isa<CNode>()) {
  485. if (kOpAssignKernelNameList.find(AnfAlgo::GetCNodeName(node)) != kOpAssignKernelNameList.end()) {
  486. ResetAssignInputFeatureMapFlag(node->cast<CNodePtr>());
  487. }
  488. #if defined(__APPLE__)
  489. std::vector<int> feature_map_input_indexs;
  490. #else
  491. std::vector<size_t> feature_map_input_indexs;
  492. #endif
  493. kernel_info->set_feature_map_flag(false);
  494. size_t input_num = AnfAlgo::GetInputTensorNum(node);
  495. for (size_t index = 0; index < input_num; ++index) {
  496. if (AnfAlgo::IsFeatureMapInput(node, index)) {
  497. kernel_info->set_feature_map_flag(true);
  498. feature_map_input_indexs.push_back(index);
  499. }
  500. }
  501. if (AnfAlgo::GetInputTensorNum(node) == 0) {
  502. kernel_info->set_feature_map_flag(true);
  503. }
  504. if (AnfUtils::IsRealKernel(node)) {
  505. // if the node only has the primitive(such as getNext) or the node's input has a feature map input
  506. // then the node's output is a feature map output
  507. AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node);
  508. AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node);
  509. }
  510. return;
  511. }
  512. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  513. MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
  514. // set the format of value_node to DEFAULT_FORMAT
  515. std::vector<TypeId> types;
  516. std::vector<std::string> formats = {kOpFormat_DEFAULT};
  517. if (node->isa<ValueNode>()) {
  518. kernel_info->set_feature_map_flag(false);
  519. (void)types.emplace_back(kTypeUnknown);
  520. auto value_node = node->cast<ValueNodePtr>();
  521. SyncDeviceInfoToValueNode(value_node, &formats, &types);
  522. }
  523. if (node->isa<Parameter>()) {
  524. auto parameter = node->cast<ParameterPtr>();
  525. MS_EXCEPTION_IF_NULL(parameter);
  526. bool is_weight = AnfAlgo::IsParameterWeight(parameter);
  527. kernel_info->set_feature_map_flag(!is_weight);
  528. types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0));
  529. }
  530. // set parameter initaial device data type
  531. kernel_build_info_builder->SetOutputsFormat(formats);
  532. kernel_build_info_builder->SetOutputsDeviceType(types);
  533. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get());
  534. }
  535. CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
  536. MS_EXCEPTION_IF_NULL(cnode);
  537. auto new_cnode = std::make_shared<CNode>(*cnode);
  538. // if a cnode is created not from front,this cnode won't be in map,so when replace it,we shouldn't update map
  539. if (BackendNodeExistInFrontBackendMap(cnode)) {
  540. FrontBackendlMapUpdate(cnode, new_cnode);
  541. }
  542. AnfAlgo::SetGraphId(graph_id_, cnode.get());
  543. return new_cnode;
  544. }
  545. ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
  546. auto abstract = parameter == nullptr ? std::make_shared<abstract::AbstractNone>() : parameter->abstract();
  547. auto new_parameter = NewParameter(abstract);
  548. // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
  549. if (parameter != nullptr) {
  550. new_parameter->set_name(parameter->name());
  551. if (AnfAlgo::IsParameterWeight(parameter)) {
  552. new_parameter->set_default_param(parameter->default_param());
  553. }
  554. }
  555. // create kernel_info form new parameter
  556. SetKernelInfoForNode(new_parameter);
  557. AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
  558. return new_parameter;
  559. }
  560. ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) {
  561. ParameterPtr new_parameter = add_parameter();
  562. new_parameter->set_abstract(abstract);
  563. // create kernel_info form new parameter
  564. SetKernelInfoForNode(new_parameter);
  565. AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
  566. return new_parameter;
  567. }
  568. ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) {
  569. MS_EXCEPTION_IF_NULL(value_node);
  570. auto new_value_node = MakeValueNode(value_node)->cast<ValueNodePtr>();
  571. AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
  572. return new_value_node;
  573. }
  574. ValueNodePtr KernelGraph::NewValueNode(const AbstractBasePtr &abstract, const ValuePtr &value) {
  575. MS_EXCEPTION_IF_NULL(abstract);
  576. MS_EXCEPTION_IF_NULL(value);
  577. ValueNodePtr new_value_node = std::make_shared<ValueNode>(value);
  578. MS_EXCEPTION_IF_NULL(new_value_node);
  579. new_value_node->set_abstract(abstract);
  580. SetKernelInfoForNode(new_value_node);
  581. AnfAlgo::SetGraphId(graph_id(), new_value_node.get());
  582. return new_value_node;
  583. }
  584. ValueNodePtr KernelGraph::NewValueNode(const tensor::TensorPtr &input_tensor) {
  585. MS_EXCEPTION_IF_NULL(input_tensor);
  586. ValueNodePtr value_node = nullptr;
  587. if (input_tensor->data_type() == kObjectTypeString) {
  588. std::string value_string;
  589. value_string.assign(reinterpret_cast<char *>(input_tensor->data_c()), input_tensor->data().size());
  590. StringImmPtr string_imm_value = std::make_shared<StringImm>(value_string);
  591. value_node = std::make_shared<ValueNode>(string_imm_value);
  592. } else {
  593. value_node = std::make_shared<ValueNode>(input_tensor);
  594. }
  595. MS_EXCEPTION_IF_NULL(value_node);
  596. // construct abstract of value node
  597. auto type_of_tensor = input_tensor->Dtype();
  598. auto shape_of_tensor = input_tensor->shape();
  599. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  600. value_node->set_abstract(abstract);
  601. // add value node to graph
  602. auto input_value_node = NewValueNode(value_node);
  603. AddValueNodeToGraph(input_value_node);
  604. return input_value_node;
  605. }
  606. AnfNodePtr KernelGraph::TransValueNodeTuple(const AbstractBasePtr &abstract, const ValuePtr &value) {
  607. MS_EXCEPTION_IF_NULL(abstract);
  608. MS_EXCEPTION_IF_NULL(value);
  609. if (!abstract->isa<abstract::AbstractTuple>()) {
  610. auto new_value_node = NewValueNode(abstract, value);
  611. AddValueNodeToGraph(new_value_node);
  612. return new_value_node;
  613. }
  614. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  615. auto value_tuple = value->cast<ValueTuplePtr>();
  616. MS_EXCEPTION_IF_NULL(tuple_abstract);
  617. MS_EXCEPTION_IF_NULL(value_tuple);
  618. if (tuple_abstract->size() != value_tuple->size()) {
  619. MS_LOG(EXCEPTION) << "Abstract size:" << tuple_abstract->size()
  620. << " is not equal to value size:" << value_tuple->size();
  621. }
  622. std::vector<AnfNodePtr> make_tuple_inputs = {
  623. mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
  624. for (size_t index = 0; index < tuple_abstract->size(); ++index) {
  625. make_tuple_inputs.push_back(TransValueNodeTuple((*tuple_abstract)[index], (*value_tuple)[index]));
  626. }
  627. auto make_tuple = NewCNode(std::move(make_tuple_inputs));
  628. MS_EXCEPTION_IF_NULL(make_tuple);
  629. make_tuple->set_abstract(tuple_abstract);
  630. return make_tuple;
  631. }
  632. AnfNodePtr KernelGraph::TransParameterTuple(const AbstractBasePtr &abstract) {
  633. MS_EXCEPTION_IF_NULL(abstract);
  634. if (!abstract->isa<abstract::AbstractTuple>()) {
  635. return NewParameter(abstract);
  636. }
  637. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  638. MS_EXCEPTION_IF_NULL(tuple_abstract);
  639. std::vector<AnfNodePtr> make_tuple_inputs = {
  640. mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))};
  641. for (size_t index = 0; index < tuple_abstract->size(); ++index) {
  642. make_tuple_inputs.push_back(TransParameterTuple((*tuple_abstract)[index]));
  643. }
  644. auto make_tuple = NewCNode(std::move(make_tuple_inputs));
  645. make_tuple->set_abstract(tuple_abstract);
  646. return make_tuple;
  647. }
  648. AnfNodePtr KernelGraph::CreatTupleGetItemNode(const AnfNodePtr &node, size_t output_idx) {
  649. auto idx = mindspore::NewValueNode(SizeToLong(output_idx));
  650. MS_EXCEPTION_IF_NULL(idx);
  651. auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx));
  652. auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm);
  653. idx->set_abstract(abstract_scalar);
  654. AnfNodePtr tuple_getitem = NewCNode({mindspore::NewValueNode(prim::kPrimTupleGetItem), node, idx});
  655. MS_EXCEPTION_IF_NULL(tuple_getitem);
  656. tuple_getitem->set_scope(node->scope());
  657. std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
  658. TypeId origin_type = AnfAlgo::GetOutputInferDataType(node, output_idx);
  659. AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
  660. return tuple_getitem;
  661. }
  662. AnfNodePtr KernelGraph::TransCNodeTuple(const CNodePtr &node) {
  663. MS_EXCEPTION_IF_NULL(node);
  664. std::vector<TypeId> types;
  665. std::vector<std::vector<size_t>> shapes;
  666. std::vector<AnfNodePtr> make_tuple_inputs_list = {mindspore::NewValueNode(prim::kPrimMakeTuple)};
  667. size_t output_num = AnfAlgo::GetOutputTensorNum(node);
  668. for (size_t tuple_out_index = 0; tuple_out_index < output_num; ++tuple_out_index) {
  669. make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(node, tuple_out_index));
  670. types.push_back(AnfAlgo::GetOutputInferDataType(node, tuple_out_index));
  671. shapes.emplace_back(AnfAlgo::GetOutputInferShape(node, tuple_out_index));
  672. }
  673. auto make_tuple = NewCNode(std::move(make_tuple_inputs_list));
  674. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get());
  675. return make_tuple;
  676. }
  677. AnfNodePtr KernelGraph::TransTupleToMakeTuple(const AnfNodePtr &node) {
  678. MS_EXCEPTION_IF_NULL(node);
  679. if (!AnfAlgo::IsTupleOutput(node)) {
  680. return node;
  681. }
  682. if (node->isa<Parameter>()) {
  683. return TransParameterTuple(node->abstract());
  684. } else if (node->isa<ValueNode>()) {
  685. auto value_node = node->cast<ValueNodePtr>();
  686. MS_EXCEPTION_IF_NULL(value_node);
  687. auto make_tuple = TransValueNodeTuple(value_node->abstract(), value_node->value());
  688. if (!RemoveValueNodeFromGraph(value_node)) {
  689. MS_LOG(WARNING) << "Failed to remove the value_node " << value_node->DebugString();
  690. }
  691. return make_tuple;
  692. } else if (node->isa<CNode>()) {
  693. return TransCNodeTuple(node->cast<CNodePtr>());
  694. } else {
  695. return nullptr;
  696. }
  697. }
  698. const std::vector<AnfNodePtr> &KernelGraph::inputs() const {
  699. MS_EXCEPTION_IF_NULL(inputs_);
  700. return *inputs_;
  701. }
  702. void KernelGraph::FrontBackendMapAdd(const AnfNodePtr &front_anf, const AnfNodePtr &backend_anf) {
  703. MS_EXCEPTION_IF_NULL(front_anf);
  704. MS_EXCEPTION_IF_NULL(backend_anf);
  705. if (front_backend_anf_map_.find(front_anf) != front_backend_anf_map_.end()) {
  706. MS_LOG(EXCEPTION) << "Anf " << front_anf->DebugString() << " has been exist in the front_backend_anf_map_";
  707. }
  708. if (backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end()) {
  709. auto front_node = front_anf->cast<CNodePtr>();
  710. MS_EXCEPTION_IF_NULL(front_node);
  711. auto attr_input = front_node->input(kAnfPrimitiveIndex);
  712. MS_EXCEPTION_IF_NULL(attr_input);
  713. if (!attr_input->isa<CNode>()) {
  714. MS_LOG(EXCEPTION) << "Kernel " << backend_anf->DebugString() << "has been exist in the backend_front_anf_map_";
  715. }
  716. }
  717. front_backend_anf_map_[front_anf] = backend_anf;
  718. backend_front_anf_map_[backend_anf] = front_anf;
  719. }
  720. void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) {
  721. MS_EXCEPTION_IF_NULL(old_backend_anf);
  722. MS_EXCEPTION_IF_NULL(new_backend_anf);
  723. if (old_backend_anf == new_backend_anf) {
  724. MS_LOG(DEBUG) << "Old same with new:" << old_backend_anf->DebugString();
  725. return;
  726. }
  727. auto bf_iter = backend_front_anf_map_.find(old_backend_anf);
  728. if (bf_iter == backend_front_anf_map_.end()) {
  729. MS_LOG(DEBUG) << "Old_backend_anf " << old_backend_anf->DebugString() << " is not exist in the map";
  730. return;
  731. }
  732. auto front_anf = bf_iter->second;
  733. auto fb_iter = front_backend_anf_map_.find(front_anf);
  734. if (fb_iter == front_backend_anf_map_.end()) {
  735. MS_LOG(EXCEPTION) << "Anf is not exist in the map ,old " << old_backend_anf->DebugString();
  736. }
  737. fb_iter->second = new_backend_anf;
  738. // Delete old kernel, should be called before add new item to map.
  739. (void)backend_front_anf_map_.erase(bf_iter);
  740. backend_front_anf_map_[new_backend_anf] = front_anf;
  741. if (IsInternalOutput(old_backend_anf)) {
  742. ReplaceInternalOutput(old_backend_anf, new_backend_anf);
  743. }
  744. }
  745. // get kernel by anf
  746. AnfNodePtr KernelGraph::GetBackendAnfByFrontAnf(const AnfNodePtr &front_anf) {
  747. auto iter = front_backend_anf_map_.find(front_anf);
  748. if (iter == front_backend_anf_map_.end()) {
  749. return nullptr;
  750. }
  751. return iter->second;
  752. }
  753. AnfNodePtr KernelGraph::GetFrontAnfByBackendAnf(const AnfNodePtr &backend_anf) {
  754. auto iter = backend_front_anf_map_.find(backend_anf);
  755. if (iter == backend_front_anf_map_.end()) {
  756. return nullptr;
  757. }
  758. return iter->second;
  759. }
  760. bool KernelGraph::BackendNodeExistInFrontBackendMap(const AnfNodePtr &backend_anf) {
  761. return backend_front_anf_map_.find(backend_anf) != backend_front_anf_map_.end();
  762. }
  763. ValueNodePtr KernelGraph::GetValueNodeByTensor(const mindspore::tensor::TensorPtr &tensor) {
  764. auto iter = tensor_to_value_node_map_.find(tensor);
  765. if (iter == tensor_to_value_node_map_.end()) {
  766. return nullptr;
  767. }
  768. return iter->second;
  769. }
  770. void KernelGraph::TensorValueNodeMapAdd(const tensor::TensorPtr &tensor, const ValueNodePtr &value_node) {
  771. MS_EXCEPTION_IF_NULL(tensor);
  772. MS_EXCEPTION_IF_NULL(value_node);
  773. tensor_to_value_node_map_[tensor] = value_node;
  774. }
  775. void KernelGraph::AddDependEdge(const AnfNodePtr &node, const AnfNodePtr &input, size_t depend_edge_num) {
  776. MS_EXCEPTION_IF_NULL(node);
  777. MS_EXCEPTION_IF_NULL(input);
  778. MS_LOG(DEBUG) << "Input:" << input->DebugString() << ", node:" << node->DebugString() << ",num:" << depend_edge_num;
  779. // add output depend edge of input
  780. node_output_edges_[input].emplace_back(node, depend_edge_num);
  781. // add input depend edge of output
  782. node_input_edges_[node].emplace_back(input, depend_edge_num);
  783. // add node input depend num
  784. node_input_num_[node] += depend_edge_num;
  785. }
  786. std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
  787. MS_EXCEPTION_IF_NULL(node);
  788. auto it = node_output_edges_.find(node);
  789. if (it == node_output_edges_.end()) {
  790. MS_LOG(EXCEPTION) << "Can't find node[" << node->DebugString() << "]";
  791. }
  792. std::vector<AnfNodePtr> output_nodes;
  793. output_nodes.reserve(it->second.size());
  794. (void)std::transform(it->second.begin(), it->second.end(), std::back_inserter(output_nodes),
  795. [](const auto &p) { return p.first; });
  796. return output_nodes;
  797. }
  798. void KernelGraph::UpdateNodeEdgeList(std::queue<AnfNodePtr> *seed_nodes) {
  799. MS_EXCEPTION_IF_NULL(seed_nodes);
  800. node_output_edges_.clear();
  801. node_input_num_.clear();
  802. node_input_edges_.clear();
  803. mindspore::HashSet<AnfNodePtr> visited_nodes;
  804. std::queue<AnfNodePtr> que;
  805. que.push(get_return());
  806. while (!que.empty()) {
  807. auto node = que.front();
  808. que.pop();
  809. MS_EXCEPTION_IF_NULL(node);
  810. if (node->isa<Parameter>() || node->isa<ValueNode>()) {
  811. seed_nodes->push(node);
  812. continue;
  813. }
  814. auto cnode = dyn_cast<CNode>(node);
  815. if (cnode == nullptr) {
  816. continue;
  817. }
  818. auto &inputs = cnode->inputs();
  819. // We push inputs from right to left, so that them can be evaluated from left to right.
  820. for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
  821. auto &input = *iter;
  822. PushNoVisitedNode(input, &que, &visited_nodes);
  823. AddDependEdge(node, input, 1);
  824. }
  825. }
  826. }
  827. void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); }
  828. bool KernelGraph::IsInRefOutputMap(const AnfWithOutIndex &pair) const { return ref_out_in_map_.count(pair) != 0; }
  829. AnfWithOutIndex KernelGraph::GetRefCorrespondOutput(const AnfWithOutIndex &out_pair) const {
  830. if (!IsInRefOutputMap(out_pair)) {
  831. MS_LOG(EXCEPTION) << "Out_pair is not in RefOutputMap, node is " << out_pair.first->DebugString() << ", index is "
  832. << out_pair.second;
  833. }
  834. return ref_out_in_map_.at(out_pair);
  835. }
  836. void KernelGraph::AddRefCorrespondPairs(const AnfWithOutIndex &final_pair, const AnfWithOutIndex &origin_pair) {
  837. if (IsInRefOutputMap(final_pair)) {
  838. MS_LOG(EXCEPTION) << "Out_pair is already in RefOutputMap, node is " << final_pair.first->DebugString()
  839. << ", index is " << final_pair.second;
  840. }
  841. (void)ref_out_in_map_.emplace(final_pair, origin_pair);
  842. }
  843. bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) {
  844. return graph_value_nodes_.erase(value_node) != 0;
  845. }
  846. void KernelGraph::ReplaceGraphInput(const AnfNodePtr &old_parameter, const AnfNodePtr &new_parameter) {
  847. // update graph inputs
  848. MS_EXCEPTION_IF_NULL(old_parameter);
  849. MS_EXCEPTION_IF_NULL(new_parameter);
  850. if (old_parameter == new_parameter) {
  851. return;
  852. }
  853. for (size_t i = 0; i < inputs_->size(); i++) {
  854. if ((*inputs_)[i] == old_parameter) {
  855. MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_parameter->DebugString()
  856. << ",new graph input:" << new_parameter->DebugString();
  857. (*inputs_)[i] = new_parameter;
  858. FrontBackendlMapUpdate(old_parameter, new_parameter);
  859. break;
  860. }
  861. }
  862. }
  863. void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, const AnfNodePtr &new_anf_node) {
  864. MS_EXCEPTION_IF_NULL(inputs_);
  865. auto it = node_output_edges_.find(old_anf_node);
  866. if (it == node_output_edges_.end()) {
  867. MS_LOG(WARNING) << "Old node not found " << old_anf_node->DebugString();
  868. return;
  869. }
  870. for (auto &user : it->second) {
  871. auto user_cnode = dyn_cast<CNode>(user.first);
  872. MS_EXCEPTION_IF_NULL(user_cnode);
  873. auto &inputs = user_cnode->inputs();
  874. for (size_t i = 1; i < inputs.size(); i++) {
  875. if (inputs[i] == old_anf_node) {
  876. user_cnode->set_input(i, new_anf_node);
  877. }
  878. }
  879. }
  880. }
  881. void KernelGraph::UpdateExecuteKernelStreamLabel() {
  882. for (auto &kernel : execution_order_) {
  883. AnfAlgo::SetStreamDistinctionLabel(stream_distinction_label_, kernel.get());
  884. }
  885. }
  886. std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
  887. std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
  888. if (IsLeafGraph()) {
  889. leaf_graph_order.push_back(shared_from_this()->cast<KernelGraphPtr>());
  890. } else {
  891. for (const auto &child_graph : child_graph_order_) {
  892. std::shared_ptr<KernelGraph> child_graph_ptr = child_graph.lock();
  893. MS_EXCEPTION_IF_NULL(child_graph_ptr);
  894. auto child_leaf_graph_order = child_graph_ptr->GetLeafGraphOrder();
  895. std::copy(child_leaf_graph_order.begin(), child_leaf_graph_order.end(), std::back_inserter(leaf_graph_order));
  896. }
  897. }
  898. return leaf_graph_order;
  899. }
  900. bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
  901. std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const {
  902. std::vector<CNodePtr> result;
  903. for (const auto &anf : execution_order_) {
  904. MS_EXCEPTION_IF_NULL(anf);
  905. if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
  906. result.push_back(anf->cast<CNodePtr>());
  907. }
  908. }
  909. return result;
  910. }
  911. std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const std::vector<PrimitivePtr> &primitive_list) const {
  912. std::vector<CNodePtr> result;
  913. for (const auto &anf : execution_order_) {
  914. MS_EXCEPTION_IF_NULL(anf);
  915. for (const auto &primitive : primitive_list) {
  916. if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
  917. result.push_back(anf->cast<CNodePtr>());
  918. }
  919. }
  920. }
  921. return result;
  922. }
  923. void KernelGraph::PrintGraphExecuteOrder() const {
  924. if (!(IS_OUTPUT_ON(INFO))) {
  925. return;
  926. }
  927. MS_LOG(INFO) << "Graph " << graph_id_ << " execution order:";
  928. for (size_t i = 0; i < execution_order_.size(); i++) {
  929. CNodePtr cur_cnode_ptr = execution_order_[i];
  930. MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
  931. std::string event_str;
  932. if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) {
  933. event_str = ", event id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId)) + "]";
  934. }
  935. std::string label_str;
  936. if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) {
  937. label_str = ", label id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrLabelIndex)) + "]";
  938. }
  939. if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) {
  940. auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrLabelSwitchList);
  941. label_str = ", label id[";
  942. for (size_t j = 0; j < label_list.size(); ++j) {
  943. label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]");
  944. }
  945. }
  946. std::string active_stream_str;
  947. if (AnfAlgo::HasNodeAttr(kAttrActiveStreamList, cur_cnode_ptr)) {
  948. auto stream_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrActiveStreamList);
  949. active_stream_str = ", active stream id[";
  950. for (size_t j = 0; j < stream_list.size(); ++j) {
  951. active_stream_str += std::to_string(stream_list[j]) + (j + 1 < stream_list.size() ? ", " : "]");
  952. }
  953. }
  954. std::string group_str;
  955. if (AnfAlgo::GetKernelType(cur_cnode_ptr) == HCCL_KERNEL && AnfAlgo::HasNodeAttr(kAttrGroup, cur_cnode_ptr)) {
  956. group_str = ", group[" + AnfAlgo::GetNodeAttr<std::string>(cur_cnode_ptr, kAttrGroup) + "]";
  957. }
  958. MS_LOG(INFO) << "Index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
  959. << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
  960. << AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"
  961. << event_str << label_str << active_stream_str << group_str;
  962. }
  963. }
  964. void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, size_t output_idx,
  965. bool unique_target) {
  966. if (front_node == nullptr || node == nullptr) {
  967. MS_LOG(INFO) << "Front node or node is nullptr";
  968. return;
  969. }
  970. MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString();
  971. front_to_internal_outputs_map_[front_node] = node;
  972. if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
  973. output_idx = AnfAlgo::GetTupleGetItemOutIndex(front_node->cast<CNodePtr>());
  974. }
  975. internal_outputs_to_front_map_[node][output_idx] = std::pair<AnfNodePtr, bool>(front_node, unique_target);
  976. }
  977. void KernelGraph::AddInternalOutputTensor(const AnfNodePtr &node, size_t output_idx, const tensor::TensorPtr &tensor) {
  978. if (node == nullptr) {
  979. return;
  980. }
  981. internal_outputs_tensor_map_[node][output_idx] = tensor;
  982. }
  983. tensor::TensorPtr KernelGraph::GetInternalOutputTensor(const AnfNodePtr &node, size_t output_idx) {
  984. if (node == nullptr) {
  985. return nullptr;
  986. }
  987. auto iter = internal_outputs_tensor_map_.find(node);
  988. if (iter == internal_outputs_tensor_map_.end()) {
  989. return nullptr;
  990. }
  991. auto idx_iter = iter->second.find(output_idx);
  992. if (idx_iter == iter->second.end()) {
  993. return nullptr;
  994. }
  995. return idx_iter->second;
  996. }
  997. void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) {
  998. if (new_node == nullptr || node == nullptr) {
  999. MS_LOG(INFO) << "New node or node is nullptr";
  1000. return;
  1001. }
  1002. if (node == new_node) {
  1003. MS_LOG(INFO) << "New node and node is the same";
  1004. return;
  1005. }
  1006. auto iter = internal_outputs_to_front_map_.find(node);
  1007. if (iter == internal_outputs_to_front_map_.end()) {
  1008. MS_LOG(INFO) << "Node is not internal output";
  1009. return;
  1010. }
  1011. MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString();
  1012. auto front_nodes = std::move(iter->second);
  1013. // We should do 'erase(iter)' before modify 'internal_outputs_to_front_map_',
  1014. // since the 'iter' may be invalidated after new item added.
  1015. internal_outputs_to_front_map_.erase(iter);
  1016. // Move all front nodes to new node mapping.
  1017. for (const auto &front_node_iter : front_nodes) {
  1018. front_to_internal_outputs_map_[front_node_iter.second.first] = new_node;
  1019. }
  1020. internal_outputs_to_front_map_[new_node] = std::move(front_nodes);
  1021. }
  1022. void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, size_t src_output_idx,
  1023. size_t dst_output_idx) {
  1024. if (new_node == nullptr || node == nullptr) {
  1025. MS_LOG(INFO) << "New node or node is nullptr";
  1026. return;
  1027. }
  1028. if (node == new_node) {
  1029. MS_LOG(INFO) << "New node and node is the same";
  1030. return;
  1031. }
  1032. auto iter = internal_outputs_to_front_map_.find(node);
  1033. if (iter == internal_outputs_to_front_map_.end()) {
  1034. MS_LOG(INFO) << "Node is not internal output";
  1035. return;
  1036. }
  1037. MS_LOG(INFO) << "Replace internal output node " << node->DebugString() << " to " << new_node->DebugString();
  1038. auto &front_nodes = iter->second;
  1039. // Move specified front node to new node mapping
  1040. auto front_node_iter = front_nodes.find(src_output_idx);
  1041. if (front_node_iter == front_nodes.end()) {
  1042. MS_LOG(INFO) << "The output " << src_output_idx << " of node " << node->DebugString() << " is not an internal node";
  1043. return;
  1044. }
  1045. auto front_node_pair = std::move(front_node_iter->second);
  1046. (void)front_nodes.erase(front_node_iter);
  1047. if (front_nodes.empty()) {
  1048. (void)internal_outputs_to_front_map_.erase(iter);
  1049. }
  1050. // We should do 'erase' before 'insert', since the 'iter' may be invalidated after new item added.
  1051. front_to_internal_outputs_map_[front_node_pair.first] = new_node;
  1052. internal_outputs_to_front_map_[new_node][dst_output_idx] = std::move(front_node_pair);
  1053. }
  1054. void KernelGraph::CacheInternalParameterToFrontNode(const AnfNodePtr &parameter,
  1055. const AnfWithOutIndex &front_node_with_index) {
  1056. if ((parameter == nullptr) || (front_node_with_index.first == nullptr)) {
  1057. return;
  1058. }
  1059. auto front_outputs = AnfAlgo::GetAllOutputWithIndex(front_node_with_index.first);
  1060. AnfWithOutIndex new_front_node_with_index;
  1061. if (front_node_with_index.second < front_outputs.size()) {
  1062. new_front_node_with_index = front_outputs[front_node_with_index.second];
  1063. } else {
  1064. new_front_node_with_index = front_node_with_index;
  1065. }
  1066. if (new_front_node_with_index.first == nullptr) {
  1067. return;
  1068. }
  1069. MS_LOG(INFO) << "Cache internal parameter: " << parameter->DebugString()
  1070. << " to front node: " << new_front_node_with_index.first->DebugString()
  1071. << " with index: " << new_front_node_with_index.second
  1072. << ", from front node: " << front_node_with_index.first->DebugString()
  1073. << " with index: " << front_node_with_index.second;
  1074. internal_parameter_to_front_node_map_[parameter] = new_front_node_with_index;
  1075. }
  1076. AnfWithOutIndex KernelGraph::GetFrontNodeByInternalParameter(const AnfNodePtr &parameter) const {
  1077. auto iter = internal_parameter_to_front_node_map_.find(parameter);
  1078. if (iter != internal_parameter_to_front_node_map_.end()) {
  1079. return iter->second;
  1080. }
  1081. return AnfWithOutIndex();
  1082. }
  1083. FuncGraphPtr KernelGraph::GetFuncGraph() {
  1084. for (const auto &front_backend_anf : front_backend_anf_map_) {
  1085. const auto &front_node = front_backend_anf.first;
  1086. const auto &func_graph = front_node->func_graph();
  1087. if (func_graph != nullptr) {
  1088. return func_graph;
  1089. }
  1090. }
  1091. return nullptr;
  1092. }
  1093. void KernelGraph::CacheGraphOutputToFrontNodeWithIndex(const std::vector<AnfNodePtr> &backend_outputs,
  1094. const std::vector<AnfNodePtr> &front_outputs) {
  1095. MS_LOG(INFO) << "Get graph backend output nodes.";
  1096. std::vector<KernelWithIndex> backend_output_nodes;
  1097. for (auto &backend_output : backend_outputs) {
  1098. auto temp_backend_outputs = AnfAlgo::GetAllOutputWithIndex(backend_output);
  1099. backend_output_nodes.insert(backend_output_nodes.end(), temp_backend_outputs.begin(), temp_backend_outputs.end());
  1100. }
  1101. MS_LOG(INFO) << "Get graph front output nodes.";
  1102. std::vector<KernelWithIndex> front_output_nodes;
  1103. for (auto &front_output : front_outputs) {
  1104. auto temp_front_outputs = AnfAlgo::GetAllOutputWithIndex(front_output);
  1105. front_output_nodes.insert(front_output_nodes.end(), temp_front_outputs.begin(), temp_front_outputs.end());
  1106. }
  1107. if (backend_output_nodes.size() != front_output_nodes.size()) {
  1108. MS_LOG(WARNING) << "The size(" << backend_output_nodes.size() << ") of backend outputs: "
  1109. << " is not equal to the size(" << front_output_nodes.size() << ") of front outputs.";
  1110. return;
  1111. }
  1112. for (size_t i = 0; i < backend_output_nodes.size(); ++i) {
  1113. auto backend_output_node = backend_output_nodes[i];
  1114. auto front_output_node = front_output_nodes[i];
  1115. graph_output_to_front_node_map_[backend_output_node] = front_output_node;
  1116. MS_LOG(INFO) << "Backend output: " << backend_output_node.first->fullname_with_scope()
  1117. << " with index: " << backend_output_node.second
  1118. << " map to front node: " << front_output_node.first->fullname_with_scope()
  1119. << " with index: " << front_output_node.second;
  1120. }
  1121. }
  1122. AnfWithOutIndex KernelGraph::GetFrontNodeWithIndexByGraphOutput(
  1123. const AnfWithOutIndex &backend_graph_output_with_index) const {
  1124. auto iter = graph_output_to_front_node_map_.find(backend_graph_output_with_index);
  1125. if (iter != graph_output_to_front_node_map_.end()) {
  1126. return iter->second;
  1127. }
  1128. return AnfWithOutIndex();
  1129. }
  1130. AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
  1131. auto iter = front_to_internal_outputs_map_.find(front_node);
  1132. if (iter != front_to_internal_outputs_map_.end()) {
  1133. return iter->second;
  1134. }
  1135. return nullptr;
  1136. }
  1137. bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const {
  1138. return internal_outputs_to_front_map_.find(node) != internal_outputs_to_front_map_.end();
  1139. }
  1140. bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, size_t output_idx) const {
  1141. auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
  1142. if (front_nodes_iter == internal_outputs_to_front_map_.end()) {
  1143. return false;
  1144. }
  1145. auto &front_nodes = front_nodes_iter->second;
  1146. return front_nodes.find(output_idx) != front_nodes.end();
  1147. }
  1148. bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, size_t output_idx) const {
  1149. auto front_nodes_iter = internal_outputs_to_front_map_.find(node);
  1150. if (front_nodes_iter == internal_outputs_to_front_map_.end()) {
  1151. return false;
  1152. }
  1153. auto &front_nodes = front_nodes_iter->second;
  1154. auto idx_iter = front_nodes.find(output_idx);
  1155. if (idx_iter == front_nodes.end()) {
  1156. return false;
  1157. }
  1158. return idx_iter->second.second;
  1159. }
  1160. void KernelGraph::UpdateChildGraphOrder() {
  1161. MS_LOG(INFO) << "Update " << ToString() << " child graph order.";
  1162. SetExecOrderByDefault();
  1163. auto call_nodes = FindNodeByPrimitive({std::make_shared<Primitive>(prim::kPrimCall->name()),
  1164. std::make_shared<Primitive>(prim::kPrimSwitch->name()),
  1165. std::make_shared<Primitive>(prim::kPrimSwitchLayer->name())});
  1166. std::vector<std::weak_ptr<KernelGraph>> child_graph_order;
  1167. for (auto &call_node : call_nodes) {
  1168. MS_EXCEPTION_IF_NULL(call_node);
  1169. auto call_child_graphs = AnfAlgo::GetCallSwitchKernelGraph(call_node->cast<CNodePtr>());
  1170. for (const auto &child_graph : call_child_graphs) {
  1171. MS_EXCEPTION_IF_NULL(child_graph);
  1172. if (child_graph != parent_graph_.lock()) {
  1173. auto shared_this = std::dynamic_pointer_cast<KernelGraph>(shared_from_this());
  1174. MS_EXCEPTION_IF_NULL(shared_this);
  1175. child_graph->set_parent_graph(shared_this);
  1176. }
  1177. child_graph_order.push_back(child_graph);
  1178. }
  1179. }
  1180. for (size_t i = 0; i < child_graph_order.size(); ++i) {
  1181. std::shared_ptr<KernelGraph> child_graph = child_graph_order[i].lock();
  1182. MS_EXCEPTION_IF_NULL(child_graph);
  1183. MS_LOG(INFO) << "Child graph[" << i << "][id:" << child_graph->graph_id() << "]";
  1184. }
  1185. child_graph_order_ = child_graph_order;
  1186. }
  1187. void KernelGraph::RemoveNodeFromGraph(const AnfNodePtr &node) {
  1188. MS_EXCEPTION_IF_NULL(node);
  1189. auto iter = backend_front_anf_map_.find(node);
  1190. if (iter != backend_front_anf_map_.end()) {
  1191. (void)front_backend_anf_map_.erase(iter->second);
  1192. (void)backend_front_anf_map_.erase(iter);
  1193. }
  1194. if (node->isa<ValueNode>()) {
  1195. (void)graph_value_nodes_.erase(node->cast<ValueNodePtr>());
  1196. }
  1197. }
  1198. void KernelGraph::UpdateGraphDynamicAttr() {
  1199. for (const auto &cnode : execution_order_) {
  1200. if (AnfAlgo::IsDynamicShape(cnode)) {
  1201. MS_LOG(INFO) << "Update Graph Dynamic Attr";
  1202. is_dynamic_shape_ = true;
  1203. return;
  1204. }
  1205. }
  1206. is_dynamic_shape_ = false;
  1207. }
  1208. void KernelGraph::SetInputNodes() {
  1209. input_nodes_.clear();
  1210. for (const auto &input_node : inputs()) {
  1211. auto params = AnfAlgo::GetAllOutput(input_node);
  1212. if (params.size() == 1) {
  1213. FrontBackendlMapUpdate(input_node, params[0]);
  1214. } else {
  1215. auto front_node = backend_front_anf_map_[input_node];
  1216. for (size_t i = 0; i < params.size(); ++i) {
  1217. FrontBackendlMapUpdate(input_node, params[i]);
  1218. tuple_backend_front_anf_index_map_[params[i]] = AnfWithOutIndex(front_node, i);
  1219. }
  1220. }
  1221. std::copy(params.begin(), params.end(), std::back_inserter(input_nodes_));
  1222. }
  1223. }
  1224. void KernelGraph::UpdateGraphAquireGilAttr() {
  1225. for (const auto &cnode : execution_order_) {
  1226. if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPyFunc)) {
  1227. MS_LOG(INFO) << "The Graph require GIL. Graph id: " << graph_id_;
  1228. is_need_gil_ = true;
  1229. return;
  1230. }
  1231. }
  1232. }
  1233. void KernelGraph::SetOptimizerFlag() {
  1234. has_optimizer_ = false;
  1235. for (const auto &cnode : execution_order_) {
  1236. MS_EXCEPTION_IF_NULL(cnode);
  1237. auto node_name = AnfAlgo::GetCNodeName(cnode);
  1238. if (AnfAlgo::HasNodeAttr(kAttrAsync, cnode) && AnfAlgo::GetNodeAttr<bool>(cnode, kAttrAsync)) {
  1239. continue;
  1240. }
  1241. if (kOptOperatorSet.find(node_name) != kOptOperatorSet.end()) {
  1242. has_optimizer_ = true;
  1243. } else if (node_name.find("Assign") == string::npos) {
  1244. continue;
  1245. }
  1246. for (auto &input : cnode->inputs()) {
  1247. MS_EXCEPTION_IF_NULL(input);
  1248. auto real_node = AnfAlgo::VisitKernel(input, 0).first;
  1249. MS_EXCEPTION_IF_NULL(real_node);
  1250. if (!real_node->isa<Parameter>()) {
  1251. continue;
  1252. }
  1253. auto param = real_node->cast<ParameterPtr>();
  1254. auto abstract = param->abstract();
  1255. MS_EXCEPTION_IF_NULL(abstract);
  1256. if (abstract->isa<abstract::AbstractRef>()) {
  1257. has_optimizer_ = true;
  1258. (void)updated_parameters_.insert(param);
  1259. }
  1260. }
  1261. }
  1262. }
  1263. bool KernelGraph::IsDatasetGraph() const {
  1264. // check if there is InitDataSetQueue node
  1265. const auto &nodes = execution_order_;
  1266. for (const auto &node : nodes) {
  1267. auto node_name = AnfAlgo::GetCNodeName(node);
  1268. if (node_name == prim::kPrimInitDataSetQueue->name()) {
  1269. return true;
  1270. }
  1271. }
  1272. return false;
  1273. }
  1274. std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
  1275. bool KernelGraph::IsChildGraphResult(const AnfNodePtr &node) {
  1276. std::vector<AnfNodePtr> child_graph_results;
  1277. for (const auto &child_graph_result : child_graph_result_) {
  1278. MS_EXCEPTION_IF_NULL(child_graph_result);
  1279. if (AnfAlgo::CheckPrimitiveType(child_graph_result, prim::kPrimMakeTuple)) {
  1280. const auto cnode = child_graph_result->cast<CNodePtr>();
  1281. MS_EXCEPTION_IF_NULL(cnode);
  1282. const auto &inputs = cnode->inputs();
  1283. child_graph_results.insert(child_graph_results.end(), inputs.begin(), inputs.end());
  1284. } else {
  1285. child_graph_results.emplace_back(child_graph_result);
  1286. }
  1287. }
  1288. return find(child_graph_results.begin(), child_graph_results.end(), node) != child_graph_results.end();
  1289. }
  1290. KernelGraph::~KernelGraph() {
  1291. try {
  1292. // Release the kernel resource.
  1293. for (const auto &kernel : execution_order_) {
  1294. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  1295. if (kernel_mod != nullptr) {
  1296. kernel_mod->ReleaseResource();
  1297. }
  1298. }
  1299. device::KernelRuntimeManager::Instance().ClearGraphResource(graph_id_);
  1300. } catch (const std::exception &e) {
  1301. MS_LOG(ERROR) << "KernelGraph call destructor failed: " << e.what();
  1302. } catch (...) {
  1303. MS_LOG(ERROR) << "KernelGraph call destructor failed";
  1304. }
  1305. }
  1306. } // namespace session
  1307. } // namespace mindspore