You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.cc 52 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/session/session_basic.h"
  17. #include <utility>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include "pipeline/jit/parse/data_converter.h"
  22. #include "ir/manager.h"
  23. #include "ir/param_value.h"
  24. #include "backend/kernel_compiler/common_utils.h"
  25. #include "frontend/operator/ops.h"
  26. #include "common/trans.h"
  27. #include "utils/config_manager.h"
  28. #include "backend/session/anf_runtime_algorithm.h"
  29. #include "backend/kernel_compiler/oplib/oplib.h"
  30. #include "backend/optimizer/common/common_backend_optimization.h"
  31. #include "backend/optimizer/pass/const_input_to_attr_registry.h"
  32. #include "backend/optimizer/common/helper.h"
  33. #include "common/utils.h"
  34. #include "ir/dtype.h"
  35. #include "ir/anf.h"
  36. #include "ir/func_graph_cloner.h"
  37. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  38. #include "frontend/parallel/ps/worker.h"
  39. #include "frontend/parallel/ps/common.h"
  40. #include "frontend/parallel/ps/util.h"
  41. #endif
  42. namespace mindspore {
  43. namespace session {
  44. static std::shared_ptr<std::map<ValuePtr, ParameterPtr>> python_paras;
  45. void ClearPythonParasMap() { python_paras = nullptr; }
  46. namespace {
  47. const int kSummaryGetItem = 2;
  48. ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
  49. if (node == nullptr) {
  50. return nullptr;
  51. }
  52. auto parameter = node->cast<ParameterPtr>();
  53. if (parameter == nullptr || !parameter->has_default()) {
  54. return nullptr;
  55. }
  56. return parameter->default_param();
  57. }
  58. BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
  59. const std::vector<tensor::TensorPtr> &input_tensors) {
  60. MS_EXCEPTION_IF_NULL(node);
  61. MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
  62. // if node is a value node, no need sync addr from device to host
  63. if (!AnfAlgo::OutputAddrExist(node, output_index)) {
  64. if (node->isa<ValueNode>()) {
  65. auto value_node = node->cast<ValueNodePtr>();
  66. MS_EXCEPTION_IF_NULL(value_node);
  67. return value_node->value();
  68. }
  69. if (node->isa<Parameter>()) {
  70. for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
  71. if (input_idx >= input_tensors.size()) {
  72. MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
  73. }
  74. if (graph.inputs()[input_idx] == node) {
  75. return input_tensors[input_idx];
  76. }
  77. }
  78. MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr";
  79. }
  80. }
  81. // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
  82. auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
  83. MS_EXCEPTION_IF_NULL(address);
  84. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  85. TypeId type_id = kNumberTypeFloat32;
  86. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  87. std::vector<int> temp_shape;
  88. if (graph.IsInternalOutput(node, output_index)) {
  89. temp_shape.emplace_back(1);
  90. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  91. tensor->set_device_address(address);
  92. tensor->set_dirty(false);
  93. return tensor;
  94. }
  95. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  96. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  97. // if in paynative mode,data only copyed to host when user want to print data
  98. auto ms_context = MsContext::GetInstance();
  99. MS_EXCEPTION_IF_NULL(ms_context);
  100. if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
  101. tensor->set_device_address(address);
  102. tensor->set_dirty(false);
  103. } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
  104. LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
  105. MS_LOG(INFO) << "Output sync device to host error!!!";
  106. tensor->set_dirty(false);
  107. }
  108. return tensor;
  109. }
  110. BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  111. const std::vector<tensor::TensorPtr> &input_tensors) {
  112. MS_EXCEPTION_IF_NULL(anf);
  113. MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
  114. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  115. MS_EXCEPTION_IF_NULL(item_with_index.first);
  116. MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
  117. // special handle for maketuple
  118. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  119. auto cnode = item_with_index.first->cast<CNodePtr>();
  120. MS_EXCEPTION_IF_NULL(cnode);
  121. VectorRef ret;
  122. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  123. auto out = CreateTensorForOutput(cnode->input(i), graph, input_tensors);
  124. ret.push_back(out);
  125. }
  126. return ret;
  127. }
  128. // if is graph return nothing ,the function should return a null anylist
  129. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  130. if (size == 0) {
  131. return VectorRef();
  132. }
  133. return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
  134. }
  135. ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
  136. MS_EXCEPTION_IF_NULL(anf);
  137. MS_EXCEPTION_IF_NULL(graph);
  138. auto value_node = anf->cast<ValueNodePtr>();
  139. MS_EXCEPTION_IF_NULL(value_node);
  140. auto value = value_node->value();
  141. MS_EXCEPTION_IF_NULL(value);
  142. if (value->isa<None>()) {
  143. return nullptr;
  144. }
  145. auto new_value_node = graph->NewValueNode(value_node);
  146. graph->FrontBackendlMapAdd(anf, new_value_node);
  147. graph->AddValueNodeToGraph(new_value_node);
  148. return new_value_node;
  149. }
  150. size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
  151. MS_EXCEPTION_IF_NULL(graph);
  152. MS_LOG(INFO) << "Load kInputCtrlTensors";
  153. auto inputs_params = graph->input_ctrl_tensors();
  154. if (inputs_params == nullptr) {
  155. return 0;
  156. }
  157. if (inputs_params->size() < 2) {
  158. MS_LOG(EXCEPTION) << "Illegal inputs_params size";
  159. }
  160. auto tensor = (*inputs_params)[0];
  161. MS_EXCEPTION_IF_NULL(tensor);
  162. auto *val = static_cast<int32_t *>(tensor->data_c());
  163. MS_EXCEPTION_IF_NULL(val);
  164. *val = 0;
  165. tensor->set_dirty(true);
  166. // set loop_count to zero
  167. MS_EXCEPTION_IF_NULL(inputs);
  168. inputs->push_back(tensor);
  169. auto epoch_tensor = (*inputs_params)[1];
  170. MS_EXCEPTION_IF_NULL(epoch_tensor);
  171. auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c());
  172. MS_EXCEPTION_IF_NULL(epoch_val);
  173. *epoch_val = graph->current_epoch();
  174. epoch_tensor->set_dirty(true);
  175. inputs->push_back(epoch_tensor);
  176. MS_LOG(INFO) << "Load epoch_val:" << *epoch_val;
  177. graph->set_current_epoch(graph->current_epoch() + 1);
  178. return inputs_params->size();
  179. }
  180. ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor) {
  181. MS_EXCEPTION_IF_NULL(graph);
  182. MS_EXCEPTION_IF_NULL(input_tensor);
  183. auto value_node = std::make_shared<ValueNode>(input_tensor);
  184. MS_EXCEPTION_IF_NULL(value_node);
  185. // construct abstract of value node
  186. auto type_of_tensor = input_tensor->Dtype();
  187. auto shape_of_tensor = input_tensor->shape();
  188. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  189. value_node->set_abstract(abstract);
  190. // add value node to graph
  191. auto input_value_node = graph->NewValueNode(value_node);
  192. graph->AddValueNodeToGraph(input_value_node);
  193. return input_value_node;
  194. }
  195. ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
  196. int tensor_mask) {
  197. MS_EXCEPTION_IF_NULL(graph);
  198. auto param = graph->NewParameter();
  199. MS_EXCEPTION_IF_NULL(param);
  200. if (tensor_mask == kParameterWeightTensorMask) {
  201. param->set_default_param(input_tensor);
  202. }
  203. // set the kernel info of parameter
  204. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  205. MS_EXCEPTION_IF_NULL(input_tensor);
  206. auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
  207. if (device_address == nullptr) {
  208. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  209. TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
  210. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
  211. } else {
  212. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
  213. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
  214. }
  215. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
  216. // construct abstract of parameter
  217. auto type_of_tensor = input_tensor->Dtype();
  218. auto shape_of_tensor = input_tensor->shape();
  219. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  220. param->set_abstract(abstract);
  221. return param;
  222. }
  223. void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
  224. MS_LOG(INFO) << "Graph outputs:";
  225. const size_t max_deep = 10;
  226. if (recurse_level > max_deep) {
  227. MS_LOG(INFO) << "Recurse too deep";
  228. return;
  229. }
  230. std::string tab_str;
  231. for (size_t i = 0; i < recurse_level; i++) {
  232. tab_str = tab_str.append(" ");
  233. }
  234. if (any.is<AnyList>()) {
  235. (void)tab_str.append("{");
  236. MS_LOG(INFO) << tab_str;
  237. auto any_list = any.cast<AnyList>();
  238. for (auto &it : any_list) {
  239. DumpGraphOutput(it, recurse_level + 1);
  240. }
  241. (void)tab_str.append("}");
  242. MS_LOG(INFO) << tab_str;
  243. }
  244. (void)tab_str.append(any.ToString());
  245. MS_LOG(INFO) << tab_str;
  246. }
  247. bool ExistSummaryNode(const KernelGraph *graph) {
  248. MS_EXCEPTION_IF_NULL(graph);
  249. auto ret = graph->get_return();
  250. MS_EXCEPTION_IF_NULL(ret);
  251. auto all_nodes = DeepLinkedGraphSearch(ret);
  252. for (auto &n : all_nodes) {
  253. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  254. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  255. return true;
  256. }
  257. }
  258. return false;
  259. }
  260. } // namespace
  261. GraphId SessionBasic::graph_sum_ = 0;
  262. KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
  263. auto it = graphs_.find(graph_id);
  264. if (it == graphs_.end()) {
  265. MS_LOG(WARNING) << "Can't find graph " << graph_id;
  266. return nullptr;
  267. }
  268. return it->second;
  269. }
  270. void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter) {
  271. auto graph_id = GetGraphIdByNode(out_node);
  272. if (graph_id == kInvalidGraphId) {
  273. return;
  274. }
  275. auto node_graph = GetGraph(graph_id);
  276. if (node_graph == nullptr) {
  277. return;
  278. }
  279. MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
  280. auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node);
  281. if (ref_node == nullptr) {
  282. MS_LOG(INFO) << "No corresponding internal output for output node";
  283. return;
  284. }
  285. size_t output_idx = 0;
  286. if (AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  287. output_idx = AnfAlgo::GetTupleGetItemOutIndex(out_node->cast<CNodePtr>());
  288. }
  289. auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx);
  290. auto ref_real_node = real_kernel.first;
  291. auto ref_real_node_index = real_kernel.second;
  292. if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node, ref_real_node_index)) {
  293. auto kernel_info = ref_real_node->kernel_info();
  294. if (kernel_info == nullptr || !kernel_info->has_build_info()) {
  295. MS_LOG(INFO) << "No kernel info";
  296. return;
  297. }
  298. if (!opt::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
  299. MS_LOG(INFO) << "No kernel address";
  300. return;
  301. }
  302. auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
  303. auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
  304. auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
  305. auto d_kernel_info = std::make_shared<device::KernelInfo>();
  306. MS_EXCEPTION_IF_NULL(d_kernel_info);
  307. parameter->set_kernel_info(d_kernel_info);
  308. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  309. builder.SetOutputsDeviceType({type});
  310. builder.SetOutputsFormat({format});
  311. d_kernel_info->set_select_kernel_build_info(builder.Build());
  312. AnfAlgo::SetOutputAddr(address, 0, parameter.get());
  313. AnfAlgo::SetOutputInferTypeAndShape({type}, {AnfAlgo::GetOutputInferShape(parameter, 0)}, parameter.get());
  314. }
  315. }
  316. std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input,
  317. KernelGraph *graph) {
  318. MS_EXCEPTION_IF_NULL(node);
  319. MS_EXCEPTION_IF_NULL(graph);
  320. std::vector<AnfNodePtr> parameters;
  321. std::vector<AnfNodePtr> pre_graph_out = {node};
  322. // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
  323. if (!AnfAlgo::IsRealKernel(node)) {
  324. pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
  325. }
  326. auto valid_inputs = graph->MutableValidInputs();
  327. MS_EXCEPTION_IF_NULL(valid_inputs);
  328. auto graph_inputs = graph->MutableInputs();
  329. MS_EXCEPTION_IF_NULL(graph_inputs);
  330. auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
  331. auto parameter = graph->NewParameter();
  332. MS_EXCEPTION_IF_NULL(parameter);
  333. parameter->set_abstract(abstract);
  334. auto new_parameter = graph->NewParameter(parameter);
  335. parameters.push_back(new_parameter);
  336. valid_inputs->push_back(valid_input);
  337. graph_inputs->push_back(new_parameter);
  338. };
  339. for (const auto &out_node : pre_graph_out) {
  340. MS_EXCEPTION_IF_NULL(out_node);
  341. auto abstract = out_node->abstract();
  342. MS_EXCEPTION_IF_NULL(abstract);
  343. // create multiple parameters if is a tuple output real kernel
  344. if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  345. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  346. MS_EXCEPTION_IF_NULL(tuple_abstract);
  347. MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
  348. for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
  349. create_parameter((*tuple_abstract)[output_idx]);
  350. }
  351. continue;
  352. }
  353. // create single parameter if is a abstract real kernel
  354. create_parameter(out_node->abstract());
  355. InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]);
  356. }
  357. return parameters;
  358. }
  359. ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input,
  360. KernelGraph *graph) {
  361. MS_EXCEPTION_IF_NULL(anf);
  362. if (!anf->isa<Parameter>()) {
  363. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  364. }
  365. MS_EXCEPTION_IF_NULL(graph);
  366. auto param_value = GetParamDefaultValue(anf);
  367. auto valid_inputs = graph->MutableValidInputs();
  368. MS_EXCEPTION_IF_NULL(valid_inputs);
  369. auto graph_inputs = graph->MutableInputs();
  370. MS_EXCEPTION_IF_NULL(graph_inputs);
  371. ParameterPtr new_parameter = nullptr;
  372. // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
  373. if (python_paras == nullptr) {
  374. python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
  375. }
  376. auto iter = python_paras->find(param_value);
  377. if (iter != python_paras->end()) {
  378. new_parameter = iter->second;
  379. } else {
  380. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  381. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  382. if (param_value != nullptr) {
  383. (*python_paras)[param_value] = new_parameter;
  384. }
  385. TraceManager::EndTrace();
  386. }
  387. graph_inputs->push_back(new_parameter);
  388. valid_inputs->push_back(valid_input);
  389. return new_parameter;
  390. }
  391. AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
  392. MS_EXCEPTION_IF_NULL(anf);
  393. MS_EXCEPTION_IF_NULL(graph);
  394. MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
  395. auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
  396. if (parameters.empty()) {
  397. MS_LOG(EXCEPTION) << "No parameter exist!!";
  398. }
  399. if (parameters.size() == 1) {
  400. return parameters[0];
  401. }
  402. std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
  403. (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
  404. auto make_tuple = graph->NewCNode(make_tuple_input);
  405. MS_EXCEPTION_IF_NULL(make_tuple);
  406. MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
  407. return make_tuple;
  408. }
  409. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
  410. bool *from_other_graph,
  411. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  412. MS_EXCEPTION_IF_NULL(cnode);
  413. MS_EXCEPTION_IF_NULL(graph);
  414. MS_EXCEPTION_IF_NULL(from_other_graph);
  415. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  416. *from_other_graph = false;
  417. // get primitive of old node
  418. std::vector<AnfNodePtr> cnode_inputs;
  419. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  420. if (prim != nullptr) {
  421. // push attr to inputs[0] of new cnode
  422. cnode_inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
  423. } else {
  424. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  425. MS_EXCEPTION_IF_NULL(fg);
  426. auto new_fg = BasicClone(fg);
  427. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  428. }
  429. auto origin_inputs = cnode->inputs();
  430. bool optimize_depend = false;
  431. bool optimize_control_depend = false;
  432. if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
  433. origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) {
  434. optimize_depend = true;
  435. }
  436. if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3) {
  437. optimize_control_depend = true;
  438. }
  439. // if has multiple depends,only select first depend as parameter
  440. for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
  441. auto anf = origin_inputs[input_idx];
  442. MS_EXCEPTION_IF_NULL(anf);
  443. // anf has been created before
  444. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  445. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  446. continue;
  447. } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
  448. cnode_inputs.push_back((*other_graph_cnode)[anf]);
  449. continue;
  450. } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
  451. // if input is a value node,
  452. auto new_value_node = CreateNewValueNode(anf, graph);
  453. if (new_value_node != nullptr) {
  454. cnode_inputs.emplace_back(new_value_node);
  455. }
  456. continue;
  457. } else if (anf->isa<Parameter>()) {
  458. auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
  459. cnode_inputs.push_back(new_parameter);
  460. if (GetGraphIdByNode(anf) == kInvalidGraphId) {
  461. graph->FrontBackendlMapAdd(anf, new_parameter);
  462. } else {
  463. (*other_graph_cnode)[anf] = new_parameter;
  464. }
  465. continue;
  466. } else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
  467. cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
  468. continue;
  469. } else if (optimize_control_depend) {
  470. cnode_inputs.push_back(NewValueNode(MakeValue(input_idx)));
  471. } else {
  472. *from_other_graph = true;
  473. // the input node is a cnode from other graph
  474. auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
  475. cnode_inputs.push_back(parameter_from_cnode);
  476. (*other_graph_cnode)[anf] = parameter_from_cnode;
  477. }
  478. }
  479. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  480. auto new_cnode = graph->NewCNode(cnode_inputs);
  481. TraceManager::EndTrace();
  482. return new_cnode;
  483. }
  484. CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) {
  485. MS_EXCEPTION_IF_NULL(node_input);
  486. MS_EXCEPTION_IF_NULL(graph);
  487. // switch input generalizes partial
  488. if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) ||
  489. AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) {
  490. return node_input->cast<CNodePtr>();
  491. }
  492. if (node_input->isa<CNode>()) {
  493. MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call.";
  494. }
  495. std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
  496. if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
  497. partial_inputs.emplace_back(node_input);
  498. auto partial_node = graph->NewCNode(partial_inputs);
  499. return partial_node;
  500. }
  501. KernelGraphPtr kernel_graph = NewKernelGraph();
  502. MS_EXCEPTION_IF_NULL(kernel_graph);
  503. kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input));
  504. partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
  505. auto partial_node = graph->NewCNode(partial_inputs);
  506. return partial_node;
  507. }
  508. CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) {
  509. MS_EXCEPTION_IF_NULL(anf_node);
  510. MS_EXCEPTION_IF_NULL(graph);
  511. auto node = anf_node->cast<CNodePtr>();
  512. MS_EXCEPTION_IF_NULL(node);
  513. if (node->inputs().size() < kSwitchInputSize) {
  514. MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize;
  515. }
  516. auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimSwitch->name()));
  517. std::vector<AnfNodePtr> switch_inputs = {primitive, node->input(1)};
  518. for (size_t index = 2; index < node->inputs().size(); index++) {
  519. auto input = CreateSwitchInput(node->input(index), graph);
  520. switch_inputs.emplace_back(input);
  521. }
  522. auto switch_node = graph->NewCNode(switch_inputs);
  523. return switch_node;
  524. }
  525. std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
  526. MS_EXCEPTION_IF_NULL(cnode);
  527. MS_EXCEPTION_IF_NULL(graph);
  528. // create primitive of cnode:call(partial or switch)
  529. std::vector<AnfNodePtr> cnode_inputs = {
  530. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  531. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  532. MS_EXCEPTION_IF_NULL(attr_input);
  533. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  534. if (cnode_input == nullptr) {
  535. MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
  536. << ", but input[0] has not been created.";
  537. }
  538. // if the node is partial, insert the inputs of partial to the call
  539. if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
  540. auto partial_node = attr_input->cast<CNodePtr>();
  541. MS_EXCEPTION_IF_NULL(partial_node);
  542. auto partial_inputs = partial_node->inputs();
  543. std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
  544. std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
  545. MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
  546. return graph->GetBackendAnfByFrontAnf(node);
  547. });
  548. return cnode_inputs;
  549. } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
  550. auto switch_node = HandleSwitchInputs(cnode_input, graph);
  551. cnode_inputs.emplace_back(switch_node);
  552. return cnode_inputs;
  553. }
  554. MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
  555. }
  556. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
  557. MS_EXCEPTION_IF_NULL(cnode);
  558. MS_EXCEPTION_IF_NULL(graph);
  559. std::vector<AnfNodePtr> cnode_inputs;
  560. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  561. MS_EXCEPTION_IF_NULL(attr_input);
  562. if (AnfAlgo::IsGraphKernel(cnode)) {
  563. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  564. MS_EXCEPTION_IF_NULL(fg);
  565. auto new_fg = BasicClone(fg);
  566. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  567. } else if (IsValueNode<FuncGraph>(attr_input)) {
  568. // create primitive of cnode:call
  569. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  570. // create a ValueNode<KernelGraph> as input of cnode:call
  571. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  572. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
  573. } else {
  574. auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
  575. if (new_value_node != nullptr) {
  576. cnode_inputs.emplace_back(new_value_node);
  577. }
  578. }
  579. } else if (attr_input->isa<CNode>()) {
  580. cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
  581. } else {
  582. // get primitive of old node
  583. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  584. MS_EXCEPTION_IF_NULL(prim);
  585. // push attr to inputs[0] of new cnode
  586. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
  587. }
  588. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  589. auto anf = cnode->input(input_idx);
  590. MS_EXCEPTION_IF_NULL(anf);
  591. // anf has been created before
  592. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  593. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  594. continue;
  595. } else if (IsValueNode<None>(anf)) {
  596. continue;
  597. }
  598. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  599. }
  600. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  601. auto new_cnode = graph->NewCNode(cnode_inputs);
  602. TraceManager::EndTrace();
  603. return new_cnode;
  604. }
  605. ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
  606. MS_EXCEPTION_IF_NULL(anf);
  607. MS_EXCEPTION_IF_NULL(graph);
  608. auto value_node = anf->cast<ValueNodePtr>();
  609. MS_EXCEPTION_IF_NULL(value_node);
  610. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
  611. MS_EXCEPTION_IF_NULL(sub_func_graph);
  612. if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
  613. MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
  614. }
  615. auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
  616. ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
  617. new_value_node->set_abstract(value_node->abstract());
  618. // create new kernel_info of new value_node
  619. auto kernel_info = std::make_shared<device::KernelInfo>();
  620. kernel_info->SetFeatureMapFlag(false);
  621. new_value_node->set_kernel_info(kernel_info);
  622. // create kernel_build_info for new value node
  623. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  624. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  625. AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
  626. graph->FrontBackendlMapAdd(anf, new_value_node);
  627. return new_value_node;
  628. }
  629. ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  630. MS_EXCEPTION_IF_NULL(anf);
  631. MS_EXCEPTION_IF_NULL(graph);
  632. if (!anf->isa<Parameter>()) {
  633. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  634. }
  635. auto param_value = GetParamDefaultValue(anf);
  636. ParameterPtr new_parameter = nullptr;
  637. if (python_paras == nullptr) {
  638. python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
  639. }
  640. auto iter = python_paras->find(param_value);
  641. if (iter != python_paras->end()) {
  642. new_parameter = iter->second;
  643. } else {
  644. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  645. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  646. if (param_value != nullptr) {
  647. (*python_paras)[param_value] = new_parameter;
  648. }
  649. TraceManager::EndTrace();
  650. }
  651. return new_parameter;
  652. }
  653. KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  654. std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
  655. auto graph = NewKernelGraph();
  656. MS_EXCEPTION_IF_NULL(graph);
  657. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  658. size_t from_other_graph_depend_num = 0;
  659. for (const auto &node : lst) {
  660. MS_EXCEPTION_IF_NULL(node);
  661. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  662. if (!node->isa<CNode>()) {
  663. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
  664. }
  665. auto cnode = node->cast<CNodePtr>();
  666. MS_EXCEPTION_IF_NULL(cnode);
  667. // create a new cnode object
  668. bool from_other_graph = false;
  669. // only first depend from other graph can create
  670. bool valid_input = true;
  671. if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
  672. valid_input = false;
  673. }
  674. auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
  675. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
  676. from_other_graph_depend_num++;
  677. }
  678. MS_EXCEPTION_IF_NULL(new_cnode);
  679. new_cnode->set_abstract(cnode->abstract());
  680. new_cnode->set_scope(cnode->scope());
  681. // record map relations between anf from ME and new anf node used in backend
  682. graph->FrontBackendlMapAdd(node, new_cnode);
  683. }
  684. // add a make_tuple at the end of graph as output
  685. graph->set_output(ConstructOutput(outputs, graph));
  686. MS_EXCEPTION_IF_NULL(context_);
  687. FuncGraphManagerPtr manager = MakeManager({graph});
  688. if (manager) {
  689. manager->AddFuncGraph(graph);
  690. graph->set_manager(manager);
  691. }
  692. graph->SetExecOrderByDefault();
  693. if (ExistSummaryNode(graph.get())) {
  694. graph->set_summary_node_exist(true);
  695. }
  696. opt::BackendCommonOptimization(graph);
  697. return graph;
  698. }
  699. void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) {
  700. MS_EXCEPTION_IF_NULL(node);
  701. MS_EXCEPTION_IF_NULL(graph);
  702. auto cnode = node->cast<CNodePtr>();
  703. MS_EXCEPTION_IF_NULL(cnode);
  704. // create a new cnode object
  705. auto new_cnode = CreateNewCNode(cnode, graph.get());
  706. MS_EXCEPTION_IF_NULL(new_cnode);
  707. new_cnode->set_abstract(cnode->abstract());
  708. new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
  709. new_cnode->set_scope(cnode->scope());
  710. graph->FrontBackendlMapAdd(node, new_cnode);
  711. if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
  712. graph->set_return(new_cnode);
  713. }
  714. }
  715. std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
  716. std::vector<KernelGraphPtr> *all_out_graph) {
  717. MS_EXCEPTION_IF_NULL(func_graph);
  718. MS_EXCEPTION_IF_NULL(all_out_graph);
  719. auto node_list = TopoSort(func_graph->get_return());
  720. auto graph = NewKernelGraph();
  721. MS_EXCEPTION_IF_NULL(graph);
  722. front_backend_graph_map_[func_graph] = graph;
  723. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  724. bool is_trace_back = false;
  725. for (const auto &node : node_list) {
  726. MS_EXCEPTION_IF_NULL(node);
  727. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  728. if (node->isa<Parameter>()) {
  729. auto graph_inputs = graph->MutableInputs();
  730. MS_EXCEPTION_IF_NULL(graph_inputs);
  731. auto new_parameter = CreateNewParameter(node, graph.get());
  732. graph_inputs->push_back(new_parameter);
  733. graph->FrontBackendlMapAdd(node, new_parameter);
  734. continue;
  735. } else if (node->isa<ValueNode>()) {
  736. if (!IsValueNode<FuncGraph>(node)) {
  737. // if input is a common value node,
  738. (void)CreateNewValueNode(node, graph.get());
  739. } else {
  740. // if input is a ValueNode<FuncGraph>
  741. FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
  742. if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
  743. is_trace_back = true;
  744. } else {
  745. (void)ConstructKernelGraph(child_graph, all_out_graph);
  746. }
  747. (void)CreateValueNodeKernelGraph(node, graph.get());
  748. }
  749. continue;
  750. } else {
  751. CreateCNodeKernelGraph(node, graph);
  752. }
  753. }
  754. // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
  755. graph->set_output_null(is_trace_back);
  756. AddParameterToGraphInputs(func_graph->parameters(), graph.get());
  757. graph->SetExecOrderByDefault();
  758. if (ExistSummaryNode(graph.get())) {
  759. graph->set_summary_node_exist(true);
  760. }
  761. all_out_graph->push_back(graph);
  762. return graph;
  763. }
  764. void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) {
  765. MS_EXCEPTION_IF_NULL(graph);
  766. auto graph_inputs = graph->MutableInputs();
  767. MS_EXCEPTION_IF_NULL(graph_inputs);
  768. graph_inputs->clear();
  769. for (auto &parameter : parameters) {
  770. MS_EXCEPTION_IF_NULL(parameter);
  771. auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
  772. if (backend_parameter == nullptr) {
  773. // for example "def f(x,y,z) {return x + y}", parameter z in unused
  774. auto new_parameter = CreateNewParameter(parameter, graph);
  775. graph_inputs->push_back(new_parameter);
  776. MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
  777. continue;
  778. }
  779. MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
  780. graph_inputs->push_back(backend_parameter);
  781. }
  782. }
  783. namespace {
  784. bool TensorNeedSync(const AnfNodePtr &parameter, const tensor::TensorPtr &tensor) {
  785. auto ms_context = MsContext::GetInstance();
  786. MS_EXCEPTION_IF_NULL(ms_context);
  787. auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0);
  788. if (ms_context->enable_pynative_infer()) {
  789. return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
  790. }
  791. if (tensor->is_dirty()) {
  792. return true;
  793. }
  794. if (tensor->device_address() != device_address) {
  795. (void)tensor->data_sync();
  796. return true;
  797. }
  798. return false;
  799. }
  800. } // namespace
  801. // run graph steps
  802. void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  803. const std::vector<tensor::TensorPtr> &inputs_const) const {
  804. std::vector<tensor::TensorPtr> inputs(inputs_const);
  805. size_t input_ctrl_size = 2;
  806. MS_EXCEPTION_IF_NULL(kernel_graph);
  807. if (kernel_graph->input_ctrl_tensors()) {
  808. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  809. }
  810. std::vector<AnfNodePtr> input_nodes;
  811. for (const auto &input_node : kernel_graph->inputs()) {
  812. auto params = AnfAlgo::GetAllOutput(input_node);
  813. std::copy(params.begin(), params.end(), std::back_inserter(input_nodes));
  814. }
  815. if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) {
  816. MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  817. << ", input_ctrl_size:" << input_ctrl_size;
  818. }
  819. auto ms_context = MsContext::GetInstance();
  820. MS_EXCEPTION_IF_NULL(ms_context);
  821. for (size_t i = 0; i < inputs.size(); ++i) {
  822. auto tensor = inputs[i];
  823. MS_EXCEPTION_IF_NULL(tensor);
  824. auto input_node = input_nodes[i];
  825. if (TensorNeedSync(input_node, tensor) && input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  826. auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
  827. if (ms_context->execution_mode() == kPynativeMode ||
  828. AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
  829. tensor->set_device_address(device_address);
  830. }
  831. MS_EXCEPTION_IF_NULL(device_address);
  832. if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0),
  833. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  834. tensor->data_c())) {
  835. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  836. }
  837. }
  838. tensor->set_dirty(false);
  839. }
  840. }
  841. void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  842. const std::vector<tensor::TensorPtr> &input_tensors) const {
  843. MS_EXCEPTION_IF_NULL(kernel_graph);
  844. MS_EXCEPTION_IF_NULL(outputs);
  845. auto anf_outputs = kernel_graph->outputs();
  846. for (auto &item : anf_outputs) {
  847. MS_EXCEPTION_IF_NULL(item);
  848. MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
  849. outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors));
  850. }
  851. }
  852. void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
  853. MS_EXCEPTION_IF_NULL(callback);
  854. summary_callback_ = callback;
  855. }
  856. void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
  857. void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
  858. MS_LOG(DEBUG) << "Update summary Start";
  859. MS_EXCEPTION_IF_NULL(graph);
  860. if (!graph->summary_node_exist()) {
  861. return;
  862. }
  863. auto summary = graph->summary_nodes();
  864. auto apply_list = TopoSort(graph->get_return());
  865. for (auto &n : apply_list) {
  866. MS_EXCEPTION_IF_NULL(n);
  867. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  868. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  869. auto cnode = n->cast<CNodePtr>();
  870. MS_EXCEPTION_IF_NULL(cnode);
  871. if (cnode->inputs().size() <= kSummaryGetItem) {
  872. MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!";
  873. }
  874. auto node = cnode->input(kSummaryGetItem);
  875. MS_EXCEPTION_IF_NULL(node);
  876. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  877. MS_EXCEPTION_IF_NULL(item_with_index.first);
  878. if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
  879. MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
  880. }
  881. summary[n->fullname_with_scope()] = item_with_index;
  882. }
  883. }
  884. graph->set_summary_nodes(summary);
  885. MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
  886. }
  887. void SessionBasic::Summary(KernelGraph *graph) {
  888. if (summary_callback_ == nullptr) {
  889. return;
  890. }
  891. MS_EXCEPTION_IF_NULL(graph);
  892. bool exist_summary = graph->summary_node_exist();
  893. if (!exist_summary) {
  894. return;
  895. }
  896. SetSummaryNodes(graph);
  897. auto summary_outputs = graph->summary_nodes();
  898. std::map<std::string, tensor::TensorPtr> params_list;
  899. // fetch outputs apply kernel in session & run callback functions
  900. for (auto &output_item : summary_outputs) {
  901. auto node = output_item.second.first;
  902. size_t index = IntToSize(output_item.second.second);
  903. auto address = AnfAlgo::GetOutputAddr(node, index);
  904. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  905. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  906. std::vector<int> temp_shape;
  907. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  908. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  909. MS_EXCEPTION_IF_NULL(address);
  910. if (!address->GetPtr()) {
  911. continue;
  912. }
  913. if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
  914. tensor->data_type(), tensor->data_c())) {
  915. MS_LOG(ERROR) << "Failed to sync output from device to host.";
  916. }
  917. tensor->set_dirty(false);
  918. params_list[output_item.first] = tensor;
  919. }
  920. // call callback function here
  921. summary_callback_(0, params_list);
  922. }
  923. CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
  924. MS_EXCEPTION_IF_NULL(graph);
  925. std::vector<AnfNodePtr> output_args;
  926. for (const auto &output : outputs) {
  927. MS_EXCEPTION_IF_NULL(output);
  928. MS_LOG(INFO) << "Output:" << output->DebugString();
  929. }
  930. auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
  931. auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
  932. if (backend_anf != nullptr) {
  933. auto context_ptr = MsContext::GetInstance();
  934. MS_EXCEPTION_IF_NULL(context_ptr);
  935. if (context_ptr->execution_mode() == kPynativeMode) {
  936. return backend_anf;
  937. }
  938. auto front_real_kernel = AnfAlgo::VisitKernel(out, 0);
  939. auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0);
  940. MS_EXCEPTION_IF_NULL(out);
  941. auto out_func_graph = out->func_graph();
  942. MS_EXCEPTION_IF_NULL(out_func_graph);
  943. auto out_func_graph_manager = out_func_graph->manager();
  944. if (out_func_graph_manager == nullptr) {
  945. return backend_anf;
  946. }
  947. auto node_users = out_func_graph_manager->node_users();
  948. auto users = node_users[out];
  949. bool internal_output = true;
  950. std::string kernel_target = GetCNodeTarget(front_real_kernel.first);
  951. for (auto user : users) {
  952. auto cnode = user.first->cast<CNodePtr>();
  953. if (cnode == nullptr) {
  954. internal_output = false;
  955. break;
  956. }
  957. auto prim = cnode->input(kAnfPrimitiveIndex);
  958. if (prim == nullptr || !prim->isa<ValueNode>()) {
  959. internal_output = false;
  960. break;
  961. }
  962. if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) {
  963. internal_output = false;
  964. break;
  965. }
  966. }
  967. if (internal_output) {
  968. MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString();
  969. graph->AddInternalOutput(out, backend_real_kernel.first);
  970. }
  971. return backend_anf;
  972. }
  973. MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
  974. };
  975. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  976. (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
  977. [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
  978. return graph->NewCNode(output_args);
  979. }
  980. void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
  981. MS_LOG(INFO) << "Start!";
  982. std::vector<AnfNodePtr> make_tuple_inputs;
  983. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  984. MS_EXCEPTION_IF_NULL(graph);
  985. if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
  986. for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
  987. auto idx = NewValueNode(SizeToInt(output_index));
  988. MS_EXCEPTION_IF_NULL(idx);
  989. auto imm = std::make_shared<Int32Imm>(output_index);
  990. idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
  991. auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
  992. std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
  993. std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
  994. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
  995. make_tuple_inputs.push_back(getitem);
  996. }
  997. } else {
  998. make_tuple_inputs.push_back(cnode);
  999. }
  1000. // create output
  1001. auto g_output = graph->NewCNode(make_tuple_inputs);
  1002. graph->set_output(g_output);
  1003. // set graph manager,which now is only used to get valuenodes and hardware optimizing
  1004. MS_EXCEPTION_IF_NULL(context_);
  1005. FuncGraphManagerPtr manager = context_->manager();
  1006. if (manager != nullptr) {
  1007. manager->AddFuncGraph(graph);
  1008. graph->set_manager(manager);
  1009. }
  1010. MS_LOG(INFO) << "Finish!";
  1011. }
  1012. std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  1013. const std::vector<tensor::TensorPtr> &input_tensors,
  1014. const std::vector<int> &tensors_mask) {
  1015. auto graph = std::make_shared<KernelGraph>();
  1016. std::vector<AnfNodePtr> inputs;
  1017. // set input[0]
  1018. PrimitivePtr op_prim = op_run_info.py_primitive;
  1019. MS_EXCEPTION_IF_NULL(op_prim);
  1020. inputs.push_back(std::make_shared<ValueNode>(op_prim));
  1021. // set input parameter
  1022. MS_LOG(INFO) << "Input tensor size: " << input_tensors.size();
  1023. if (input_tensors.size() != tensors_mask.size()) {
  1024. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  1025. << tensors_mask.size();
  1026. }
  1027. for (size_t i = 0; i < input_tensors.size(); ++i) {
  1028. if (tensors_mask[i] == kValueNodeTensorMask) {
  1029. auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]);
  1030. inputs.push_back(value_node);
  1031. continue;
  1032. }
  1033. auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
  1034. inputs.push_back(parameter);
  1035. auto mutable_inputs = graph->MutableInputs();
  1036. MS_EXCEPTION_IF_NULL(mutable_inputs);
  1037. mutable_inputs->push_back(parameter);
  1038. }
  1039. // set execution order
  1040. auto cnode = graph->NewCNode(inputs);
  1041. MS_EXCEPTION_IF_NULL(cnode);
  1042. // set abstract,which include inferred shapes and types
  1043. cnode->set_abstract(op_run_info.abstract);
  1044. // set execution order
  1045. std::vector<CNodePtr> exe_order = {cnode};
  1046. graph->set_execution_order(exe_order);
  1047. // set output
  1048. CreateOutputNode(cnode, graph);
  1049. return graph;
  1050. }
  1051. BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
  1052. if (utils::isa<VectorRef>(base_ref)) {
  1053. auto ref_list = utils::cast<VectorRef>(base_ref);
  1054. py::tuple output_tensors(ref_list.size());
  1055. for (size_t i = 0; i < ref_list.size(); ++i) {
  1056. auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
  1057. if (utils::isa<tensor::TensorPtr>(output)) {
  1058. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  1059. MS_EXCEPTION_IF_NULL(tensor_ptr);
  1060. output_tensors[i] = tensor_ptr;
  1061. } else if (utils::isa<PyObjectRef>(output)) {
  1062. py::object obj = utils::cast<PyObjectRef>(output).object_;
  1063. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  1064. output_tensors[i] = tensor_tuple;
  1065. } else {
  1066. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  1067. }
  1068. }
  1069. return output_tensors; // turn tuple to py::object and store in PyObjectRef
  1070. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  1071. return base_ref;
  1072. } else {
  1073. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  1074. }
  1075. }
  1076. KernelGraphPtr SessionBasic::NewKernelGraph() {
  1077. auto graph = std::make_shared<KernelGraph>();
  1078. graph->set_graph_id(graph_sum_);
  1079. graphs_[graph_sum_++] = graph;
  1080. return graph;
  1081. }
  1082. AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list) {
  1083. MS_EXCEPTION_IF_NULL(push_node);
  1084. for (auto &node : node_list) {
  1085. if (node != nullptr && node->isa<CNode>()) {
  1086. for (auto input : node->cast<CNodePtr>()->inputs()) {
  1087. if (push_node == AnfAlgo::VisitKernel(input, 0).first) {
  1088. if (AnfAlgo::GetCNodeName(node) != kPullOpName) {
  1089. MS_LOG(EXCEPTION) << "The edge between Push and Pull node is invalid.";
  1090. }
  1091. return node;
  1092. }
  1093. }
  1094. }
  1095. }
  1096. return nullptr;
  1097. }
  1098. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  1099. void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
  1100. if (!parallel::ps::Util::IsRoleOfWorker()) {
  1101. MS_LOG(INFO) << "Not parameter server mode.";
  1102. return;
  1103. }
  1104. MS_EXCEPTION_IF_NULL(kernel_graph);
  1105. std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
  1106. for (auto &node : node_list) {
  1107. if (node != nullptr && node->isa<CNode>()) {
  1108. // Assign key for forward kernel EmbeddingLookup.
  1109. // The key will be assigned to embedding table ande Push kernel as well.
  1110. if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
  1111. size_t embedding_table_idx = 0;
  1112. auto embedding_table = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), embedding_table_idx);
  1113. size_t key = parallel::ps::Worker<float>::GetInstance().SetParamKey(embedding_table->fullname_with_scope());
  1114. AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
  1115. } else if (AnfAlgo::GetCNodeName(node) == kPushOpName) {
  1116. auto pull_node = FindPullNode(node, node_list);
  1117. if (!pull_node) {
  1118. MS_LOG(EXCEPTION) << "Assigning parameter key failed: can't find Pull node of the Push node.";
  1119. }
  1120. // Second input of Pull node is the trainable parameter.
  1121. size_t parameter_index = 1;
  1122. auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast<CNodePtr>(), parameter_index);
  1123. size_t key = parallel::ps::Worker<float>::GetInstance().SetParamKey(parameter_node->fullname_with_scope());
  1124. AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
  1125. AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node);
  1126. std::string optimizer_name = AnfAlgo::GetNodeAttr<std::string>(node, kAttrOptimizerType);
  1127. parallel::ps::Worker<float>::GetInstance().SetKeyOptimId(key, optimizer_name);
  1128. }
  1129. }
  1130. }
  1131. }
  1132. void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
  1133. const std::vector<tensor::TensorPtr> &inputs_const) {
  1134. if (!parallel::ps::Util::IsRoleOfWorker()) {
  1135. return;
  1136. }
  1137. std::vector<tensor::TensorPtr> inputs(inputs_const);
  1138. size_t input_ctrl_size = 1;
  1139. MS_EXCEPTION_IF_NULL(kernel_graph);
  1140. if (kernel_graph->input_ctrl_tensors()) {
  1141. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  1142. }
  1143. auto input_nodes = kernel_graph->inputs();
  1144. if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
  1145. MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  1146. << ", input_ctrl_size:" << input_ctrl_size;
  1147. }
  1148. auto ms_context = MsContext::GetInstance();
  1149. MS_EXCEPTION_IF_NULL(ms_context);
  1150. std::vector<int> shape_init_in_server = {1};
  1151. for (size_t i = 0; i < inputs.size(); ++i) {
  1152. auto tensor = inputs[i];
  1153. MS_EXCEPTION_IF_NULL(tensor);
  1154. auto input_node = input_nodes[i];
  1155. MS_EXCEPTION_IF_NULL(input_node);
  1156. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  1157. auto pk_node = input_node->cast<ParameterPtr>();
  1158. bool init_in_server = false;
  1159. if (tensor->shape_c() == shape_init_in_server) {
  1160. MS_LOG(INFO) << "The param need to be initialized in server " << pk_node->fullname_with_scope();
  1161. init_in_server = true;
  1162. }
  1163. mindspore::parallel::ps::Worker<float>::GetInstance().InitPSParamAndOptim(
  1164. pk_node->fullname_with_scope(), tensor->data_c(), LongToSize(tensor->data().nbytes()), init_in_server);
  1165. }
  1166. }
  1167. ps_init_ = true;
  1168. }
  1169. #endif
  1170. } // namespace session
  1171. } // namespace mindspore