You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.cc 65 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
5 years ago
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
adapte to remove inline merge me commit for remove inline deal witch multiple cases of switch in ConstructKernelGraph deal with switch and call cases in ConstructKernelGraph fix bug and rebase master ConstructKernelGraph adapte to remove inline fix InsertMultipleAssignToGraph bug add graph input to new graph which is created for switch input replace CreateNewParameterFromCNode to NewParameter in order to set new parameter's abstract and kernel_info avoids create a new switch repeatedly when the cnode is a call switch without real input null pointer check update frontend code Revert "update frontend code" This reverts commit ce1f600d1e9b4b47d9b81122f981bbbe505dd250. update frontend code PR_2948 fix bug of CheckLabalIndex handle switch_layer in ConstructKernelGraph add attr for assign node to avoid erasing by cse pass cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem rebase master Revert "cherry-pick ms commit[59b35f690ddcc94ff35a4f4eaf3816121b32235b]:temporary avoid list getitem problem" This reverts commit 74c258f94260ca0769a1ef69c6ef8e831c301dbf. Revert "handle switch_layer in ConstructKernelGraph" This reverts commit cb5367f02d69facbca8d39e9234c501608aee27f. Revert "update frontend code PR_2948" This reverts commit 234ac583400a96a8ddd641f7a722e1ccd5e056c6. Revert "merge me commit for remove inline" This reverts commit 55c0ebd42b6699c7686f5ce585e745f87dd42280. fix diff after rebase master doing remove inline in me overwrite FindNodePrimitive Revert "doing remove inline in me" This reverts commit b42e893125bc624d323e855ac6ae615333c06e65.
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/session_basic.h"
  17. #include <utility>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include "c_ops/primitive_c.h"
  21. #include "ir/manager.h"
  22. #include "ir/param_info.h"
  23. #include "backend/kernel_compiler/common_utils.h"
  24. #include "base/core_ops.h"
  25. #include "common/trans.h"
  26. #include "utils/config_manager.h"
  27. #include "backend/session/anf_runtime_algorithm.h"
  28. #include "backend/session/executor.h"
  29. #include "backend/session/executor_manager.h"
  30. #include "backend/optimizer/common/common_backend_optimization.h"
  31. #include "backend/optimizer/common/helper.h"
  32. #include "runtime/device/kernel_runtime_manager.h"
  33. #include "utils/ms_utils.h"
  34. #include "ir/dtype.h"
  35. #include "ir/anf.h"
  36. #include "ir/func_graph_cloner.h"
  37. #include "utils/utils.h"
  38. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  39. #include "ps/worker.h"
  40. #include "ps/common.h"
  41. #include "ps/util.h"
  42. #endif
  43. namespace mindspore {
  44. namespace session {
  45. static std::shared_ptr<std::map<ValuePtr, ParameterPtr>> python_paras;
  46. void ClearPythonParasMap() { python_paras = nullptr; }
  47. namespace {
  48. const int kSummaryGetItem = 2;
  49. ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
  50. if (node == nullptr) {
  51. return nullptr;
  52. }
  53. auto parameter = node->cast<ParameterPtr>();
  54. if (parameter == nullptr || !parameter->has_default()) {
  55. return nullptr;
  56. }
  57. return parameter->default_param();
  58. }
  59. tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
  60. const KernelGraphPtr &graph) {
  61. auto &node = node_output_pair.first;
  62. auto &output_index = node_output_pair.second;
  63. MS_EXCEPTION_IF_NULL(node);
  64. MS_EXCEPTION_IF_NULL(graph);
  65. TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
  66. if (type_id == kTypeUnknown) {
  67. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  68. }
  69. tensor::TensorPtr tensor = nullptr;
  70. std::vector<int> temp_shape;
  71. if (graph->IsUniqueTargetInternalOutput(node, output_index)) {
  72. temp_shape.emplace_back(1);
  73. tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  74. tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
  75. tensor->set_sync_status(kNoNeedSync);
  76. tensor->SetNeedWait(true);
  77. return tensor;
  78. }
  79. tensor = graph->GetInternalOutputTensor(node, output_index);
  80. if (tensor == nullptr) {
  81. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  82. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  83. tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  84. bool is_internal_output = graph->IsInternalOutput(node, output_index);
  85. if (is_internal_output) {
  86. graph->AddInternalOutputTensor(node, output_index, tensor);
  87. }
  88. }
  89. tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
  90. // if in paynative mode,data only copyed to host when user want to print data
  91. auto ms_context = MsContext::GetInstance();
  92. MS_EXCEPTION_IF_NULL(ms_context);
  93. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
  94. ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
  95. tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
  96. } else {
  97. tensor->set_sync_status(kNeedSyncDeviceToHost);
  98. }
  99. tensor->SetNeedWait(true);
  100. return tensor;
  101. }
  102. BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
  103. const std::vector<tensor::TensorPtr> &input_tensors,
  104. std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
  105. auto &node = node_output_pair.first;
  106. auto &output_index = node_output_pair.second;
  107. MS_EXCEPTION_IF_NULL(node);
  108. MS_EXCEPTION_IF_NULL(graph);
  109. MS_EXCEPTION_IF_NULL(tensor_to_node);
  110. MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << node_output_pair.second << "]";
  111. // if node is a value node, no need sync addr from device to host
  112. if (node->isa<ValueNode>()) {
  113. auto value_node = node->cast<ValueNodePtr>();
  114. MS_EXCEPTION_IF_NULL(value_node);
  115. return value_node->value();
  116. }
  117. if (!AnfAlgo::OutputAddrExist(node, output_index)) {
  118. if (node->isa<Parameter>()) {
  119. for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
  120. if (input_idx >= input_tensors.size()) {
  121. MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
  122. }
  123. if (graph->inputs()[input_idx] == node) {
  124. return input_tensors[input_idx];
  125. }
  126. }
  127. MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr";
  128. }
  129. }
  130. auto tensor = CreateCNodeOutputTensor(node_output_pair, graph);
  131. (*tensor_to_node)[tensor] = node_output_pair;
  132. return tensor;
  133. }
  134. BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &graph,
  135. const std::vector<tensor::TensorPtr> &input_tensors,
  136. std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
  137. MS_EXCEPTION_IF_NULL(anf);
  138. MS_EXCEPTION_IF_NULL(tensor_to_node);
  139. MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
  140. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  141. MS_EXCEPTION_IF_NULL(item_with_index.first);
  142. MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
  143. // special handle for maketuple
  144. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  145. auto cnode = item_with_index.first->cast<CNodePtr>();
  146. MS_EXCEPTION_IF_NULL(cnode);
  147. VectorRef ret;
  148. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  149. auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node);
  150. ret.push_back(out);
  151. }
  152. return ret;
  153. }
  154. // if is graph return nothing ,the function should return a null anylist
  155. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  156. if (size == 0) {
  157. return VectorRef();
  158. }
  159. return CreateNodeOutputTensor(item_with_index, graph, input_tensors, tensor_to_node);
  160. }
  161. ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
  162. MS_EXCEPTION_IF_NULL(anf);
  163. MS_EXCEPTION_IF_NULL(graph);
  164. auto value_node = anf->cast<ValueNodePtr>();
  165. MS_EXCEPTION_IF_NULL(value_node);
  166. auto value = value_node->value();
  167. MS_EXCEPTION_IF_NULL(value);
  168. if (value->isa<None>()) {
  169. return nullptr;
  170. }
  171. auto new_value_node = graph->NewValueNode(value_node);
  172. graph->FrontBackendlMapAdd(anf, new_value_node);
  173. graph->AddValueNodeToGraph(new_value_node);
  174. return new_value_node;
  175. }
  176. size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
  177. MS_EXCEPTION_IF_NULL(graph);
  178. MS_LOG(INFO) << "Load kInputCtrlTensors";
  179. auto inputs_params = graph->input_ctrl_tensors();
  180. if (inputs_params == nullptr) {
  181. return 0;
  182. }
  183. if (inputs_params->size() < 3) {
  184. MS_LOG(EXCEPTION) << "Illegal inputs_params size";
  185. }
  186. // update current loop tensor to 0 per iterator
  187. auto cur_loop_tensor = (*inputs_params)[0];
  188. MS_EXCEPTION_IF_NULL(cur_loop_tensor);
  189. auto *cur_val = static_cast<int32_t *>(cur_loop_tensor->data_c());
  190. MS_EXCEPTION_IF_NULL(cur_val);
  191. *cur_val = 0;
  192. cur_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
  193. // set loop_count to zero
  194. MS_EXCEPTION_IF_NULL(inputs);
  195. inputs->push_back(cur_loop_tensor);
  196. // update next loop tensor to 0 per iterator
  197. auto next_loop_tensor = (*inputs_params)[1];
  198. MS_EXCEPTION_IF_NULL(next_loop_tensor);
  199. auto *next_val = static_cast<int32_t *>(next_loop_tensor->data_c());
  200. MS_EXCEPTION_IF_NULL(next_val);
  201. *next_val = 0;
  202. next_loop_tensor->set_sync_status(kNeedSyncHostToDevice);
  203. // set loop_count to zero
  204. MS_EXCEPTION_IF_NULL(inputs);
  205. inputs->push_back(next_loop_tensor);
  206. auto epoch_tensor = (*inputs_params)[2];
  207. MS_EXCEPTION_IF_NULL(epoch_tensor);
  208. auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c());
  209. MS_EXCEPTION_IF_NULL(epoch_val);
  210. *epoch_val = graph->current_epoch();
  211. epoch_tensor->set_sync_status(kNeedSyncHostToDevice);
  212. inputs->push_back(epoch_tensor);
  213. MS_LOG(INFO) << "Load epoch_val:" << *epoch_val;
  214. graph->set_current_epoch(graph->current_epoch() + 1);
  215. return inputs_params->size();
  216. }
  217. ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor) {
  218. MS_EXCEPTION_IF_NULL(graph);
  219. MS_EXCEPTION_IF_NULL(input_tensor);
  220. auto value_node = std::make_shared<ValueNode>(input_tensor);
  221. MS_EXCEPTION_IF_NULL(value_node);
  222. // construct abstract of value node
  223. auto type_of_tensor = input_tensor->Dtype();
  224. auto shape_of_tensor = input_tensor->shape();
  225. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  226. value_node->set_abstract(abstract);
  227. // add value node to graph
  228. auto input_value_node = graph->NewValueNode(value_node);
  229. graph->AddValueNodeToGraph(input_value_node);
  230. return input_value_node;
  231. }
  232. ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
  233. int tensor_mask) {
  234. MS_EXCEPTION_IF_NULL(graph);
  235. auto param = graph->NewParameter();
  236. MS_EXCEPTION_IF_NULL(param);
  237. if (tensor_mask == kParameterWeightTensorMask) {
  238. param->set_default_param(input_tensor);
  239. }
  240. // set the kernel info of parameter
  241. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  242. MS_EXCEPTION_IF_NULL(input_tensor);
  243. auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
  244. if (device_address == nullptr) {
  245. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  246. TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
  247. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
  248. } else {
  249. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
  250. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
  251. kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
  252. AnfAlgo::SetOutputAddr(device_address, 0, param.get());
  253. }
  254. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
  255. // construct abstract of parameter
  256. auto type_of_tensor = input_tensor->Dtype();
  257. auto shape_of_tensor = input_tensor->shape();
  258. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  259. param->set_abstract(abstract);
  260. return param;
  261. }
  262. void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
  263. MS_LOG(INFO) << "Graph outputs:";
  264. const size_t max_deep = 10;
  265. if (recurse_level > max_deep) {
  266. MS_LOG(INFO) << "Recurse too deep";
  267. return;
  268. }
  269. std::string tab_str;
  270. for (size_t i = 0; i < recurse_level; i++) {
  271. tab_str = tab_str.append(" ");
  272. }
  273. if (any.is<AnyList>()) {
  274. (void)tab_str.append("{");
  275. MS_LOG(INFO) << tab_str;
  276. auto any_list = any.cast<AnyList>();
  277. for (auto &it : any_list) {
  278. DumpGraphOutput(it, recurse_level + 1);
  279. }
  280. (void)tab_str.append("}");
  281. MS_LOG(INFO) << tab_str;
  282. }
  283. (void)tab_str.append(any.ToString());
  284. MS_LOG(INFO) << tab_str;
  285. }
  286. bool ExistSummaryNode(const KernelGraph *graph) {
  287. MS_EXCEPTION_IF_NULL(graph);
  288. auto ret = graph->get_return();
  289. MS_EXCEPTION_IF_NULL(ret);
  290. auto all_nodes = DeepLinkedGraphSearch(ret);
  291. for (auto &n : all_nodes) {
  292. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  293. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  294. return true;
  295. }
  296. }
  297. return false;
  298. }
  299. bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
  300. MS_EXCEPTION_IF_NULL(node);
  301. if (!AnfAlgo::CheckPrimitiveType(node, prim::kPrimMakeTuple)) {
  302. return false;
  303. }
  304. auto cnode = node->cast<CNodePtr>();
  305. MS_EXCEPTION_IF_NULL(cnode);
  306. const auto &node_inputs = cnode->inputs();
  307. for (size_t i = 1; i < node_inputs.size(); ++i) {
  308. if (!AnfAlgo::CheckPrimitiveType(node_inputs[i], prim::kPrimControlDepend)) {
  309. return false;
  310. }
  311. }
  312. return true;
  313. }
  314. } // namespace
  315. GraphId SessionBasic::graph_sum_ = 0;
  316. void SessionBasic::InitDevice(const std::string &device_name, uint32_t device_id) {
  317. device_id_ = device_id;
  318. context_ = std::make_shared<Context>(device_name, device_id);
  319. executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id);
  320. }
  321. KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
  322. auto it = graphs_.find(graph_id);
  323. if (it == graphs_.end()) {
  324. MS_LOG(WARNING) << "Can't find graph " << graph_id;
  325. return nullptr;
  326. }
  327. return it->second;
  328. }
  329. void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter) {
  330. auto graph_id = GetGraphIdByNode(out_node);
  331. if (graph_id == kInvalidGraphId) {
  332. return;
  333. }
  334. auto node_graph = GetGraph(graph_id);
  335. if (node_graph == nullptr) {
  336. return;
  337. }
  338. MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
  339. auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node);
  340. if (ref_node == nullptr) {
  341. MS_LOG(INFO) << "No corresponding internal output for output node";
  342. return;
  343. }
  344. size_t output_idx = 0;
  345. if (AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  346. output_idx = AnfAlgo::GetTupleGetItemOutIndex(out_node->cast<CNodePtr>());
  347. }
  348. auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx);
  349. auto ref_real_node = real_kernel.first;
  350. auto ref_real_node_index = real_kernel.second;
  351. if (ref_real_node->isa<CNode>() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) {
  352. auto kernel_info = ref_real_node->kernel_info();
  353. if (kernel_info == nullptr || !kernel_info->has_build_info()) {
  354. MS_LOG(INFO) << "No kernel info";
  355. return;
  356. }
  357. if (!opt::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
  358. MS_LOG(INFO) << "No kernel address";
  359. return;
  360. }
  361. auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
  362. auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
  363. auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
  364. auto d_kernel_info = std::make_shared<device::KernelInfo>();
  365. MS_EXCEPTION_IF_NULL(d_kernel_info);
  366. parameter->set_kernel_info(d_kernel_info);
  367. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  368. builder.SetOutputsDeviceType({type});
  369. builder.SetOutputsFormat({format});
  370. d_kernel_info->set_select_kernel_build_info(builder.Build());
  371. AnfAlgo::SetOutputAddr(address, 0, parameter.get());
  372. AnfAlgo::SetOutputInferTypeAndShape({type}, {AnfAlgo::GetOutputInferShape(parameter, 0)}, parameter.get());
  373. }
  374. }
  375. std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
  376. MS_EXCEPTION_IF_NULL(node);
  377. MS_EXCEPTION_IF_NULL(graph);
  378. std::vector<AnfNodePtr> parameters;
  379. std::vector<AnfNodePtr> pre_graph_out = {node};
  380. if (IgnoreCreateParameterForMakeTuple(node)) {
  381. pre_graph_out.clear();
  382. }
  383. // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
  384. if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
  385. pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
  386. }
  387. auto valid_inputs = graph->MutableValidInputs();
  388. MS_EXCEPTION_IF_NULL(valid_inputs);
  389. auto graph_inputs = graph->MutableInputs();
  390. MS_EXCEPTION_IF_NULL(graph_inputs);
  391. auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
  392. auto parameter = graph->NewParameter();
  393. MS_EXCEPTION_IF_NULL(parameter);
  394. parameter->set_abstract(abstract);
  395. auto new_parameter = graph->NewParameter(parameter);
  396. parameters.push_back(new_parameter);
  397. valid_inputs->push_back(true);
  398. graph_inputs->push_back(new_parameter);
  399. };
  400. for (const auto &out_node : pre_graph_out) {
  401. MS_EXCEPTION_IF_NULL(out_node);
  402. auto abstract = out_node->abstract();
  403. MS_EXCEPTION_IF_NULL(abstract);
  404. // create multiple parameters if is a tuple output real kernel
  405. if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  406. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  407. MS_EXCEPTION_IF_NULL(tuple_abstract);
  408. MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
  409. for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
  410. create_parameter((*tuple_abstract)[output_idx]);
  411. }
  412. continue;
  413. }
  414. // create single parameter if is a abstract real kernel
  415. create_parameter(out_node->abstract());
  416. InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]);
  417. }
  418. return parameters;
  419. }
  420. ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  421. MS_EXCEPTION_IF_NULL(anf);
  422. if (!anf->isa<Parameter>()) {
  423. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  424. }
  425. MS_EXCEPTION_IF_NULL(graph);
  426. auto param_value = GetParamDefaultValue(anf);
  427. auto valid_inputs = graph->MutableValidInputs();
  428. MS_EXCEPTION_IF_NULL(valid_inputs);
  429. auto graph_inputs = graph->MutableInputs();
  430. MS_EXCEPTION_IF_NULL(graph_inputs);
  431. ParameterPtr new_parameter = nullptr;
  432. // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
  433. if (python_paras == nullptr) {
  434. python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
  435. }
  436. auto iter = python_paras->find(param_value);
  437. if (iter != python_paras->end()) {
  438. new_parameter = iter->second;
  439. } else {
  440. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  441. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  442. if (param_value != nullptr) {
  443. (*python_paras)[param_value] = new_parameter;
  444. }
  445. TraceManager::EndTrace();
  446. }
  447. new_parameter->IncreaseUsedGraphCount();
  448. graph_inputs->push_back(new_parameter);
  449. valid_inputs->push_back(true);
  450. return new_parameter;
  451. }
  452. AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) {
  453. MS_EXCEPTION_IF_NULL(anf);
  454. MS_EXCEPTION_IF_NULL(graph);
  455. MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
  456. auto parameters = CreateParameterFromTuple(anf, graph);
  457. if (parameters.empty()) {
  458. MS_LOG(INFO) << "Empty parameter from cnode";
  459. return nullptr;
  460. }
  461. if (parameters.size() == 1) {
  462. return parameters[0];
  463. }
  464. std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
  465. (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
  466. auto make_tuple = graph->NewCNode(make_tuple_input);
  467. MS_EXCEPTION_IF_NULL(make_tuple);
  468. MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
  469. return make_tuple;
  470. }
  471. void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) {
  472. MS_EXCEPTION_IF_NULL(cnode);
  473. MS_EXCEPTION_IF_NULL(cnode_inputs);
  474. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  475. if (prim != nullptr) {
  476. // push attr to inputs[0] of new cnode
  477. cnode_inputs->push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
  478. } else {
  479. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  480. MS_EXCEPTION_IF_NULL(fg);
  481. auto new_fg = BasicClone(fg);
  482. cnode_inputs->push_back(std::make_shared<ValueNode>(new_fg));
  483. }
  484. }
  485. void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
  486. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  487. MS_EXCEPTION_IF_NULL(cnode);
  488. MS_EXCEPTION_IF_NULL(graph);
  489. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  490. MS_EXCEPTION_IF_NULL(cnode_inputs);
  491. auto origin_inputs = cnode->inputs();
  492. bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
  493. origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>();
  494. bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3;
  495. // if has multiple depends,only select first depend as parameter
  496. for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
  497. auto anf = origin_inputs[input_idx];
  498. MS_EXCEPTION_IF_NULL(anf);
  499. // anf has been created before
  500. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  501. cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  502. continue;
  503. } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
  504. cnode_inputs->push_back((*other_graph_cnode)[anf]);
  505. continue;
  506. } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
  507. // if input is a value node,
  508. auto new_value_node = CreateNewValueNode(anf, graph);
  509. if (new_value_node != nullptr) {
  510. cnode_inputs->emplace_back(new_value_node);
  511. }
  512. continue;
  513. } else if (anf->isa<Parameter>()) {
  514. auto new_parameter = CreateNewParameterFromParameter(anf, graph);
  515. cnode_inputs->push_back(new_parameter);
  516. if (GetGraphIdByNode(anf) == kInvalidGraphId) {
  517. graph->FrontBackendlMapAdd(anf, new_parameter);
  518. } else {
  519. (*other_graph_cnode)[anf] = new_parameter;
  520. }
  521. continue;
  522. } else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
  523. cnode_inputs->push_back(origin_inputs[kRealInputIndexInDepend]);
  524. continue;
  525. } else if (optimize_control_depend) {
  526. cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
  527. } else {
  528. // the input node is a cnode from other graph
  529. auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph);
  530. if (parameter_from_cnode == nullptr) {
  531. parameter_from_cnode = NewValueNode(MakeValue(SizeToInt(input_idx)));
  532. }
  533. cnode_inputs->push_back(parameter_from_cnode);
  534. (*other_graph_cnode)[anf] = parameter_from_cnode;
  535. }
  536. }
  537. }
  538. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
  539. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  540. MS_EXCEPTION_IF_NULL(cnode);
  541. MS_EXCEPTION_IF_NULL(graph);
  542. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  543. // get primitive of old node
  544. std::vector<AnfNodePtr> cnode_inputs;
  545. GetCNodeInfo(cnode, &cnode_inputs);
  546. GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
  547. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  548. auto new_cnode = graph->NewCNode(cnode_inputs);
  549. TraceManager::EndTrace();
  550. return new_cnode;
  551. }
  552. CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) {
  553. MS_EXCEPTION_IF_NULL(node_input);
  554. MS_EXCEPTION_IF_NULL(graph);
  555. // switch input generalizes partial
  556. std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
  557. if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) {
  558. auto partial_node = graph->GetBackendAnfByFrontAnf(node_input);
  559. return partial_node->cast<CNodePtr>();
  560. } else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
  561. partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
  562. } else {
  563. KernelGraphPtr kernel_graph = NewKernelGraph();
  564. MS_EXCEPTION_IF_NULL(kernel_graph);
  565. auto parameter = CreateNewParameterFromCNode(graph->GetBackendAnfByFrontAnf(node_input), kernel_graph.get());
  566. auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
  567. auto return_node = kernel_graph->NewCNode({primitive, parameter});
  568. kernel_graph->set_return(return_node);
  569. partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
  570. partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
  571. }
  572. auto partial_node = graph->NewCNode(partial_inputs);
  573. return partial_node;
  574. }
  575. std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph) {
  576. MS_EXCEPTION_IF_NULL(cnode);
  577. MS_EXCEPTION_IF_NULL(graph);
  578. std::vector<AnfNodePtr> cnode_inputs = {
  579. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  580. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  581. MS_EXCEPTION_IF_NULL(attr_input);
  582. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  583. auto switch_cnode = cnode_input->cast<CNodePtr>();
  584. MS_EXCEPTION_IF_NULL(switch_cnode);
  585. if (cnode->inputs().size() < 2) {
  586. cnode_inputs = switch_cnode->inputs();
  587. return cnode_inputs;
  588. }
  589. std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
  590. switch_cnode->input(kFirstDataInputIndex)};
  591. for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
  592. auto node = switch_cnode->input(index);
  593. // there is real input in call, should put it to true and false branch in switch
  594. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
  595. auto partial_node = node->cast<CNodePtr>();
  596. MS_EXCEPTION_IF_NULL(partial_node);
  597. std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
  598. partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
  599. auto new_partial = graph->NewCNode(partial_inputs);
  600. switch_inputs.emplace_back(new_partial);
  601. }
  602. }
  603. if (switch_inputs.size() < kSwitchInputSize) {
  604. MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
  605. }
  606. auto switch_node = graph->NewCNode(switch_inputs);
  607. cnode_inputs.emplace_back(switch_node);
  608. return cnode_inputs;
  609. }
  610. std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
  611. MS_EXCEPTION_IF_NULL(cnode);
  612. MS_EXCEPTION_IF_NULL(graph);
  613. // create primitive of cnode:call(partial or switch)
  614. std::vector<AnfNodePtr> cnode_inputs = {
  615. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  616. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  617. MS_EXCEPTION_IF_NULL(attr_input);
  618. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  619. if (cnode_input == nullptr) {
  620. MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
  621. << ", but input[0] has not been created.";
  622. }
  623. // if the node is partial, insert the inputs of partial to the call
  624. if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
  625. auto partial_node = attr_input->cast<CNodePtr>();
  626. MS_EXCEPTION_IF_NULL(partial_node);
  627. auto partial_inputs = partial_node->inputs();
  628. std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
  629. std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
  630. MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
  631. return graph->GetBackendAnfByFrontAnf(node);
  632. });
  633. return cnode_inputs;
  634. } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
  635. return CreateCallSwitchInputs(cnode, graph);
  636. }
  637. MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
  638. }
  639. std::vector<AnfNodePtr> SessionBasic::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) {
  640. MS_EXCEPTION_IF_NULL(cnode);
  641. MS_EXCEPTION_IF_NULL(graph);
  642. std::vector<AnfNodePtr> cnode_inputs;
  643. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  644. MS_EXCEPTION_IF_NULL(attr_input);
  645. if (AnfAlgo::IsGraphKernel(cnode)) {
  646. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  647. MS_EXCEPTION_IF_NULL(fg);
  648. auto new_fg = BasicClone(fg);
  649. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  650. } else {
  651. // create primitive of cnode:call
  652. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  653. // create a ValueNode<KernelGraph> as input of cnode:call
  654. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  655. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
  656. } else {
  657. auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
  658. if (new_value_node != nullptr) {
  659. cnode_inputs.emplace_back(new_value_node);
  660. }
  661. }
  662. }
  663. return cnode_inputs;
  664. }
  665. void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs) {
  666. MS_EXCEPTION_IF_NULL(cnode);
  667. MS_EXCEPTION_IF_NULL(graph);
  668. if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
  669. cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
  670. for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) {
  671. auto node_input = cnode->input(index);
  672. auto switch_input = CreateSwitchInput(node_input, graph);
  673. cnode_inputs->emplace_back(switch_input);
  674. }
  675. } else {
  676. for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) {
  677. auto anf = cnode->input(input_idx);
  678. MS_EXCEPTION_IF_NULL(anf);
  679. // anf has been created before
  680. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  681. cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  682. continue;
  683. } else if (IsValueNode<None>(anf)) {
  684. continue;
  685. }
  686. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  687. }
  688. }
  689. }
  690. CNodePtr SessionBasic::CreateNewCNode(CNodePtr cnode, KernelGraph *graph) {
  691. MS_EXCEPTION_IF_NULL(cnode);
  692. MS_EXCEPTION_IF_NULL(graph);
  693. std::vector<AnfNodePtr> cnode_inputs;
  694. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  695. MS_EXCEPTION_IF_NULL(attr_input);
  696. if (IsValueNode<FuncGraph>(attr_input)) {
  697. // cnode is a graph or a call
  698. cnode_inputs = CreateValueNode(cnode, graph);
  699. } else if (attr_input->isa<CNode>()) {
  700. // cnode ia a call (partial/switch/switch_layer)
  701. // 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph
  702. // 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created
  703. cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
  704. } else {
  705. // get primitive of old node
  706. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  707. MS_EXCEPTION_IF_NULL(prim);
  708. // push attr to inputs[0] of new cnode
  709. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
  710. }
  711. // handle inputs of cnode except primitive
  712. CreateCNodeInputs(cnode, graph, &cnode_inputs);
  713. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  714. auto new_cnode = graph->NewCNode(cnode_inputs);
  715. TraceManager::EndTrace();
  716. // if the cnode is call switch, remove call
  717. if (new_cnode->inputs().size() > 1) {
  718. auto first_input = new_cnode->input(kFirstDataInputIndex);
  719. if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
  720. AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
  721. new_cnode = first_input->cast<CNodePtr>();
  722. }
  723. }
  724. return new_cnode;
  725. }
  726. ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
  727. MS_EXCEPTION_IF_NULL(anf);
  728. MS_EXCEPTION_IF_NULL(graph);
  729. auto value_node = anf->cast<ValueNodePtr>();
  730. MS_EXCEPTION_IF_NULL(value_node);
  731. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
  732. MS_EXCEPTION_IF_NULL(sub_func_graph);
  733. if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
  734. MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
  735. }
  736. auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
  737. ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
  738. new_value_node->set_abstract(value_node->abstract());
  739. // create new kernel_info of new value_node
  740. auto kernel_info = std::make_shared<device::KernelInfo>();
  741. kernel_info->SetFeatureMapFlag(false);
  742. new_value_node->set_kernel_info(kernel_info);
  743. // create kernel_build_info for new value node
  744. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  745. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  746. AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
  747. graph->FrontBackendlMapAdd(anf, new_value_node);
  748. return new_value_node;
  749. }
  750. ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  751. MS_EXCEPTION_IF_NULL(anf);
  752. MS_EXCEPTION_IF_NULL(graph);
  753. if (!anf->isa<Parameter>()) {
  754. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  755. }
  756. auto param_value = GetParamDefaultValue(anf);
  757. ParameterPtr new_parameter = nullptr;
  758. if (python_paras == nullptr) {
  759. python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
  760. }
  761. auto iter = python_paras->find(param_value);
  762. if (iter != python_paras->end()) {
  763. new_parameter = iter->second;
  764. } else {
  765. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  766. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  767. if (param_value != nullptr) {
  768. (*python_paras)[param_value] = new_parameter;
  769. }
  770. TraceManager::EndTrace();
  771. }
  772. new_parameter->IncreaseUsedGraphCount();
  773. return new_parameter;
  774. }
  775. KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  776. std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
  777. auto graph = NewKernelGraph();
  778. MS_EXCEPTION_IF_NULL(graph);
  779. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  780. for (const auto &node : lst) {
  781. MS_EXCEPTION_IF_NULL(node);
  782. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  783. if (!node->isa<CNode>()) {
  784. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
  785. }
  786. auto cnode = node->cast<CNodePtr>();
  787. MS_EXCEPTION_IF_NULL(cnode);
  788. // create a new cnode object
  789. auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode);
  790. MS_EXCEPTION_IF_NULL(new_cnode);
  791. new_cnode->set_abstract(cnode->abstract());
  792. new_cnode->set_scope(cnode->scope());
  793. // record map relations between anf from ME and new anf node used in backend
  794. graph->FrontBackendlMapAdd(node, new_cnode);
  795. }
  796. // add a make_tuple at the end of graph as output
  797. graph->set_output(ConstructOutput(outputs, graph));
  798. MS_EXCEPTION_IF_NULL(context_);
  799. FuncGraphManagerPtr manager = MakeManager({graph});
  800. if (manager) {
  801. manager->AddFuncGraph(graph);
  802. graph->set_manager(manager);
  803. }
  804. graph->SetExecOrderByDefault();
  805. if (ExistSummaryNode(graph.get())) {
  806. graph->set_summary_node_exist(true);
  807. }
  808. opt::BackendCommonOptimization(graph);
  809. return graph;
  810. }
  811. void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) {
  812. MS_EXCEPTION_IF_NULL(node);
  813. MS_EXCEPTION_IF_NULL(graph);
  814. auto cnode = node->cast<CNodePtr>();
  815. MS_EXCEPTION_IF_NULL(cnode);
  816. // create a new cnode object
  817. auto new_cnode = CreateNewCNode(cnode, graph.get());
  818. MS_EXCEPTION_IF_NULL(new_cnode);
  819. new_cnode->set_abstract(cnode->abstract());
  820. new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
  821. new_cnode->set_scope(cnode->scope());
  822. graph->FrontBackendlMapAdd(node, new_cnode);
  823. if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
  824. graph->set_return(new_cnode);
  825. }
  826. }
  827. std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
  828. std::vector<KernelGraphPtr> *all_out_graph) {
  829. MS_EXCEPTION_IF_NULL(func_graph);
  830. MS_EXCEPTION_IF_NULL(all_out_graph);
  831. auto node_list = TopoSort(func_graph->get_return());
  832. auto graph = NewKernelGraph();
  833. MS_EXCEPTION_IF_NULL(graph);
  834. front_backend_graph_map_[func_graph] = graph;
  835. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  836. bool is_trace_back = false;
  837. for (const auto &node : node_list) {
  838. MS_EXCEPTION_IF_NULL(node);
  839. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  840. if (node->isa<Parameter>()) {
  841. auto graph_inputs = graph->MutableInputs();
  842. MS_EXCEPTION_IF_NULL(graph_inputs);
  843. auto new_parameter = CreateNewParameter(node, graph.get());
  844. graph_inputs->push_back(new_parameter);
  845. graph->FrontBackendlMapAdd(node, new_parameter);
  846. continue;
  847. } else if (node->isa<ValueNode>()) {
  848. if (!IsValueNode<FuncGraph>(node)) {
  849. // if input is a common value node,
  850. (void)CreateNewValueNode(node, graph.get());
  851. } else {
  852. // if input is a ValueNode<FuncGraph>
  853. FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
  854. if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
  855. is_trace_back = true;
  856. } else {
  857. (void)ConstructKernelGraph(child_graph, all_out_graph);
  858. }
  859. (void)CreateValueNodeKernelGraph(node, graph.get());
  860. }
  861. continue;
  862. } else {
  863. CreateCNodeKernelGraph(node, graph);
  864. }
  865. }
  866. // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
  867. graph->set_output_null(is_trace_back);
  868. AddParameterToGraphInputs(func_graph->parameters(), graph.get());
  869. graph->SetExecOrderByDefault();
  870. if (ExistSummaryNode(graph.get())) {
  871. graph->set_summary_node_exist(true);
  872. }
  873. all_out_graph->push_back(graph);
  874. return graph;
  875. }
  876. void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) {
  877. MS_EXCEPTION_IF_NULL(graph);
  878. auto graph_inputs = graph->MutableInputs();
  879. MS_EXCEPTION_IF_NULL(graph_inputs);
  880. graph_inputs->clear();
  881. for (auto &parameter : parameters) {
  882. MS_EXCEPTION_IF_NULL(parameter);
  883. auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
  884. if (backend_parameter == nullptr) {
  885. // for example "def f(x,y,z) {return x + y}", parameter z in unused
  886. auto new_parameter = CreateNewParameter(parameter, graph);
  887. graph_inputs->push_back(new_parameter);
  888. MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
  889. continue;
  890. }
  891. MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
  892. graph_inputs->push_back(backend_parameter);
  893. }
  894. }
  895. namespace {
  896. bool TensorNeedSync(const AnfNodePtr &parameter, const tensor::TensorPtr &tensor) {
  897. auto ms_context = MsContext::GetInstance();
  898. MS_EXCEPTION_IF_NULL(ms_context);
  899. auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0);
  900. if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
  901. return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
  902. }
  903. if (tensor->NeedSyncHostToDevice()) {
  904. return true;
  905. }
  906. auto tensor_address = tensor->device_address();
  907. if (tensor_address != device_address) {
  908. tensor->data_sync(false);
  909. return true;
  910. }
  911. return false;
  912. }
  913. } // namespace
  914. // run graph steps
  915. void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  916. const std::vector<tensor::TensorPtr> &inputs_const) const {
  917. std::vector<tensor::TensorPtr> inputs(inputs_const);
  918. size_t input_ctrl_size = 3;
  919. MS_EXCEPTION_IF_NULL(kernel_graph);
  920. if (kernel_graph->input_ctrl_tensors()) {
  921. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  922. }
  923. std::vector<AnfNodePtr> input_nodes;
  924. for (const auto &input_node : kernel_graph->inputs()) {
  925. auto params = AnfAlgo::GetAllOutput(input_node);
  926. std::copy(params.begin(), params.end(), std::back_inserter(input_nodes));
  927. }
  928. if ((inputs.size() + input_ctrl_size) - 3 != input_nodes.size()) {
  929. MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  930. << ", input_ctrl_size:" << input_ctrl_size;
  931. }
  932. auto ms_context = MsContext::GetInstance();
  933. MS_EXCEPTION_IF_NULL(ms_context);
  934. for (size_t i = 0; i < inputs.size(); ++i) {
  935. auto tensor = inputs[i];
  936. MS_EXCEPTION_IF_NULL(tensor);
  937. auto input_node = input_nodes[i];
  938. MS_EXCEPTION_IF_NULL(input_node);
  939. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0) && TensorNeedSync(input_node, tensor)) {
  940. auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
  941. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode ||
  942. AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
  943. tensor->set_device_address(device_address);
  944. }
  945. MS_EXCEPTION_IF_NULL(device_address);
  946. if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0),
  947. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  948. tensor->data_c())) {
  949. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  950. }
  951. }
  952. tensor->set_sync_status(kNoNeedSync);
  953. }
  954. }
  955. void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  956. const std::vector<tensor::TensorPtr> &input_tensors) const {
  957. MS_EXCEPTION_IF_NULL(kernel_graph);
  958. MS_EXCEPTION_IF_NULL(outputs);
  959. std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
  960. auto anf_outputs = kernel_graph->outputs();
  961. for (auto &item : anf_outputs) {
  962. MS_EXCEPTION_IF_NULL(item);
  963. MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
  964. outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, &tensor_to_node));
  965. }
  966. for (auto &item : tensor_to_node) {
  967. auto &tensor = item.first;
  968. auto &node = item.second.first;
  969. auto &output_index = item.second.second;
  970. auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
  971. tensor->set_device_address(address);
  972. tensor->SetNeedWait(false);
  973. }
  974. }
  975. std::vector<tensor::TensorPtr> SessionBasic::GetNeedLockInputTensors(const GraphId &graph_id,
  976. const std::vector<tensor::TensorPtr> &inputs) {
  977. auto graph = GetGraph(graph_id);
  978. MS_EXCEPTION_IF_NULL(graph);
  979. bool has_optimizer = false;
  980. for (const auto &cnode : graph->execution_order()) {
  981. MS_EXCEPTION_IF_NULL(cnode);
  982. if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(cnode)) != kOptOperatorSet.end()) {
  983. has_optimizer = true;
  984. break;
  985. }
  986. }
  987. if (!has_optimizer) {
  988. return {};
  989. }
  990. std::vector<tensor::TensorPtr> result;
  991. for (auto &tensor : inputs) {
  992. if (!tensor->NeedWait()) {
  993. result.emplace_back(tensor);
  994. }
  995. }
  996. return result;
  997. }
  998. void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
  999. VectorRef *outputs,
  1000. std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
  1001. auto kernel_graph = GetGraph(graph_id);
  1002. MS_EXCEPTION_IF_NULL(kernel_graph);
  1003. MS_EXCEPTION_IF_NULL(outputs);
  1004. MS_EXCEPTION_IF_NULL(tensor_to_node);
  1005. auto anf_outputs = kernel_graph->outputs();
  1006. for (auto &item : anf_outputs) {
  1007. MS_EXCEPTION_IF_NULL(item);
  1008. MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
  1009. outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node));
  1010. }
  1011. }
  1012. void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
  1013. MS_EXCEPTION_IF_NULL(callback);
  1014. summary_callback_ = callback;
  1015. }
  1016. void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
  1017. void SessionBasic::RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs) {
  1018. auto node_list = TopoSort(func_graph->get_return());
  1019. size_t tensor_index = 0;
  1020. for (const auto &node : node_list) {
  1021. MS_EXCEPTION_IF_NULL(node);
  1022. if (node->isa<CNode>()) {
  1023. AbstractBasePtrList input_abstracts;
  1024. for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) {
  1025. auto input_node = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), index);
  1026. MS_EXCEPTION_IF_NULL(input_node);
  1027. auto abstract = input_node->abstract();
  1028. MS_EXCEPTION_IF_NULL(abstract);
  1029. input_abstracts.emplace_back(abstract);
  1030. }
  1031. auto prim = AnfAlgo::GetCNodePrimitive(node);
  1032. if (prim->isa<PrimitiveC>()) {
  1033. auto prim_c = prim->cast<std::shared_ptr<PrimitiveC>>();
  1034. MS_EXCEPTION_IF_NULL(prim_c);
  1035. auto abstract = prim_c->Infer(input_abstracts);
  1036. node->set_abstract(abstract);
  1037. } else {
  1038. node->set_abstract(
  1039. std::make_shared<tensor::Tensor>(kNumberTypeFloat32, std::vector<int>{32, 64, 218, 218})->ToAbstract());
  1040. }
  1041. } else if (node->isa<Parameter>()) {
  1042. if (tensor_index > inputs.size()) {
  1043. MS_EXCEPTION(IndexError) << "Index " << tensor_index << "is out of " << inputs.size() << "tensor's size";
  1044. }
  1045. node->set_abstract(inputs[tensor_index++]->ToAbstract());
  1046. } else {
  1047. auto value_node = node->cast<ValueNodePtr>();
  1048. MS_EXCEPTION_IF_NULL(value_node);
  1049. auto value = value_node->value();
  1050. MS_EXCEPTION_IF_NULL(value);
  1051. value_node->set_abstract(value->ToAbstract());
  1052. }
  1053. }
  1054. }
  1055. void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
  1056. MS_LOG(DEBUG) << "Update summary Start";
  1057. MS_EXCEPTION_IF_NULL(graph);
  1058. if (!graph->summary_node_exist()) {
  1059. return;
  1060. }
  1061. auto summary = graph->summary_nodes();
  1062. auto apply_list = TopoSort(graph->get_return());
  1063. for (auto &n : apply_list) {
  1064. MS_EXCEPTION_IF_NULL(n);
  1065. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  1066. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  1067. auto cnode = n->cast<CNodePtr>();
  1068. MS_EXCEPTION_IF_NULL(cnode);
  1069. if (cnode->inputs().size() <= kSummaryGetItem) {
  1070. MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!";
  1071. }
  1072. auto node = cnode->input(kSummaryGetItem);
  1073. MS_EXCEPTION_IF_NULL(node);
  1074. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  1075. MS_EXCEPTION_IF_NULL(item_with_index.first);
  1076. if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
  1077. MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
  1078. }
  1079. summary[n->fullname_with_scope()] = item_with_index;
  1080. }
  1081. }
  1082. graph->set_summary_nodes(summary);
  1083. MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
  1084. }
  1085. void SessionBasic::Summary(KernelGraph *graph) {
  1086. if (summary_callback_ == nullptr) {
  1087. return;
  1088. }
  1089. MS_EXCEPTION_IF_NULL(graph);
  1090. bool exist_summary = graph->summary_node_exist();
  1091. if (!exist_summary) {
  1092. return;
  1093. }
  1094. SetSummaryNodes(graph);
  1095. auto summary_outputs = graph->summary_nodes();
  1096. std::map<std::string, tensor::TensorPtr> params_list;
  1097. // fetch outputs apply kernel in session & run callback functions
  1098. for (auto &output_item : summary_outputs) {
  1099. auto node = output_item.second.first;
  1100. size_t index = IntToSize(output_item.second.second);
  1101. auto address = AnfAlgo::GetOutputAddr(node, index);
  1102. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  1103. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  1104. std::vector<int> temp_shape;
  1105. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  1106. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  1107. MS_EXCEPTION_IF_NULL(address);
  1108. if (!address->GetPtr()) {
  1109. continue;
  1110. }
  1111. if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
  1112. tensor->data_type(), tensor->data_c())) {
  1113. MS_LOG(ERROR) << "Failed to sync output from device to host.";
  1114. }
  1115. tensor->set_sync_status(kNoNeedSync);
  1116. params_list[output_item.first] = tensor;
  1117. }
  1118. // call callback function here
  1119. summary_callback_(0, params_list);
  1120. }
  1121. namespace {
  1122. bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
  1123. if (node == nullptr) {
  1124. return false;
  1125. }
  1126. auto cnode = node->cast<CNodePtr>();
  1127. if (cnode == nullptr) {
  1128. return false;
  1129. }
  1130. auto prim = cnode->input(kAnfPrimitiveIndex);
  1131. if (prim == nullptr || !IsValueNode<Primitive>(prim)) {
  1132. return false;
  1133. }
  1134. return true;
  1135. }
  1136. void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
  1137. const FuncGraphManagerPtr &front_func_graph_manager,
  1138. const std::shared_ptr<KernelGraph> &backend_graph) {
  1139. // When init parameter from cnode of other graphs, the cnode will not be real kernel except for tuple_getitem.
  1140. if (!AnfAlgo::IsRealKernel(front_node) && !AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
  1141. return;
  1142. }
  1143. auto node_users = front_func_graph_manager->node_users();
  1144. auto users = node_users[front_node];
  1145. auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0);
  1146. auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0);
  1147. auto front_real_kernel = front_real_kernel_pair.first;
  1148. std::string kernel_target = GetCNodeTarget(front_real_kernel);
  1149. bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel);
  1150. bool unique_target = true;
  1151. if (internal_output && opt::IsNopNode(front_real_kernel)) {
  1152. auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0);
  1153. auto pre_node_target = GetCNodeTarget(pre_node_pair.first);
  1154. if (pre_node_target != kernel_target) {
  1155. unique_target = false;
  1156. }
  1157. }
  1158. if (internal_output) {
  1159. for (auto user : users) {
  1160. if (!CNodeFirstInputIsPrimitive(user.first)) {
  1161. internal_output = false;
  1162. break;
  1163. }
  1164. if (!AnfAlgo::IsRealKernel(user.first)) {
  1165. internal_output = false;
  1166. break;
  1167. }
  1168. if (kernel_target != GetCNodeTarget(user.first)) {
  1169. unique_target = false;
  1170. }
  1171. }
  1172. }
  1173. if (internal_output) {
  1174. MS_LOG(INFO) << "Internal output: " << front_node->DebugString() << " To "
  1175. << backend_real_kernel_pair.first->DebugString() << ", unique_target: " << unique_target;
  1176. backend_graph->AddInternalOutput(front_node, backend_real_kernel_pair.first, backend_real_kernel_pair.second,
  1177. unique_target);
  1178. }
  1179. }
  1180. } // namespace
  1181. CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
  1182. MS_EXCEPTION_IF_NULL(graph);
  1183. std::vector<AnfNodePtr> output_args;
  1184. for (const auto &output : outputs) {
  1185. MS_EXCEPTION_IF_NULL(output);
  1186. MS_LOG(INFO) << "Output:" << output->DebugString();
  1187. }
  1188. auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
  1189. auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
  1190. if (backend_anf != nullptr) {
  1191. auto context_ptr = MsContext::GetInstance();
  1192. MS_EXCEPTION_IF_NULL(context_ptr);
  1193. if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
  1194. return backend_anf;
  1195. }
  1196. MS_EXCEPTION_IF_NULL(out);
  1197. auto out_func_graph = out->func_graph();
  1198. MS_EXCEPTION_IF_NULL(out_func_graph);
  1199. auto out_func_graph_manager = out_func_graph->manager();
  1200. if (out_func_graph_manager == nullptr) {
  1201. return backend_anf;
  1202. }
  1203. HandleInternalOutput(out, backend_anf, out_func_graph_manager, graph);
  1204. return backend_anf;
  1205. }
  1206. MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
  1207. };
  1208. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  1209. (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
  1210. [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
  1211. return graph->NewCNode(output_args);
  1212. }
  1213. void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
  1214. MS_LOG(INFO) << "Start!";
  1215. std::vector<AnfNodePtr> make_tuple_inputs;
  1216. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  1217. MS_EXCEPTION_IF_NULL(graph);
  1218. if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
  1219. for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
  1220. auto idx = NewValueNode(SizeToInt(output_index));
  1221. MS_EXCEPTION_IF_NULL(idx);
  1222. auto imm = std::make_shared<Int32Imm>(output_index);
  1223. idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
  1224. auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
  1225. std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
  1226. std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
  1227. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
  1228. make_tuple_inputs.push_back(getitem);
  1229. }
  1230. } else {
  1231. make_tuple_inputs.push_back(cnode);
  1232. }
  1233. // create output
  1234. auto g_output = graph->NewCNode(make_tuple_inputs);
  1235. graph->set_output(g_output);
  1236. MS_LOG(INFO) << "Finish!";
  1237. }
  1238. std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  1239. const std::vector<tensor::TensorPtr> &input_tensors,
  1240. const std::vector<int> &tensors_mask) {
  1241. auto graph = std::make_shared<KernelGraph>();
  1242. std::vector<AnfNodePtr> inputs;
  1243. // set input[0]
  1244. PrimitivePtr op_prim = op_run_info.primitive;
  1245. MS_EXCEPTION_IF_NULL(op_prim);
  1246. inputs.push_back(std::make_shared<ValueNode>(op_prim));
  1247. // set input parameter
  1248. MS_LOG(INFO) << "Input tensor size: " << input_tensors.size();
  1249. if (input_tensors.size() != tensors_mask.size()) {
  1250. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  1251. << tensors_mask.size();
  1252. }
  1253. for (size_t i = 0; i < input_tensors.size(); ++i) {
  1254. if (tensors_mask[i] == kValueNodeTensorMask) {
  1255. auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]);
  1256. inputs.push_back(value_node);
  1257. continue;
  1258. }
  1259. auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
  1260. inputs.push_back(parameter);
  1261. auto mutable_inputs = graph->MutableInputs();
  1262. MS_EXCEPTION_IF_NULL(mutable_inputs);
  1263. mutable_inputs->push_back(parameter);
  1264. }
  1265. // set execution order
  1266. auto cnode = graph->NewCNode(inputs);
  1267. MS_EXCEPTION_IF_NULL(cnode);
  1268. // set abstract,which include inferred shapes and types
  1269. cnode->set_abstract(op_run_info.abstract);
  1270. // set execution order
  1271. std::vector<CNodePtr> exe_order = {cnode};
  1272. graph->set_execution_order(exe_order);
  1273. // set output
  1274. CreateOutputNode(cnode, graph);
  1275. return graph;
  1276. }
  1277. KernelGraphPtr SessionBasic::NewKernelGraph() {
  1278. auto graph = std::make_shared<KernelGraph>();
  1279. graph->set_graph_id(graph_sum_);
  1280. graphs_[graph_sum_++] = graph;
  1281. return graph;
  1282. }
  1283. AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list) {
  1284. MS_EXCEPTION_IF_NULL(push_node);
  1285. for (auto &node : node_list) {
  1286. if (node != nullptr && node->isa<CNode>()) {
  1287. for (auto input : node->cast<CNodePtr>()->inputs()) {
  1288. if (push_node == AnfAlgo::VisitKernel(input, 0).first) {
  1289. if (AnfAlgo::GetCNodeName(node) != kPullOpName) {
  1290. MS_LOG(EXCEPTION) << "The edge between Push and Pull node is invalid.";
  1291. }
  1292. return node;
  1293. }
  1294. }
  1295. }
  1296. }
  1297. return nullptr;
  1298. }
  1299. GraphId SessionBasic::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  1300. MS_EXCEPTION_IF_NULL(executor_);
  1301. return executor_->CompileGraph(shared_from_this(), lst, outputs);
  1302. }
  1303. GraphId SessionBasic::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
  1304. MS_EXCEPTION_IF_NULL(executor_);
  1305. return executor_->CompileGraph(shared_from_this(), func_graph);
  1306. }
  1307. void SessionBasic::BuildGraph(GraphId graph_id) {
  1308. MS_EXCEPTION_IF_NULL(executor_);
  1309. executor_->BuildGraph(shared_from_this(), graph_id);
  1310. }
  1311. void SessionBasic::BuildOp(OpRunInfo *op_run_info, const GraphInfo &graph_info,
  1312. const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int> &tensors_mask) {
  1313. MS_EXCEPTION_IF_NULL(executor_);
  1314. executor_->BuildOp(shared_from_this(), op_run_info, graph_info, input_tensors, tensors_mask);
  1315. }
  1316. void SessionBasic::RunOp(OpRunInfo *op_run_info, const GraphInfo &graph_info,
  1317. const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
  1318. MS_EXCEPTION_IF_NULL(executor_);
  1319. executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs);
  1320. }
  1321. void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
  1322. MS_EXCEPTION_IF_NULL(executor_);
  1323. executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs);
  1324. }
  1325. void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
  1326. VectorRef *outputs) {
  1327. MS_EXCEPTION_IF_NULL(executor_);
  1328. executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
  1329. }
  1330. bool IsDynamicShape(const NotNull<abstract::ShapePtr> &shape) {
  1331. return !std::all_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s > 0; });
  1332. }
  1333. bool IsNodeOutputDynamicShape(const CNodePtr &anf_node_ptr) {
  1334. MS_EXCEPTION_IF_NULL(anf_node_ptr);
  1335. auto base_shape = anf_node_ptr->Shape();
  1336. if (base_shape == nullptr) {
  1337. MS_LOG(INFO) << "Invalid bash shape ptr, node:" << anf_node_ptr->fullname_with_scope();
  1338. return false;
  1339. }
  1340. if (base_shape->isa<abstract::Shape>()) {
  1341. if (IsDynamicShape(NOT_NULL(base_shape->cast<abstract::ShapePtr>()))) {
  1342. return true;
  1343. }
  1344. } else if (base_shape->isa<abstract::TupleShape>()) {
  1345. auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
  1346. MS_EXCEPTION_IF_NULL(tuple_shape);
  1347. for (size_t i = 0; i < tuple_shape->size(); ++i) {
  1348. auto b_shp = (*tuple_shape)[i];
  1349. if (!b_shp->isa<abstract::Shape>()) {
  1350. continue;
  1351. }
  1352. if (IsDynamicShape(NOT_NULL(b_shp->cast<abstract::ShapePtr>()))) {
  1353. return true;
  1354. }
  1355. }
  1356. }
  1357. return false;
  1358. }
  1359. bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr) {
  1360. MS_EXCEPTION_IF_NULL(anf_node_ptr);
  1361. auto input_num = AnfAlgo::GetInputTensorNum(anf_node_ptr);
  1362. for (size_t i = 0; i < input_num; ++i) {
  1363. auto input_with_index = AnfAlgo::GetPrevNodeOutput(anf_node_ptr, i);
  1364. auto input = input_with_index.first;
  1365. auto index = input_with_index.second;
  1366. MS_EXCEPTION_IF_NULL(input);
  1367. auto base_shape = input->Shape();
  1368. if (base_shape == nullptr) {
  1369. MS_LOG(INFO) << "Invalid shape ptr, node:" << input->fullname_with_scope();
  1370. continue;
  1371. }
  1372. if (base_shape->isa<abstract::Shape>()) {
  1373. if (IsDynamicShape(NOT_NULL(base_shape->cast<abstract::ShapePtr>()))) {
  1374. return true;
  1375. }
  1376. } else if (base_shape->isa<abstract::TupleShape>()) {
  1377. auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
  1378. MS_EXCEPTION_IF_NULL(tuple_shape);
  1379. if (index >= tuple_shape->size()) {
  1380. MS_LOG(INFO) << "Node:" << anf_node_ptr->fullname_with_scope() << "Invalid index:" << index
  1381. << " and tuple_shape size:" << tuple_shape->size();
  1382. continue;
  1383. }
  1384. auto b_shp = (*tuple_shape)[index];
  1385. if (!b_shp->isa<abstract::Shape>()) {
  1386. continue;
  1387. }
  1388. if (IsDynamicShape(NOT_NULL(b_shp->cast<abstract::ShapePtr>()))) {
  1389. return true;
  1390. }
  1391. }
  1392. }
  1393. return false;
  1394. }
  1395. void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph) {
  1396. for (const auto &cnode : root_graph->execution_order()) {
  1397. auto output_dynamic = IsNodeOutputDynamicShape(NOT_NULL(cnode));
  1398. auto input_dynamic = IsNodeInputDynamicShape(NOT_NULL(cnode));
  1399. if (output_dynamic || input_dynamic) {
  1400. AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode);
  1401. MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
  1402. }
  1403. if (output_dynamic) {
  1404. AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cnode);
  1405. MS_LOG(INFO) << "Set Output Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
  1406. }
  1407. if (input_dynamic) {
  1408. AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cnode);
  1409. MS_LOG(INFO) << "Set Input Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
  1410. }
  1411. }
  1412. }
  1413. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  1414. void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
  1415. if (!ps::Util::IsRoleOfWorker()) {
  1416. MS_LOG(INFO) << "Not parameter server mode.";
  1417. return;
  1418. }
  1419. MS_EXCEPTION_IF_NULL(kernel_graph);
  1420. std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
  1421. for (auto &node : node_list) {
  1422. if (node != nullptr && node->isa<CNode>()) {
  1423. // Assign key for forward kernel EmbeddingLookup.
  1424. // The key will be assigned to embedding table ande Push kernel as well.
  1425. if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
  1426. size_t embedding_table_idx = 0;
  1427. auto embedding_table = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), embedding_table_idx);
  1428. size_t key = ps::worker.SetParamKey(embedding_table->fullname_with_scope());
  1429. AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
  1430. } else if (AnfAlgo::GetCNodeName(node) == kPushOpName) {
  1431. auto pull_node = FindPullNode(node, node_list);
  1432. if (!pull_node) {
  1433. MS_LOG(EXCEPTION) << "Assigning parameter key failed: can't find Pull node of the Push node.";
  1434. }
  1435. // Second input of Pull node is the trainable parameter.
  1436. size_t parameter_index = 1;
  1437. auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast<CNodePtr>(), parameter_index);
  1438. size_t key = ps::worker.SetParamKey(parameter_node->fullname_with_scope());
  1439. AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
  1440. AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node);
  1441. std::string optimizer_name = AnfAlgo::GetNodeAttr<std::string>(node, kAttrOptimizerType);
  1442. ps::worker.SetKeyOptimId(key, optimizer_name);
  1443. }
  1444. }
  1445. }
  1446. }
  1447. void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
  1448. const std::vector<tensor::TensorPtr> &inputs_const) {
  1449. if (!ps::Util::IsRoleOfWorker()) {
  1450. return;
  1451. }
  1452. std::vector<tensor::TensorPtr> inputs(inputs_const);
  1453. size_t input_ctrl_size = 1;
  1454. MS_EXCEPTION_IF_NULL(kernel_graph);
  1455. if (kernel_graph->input_ctrl_tensors()) {
  1456. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  1457. }
  1458. auto input_nodes = kernel_graph->inputs();
  1459. if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
  1460. MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  1461. << ", input_ctrl_size:" << input_ctrl_size;
  1462. }
  1463. auto ms_context = MsContext::GetInstance();
  1464. MS_EXCEPTION_IF_NULL(ms_context);
  1465. for (size_t i = 0; i < inputs.size(); ++i) {
  1466. auto tensor = inputs[i];
  1467. MS_EXCEPTION_IF_NULL(tensor);
  1468. auto input_node = input_nodes[i];
  1469. MS_EXCEPTION_IF_NULL(input_node);
  1470. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  1471. auto pk_node = input_node->cast<ParameterPtr>();
  1472. ps::worker.InitPSParamAndOptim(pk_node->fullname_with_scope(), tensor);
  1473. }
  1474. }
  1475. }
  1476. #endif
  1477. } // namespace session
  1478. } // namespace mindspore