You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.cc 39 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/session_basic.h"
  17. #include <utility>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include "pipeline/parse/data_converter.h"
  22. #include "ir/manager.h"
  23. #include "ir/param_value_py.h"
  24. #include "operator/ops.h"
  25. #include "common/trans.h"
  26. #include "utils/context/ms_context.h"
  27. #include "utils/config_manager.h"
  28. #include "session/anf_runtime_algorithm.h"
  29. #include "kernel/oplib/oplib.h"
  30. #include "pre_activate/common/common_backend_optimization.h"
  31. #include "pre_activate/pass/const_input_to_attr_registry.h"
  32. #include "pre_activate/common/helper.h"
  33. #include "common/utils.h"
  34. #include "ir/dtype.h"
  35. namespace mindspore {
  36. namespace session {
  37. static std::shared_ptr<std::map<tensor::TensorPtr, ParameterPtr>> python_paras_;
  38. void ClearPythonParasMap() { python_paras_ = nullptr; }
  39. namespace {
  40. const int kSummaryGetItem = 2;
  41. tensor::TensorPtr GetParamDefaultInputTensor(const AnfNodePtr &node) {
  42. if (node == nullptr) {
  43. return nullptr;
  44. }
  45. auto parameter = node->cast<ParameterPtr>();
  46. if (parameter == nullptr || !parameter->has_default()) {
  47. return nullptr;
  48. }
  49. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
  50. auto py_param = param_value->value();
  51. if (!py::hasattr(py_param, "default_input")) {
  52. return nullptr;
  53. }
  54. auto py_p_input = py_param.attr("default_input");
  55. if (!py::hasattr(py_p_input, PYTHON_TENSOR_FLAG)) {
  56. return nullptr;
  57. }
  58. return py_p_input.cast<std::shared_ptr<tensor::Tensor>>();
  59. }
  60. void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
  61. MS_LOG(DEBUG) << "Update summary Start";
  62. MS_EXCEPTION_IF_NULL(graph);
  63. MS_EXCEPTION_IF_NULL(summary);
  64. summary->clear();
  65. auto apply_list = TopoSort(graph->get_return());
  66. for (auto &n : apply_list) {
  67. MS_EXCEPTION_IF_NULL(n);
  68. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  69. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  70. auto cnode = n->cast<CNodePtr>();
  71. MS_EXCEPTION_IF_NULL(cnode);
  72. if (cnode->inputs().size() <= kSummaryGetItem) {
  73. MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
  74. }
  75. auto node = cnode->input(kSummaryGetItem);
  76. MS_EXCEPTION_IF_NULL(node);
  77. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
  78. if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
  79. MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
  80. }
  81. (*summary)[n->fullname_with_scope()] = item_with_index;
  82. }
  83. }
  84. MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
  85. }
  86. bool ExistSummaryNode(const KernelGraph *graph) {
  87. auto ret = graph->get_return();
  88. MS_EXCEPTION_IF_NULL(ret);
  89. auto all_nodes = DeepLinkedGraphSearch(ret);
  90. for (auto &n : all_nodes) {
  91. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  92. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  93. return true;
  94. }
  95. }
  96. return false;
  97. }
  98. BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
  99. const std::vector<tensor::TensorPtr> &input_tensors) {
  100. MS_EXCEPTION_IF_NULL(node);
  101. MS_LOG(INFO) << "create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
  102. // if node is a value node, no need sync addr from device to host
  103. if (!AnfAlgo::OutputAddrExist(node, output_index)) {
  104. if (node->isa<ValueNode>()) {
  105. auto value_node = node->cast<ValueNodePtr>();
  106. MS_EXCEPTION_IF_NULL(value_node);
  107. return value_node->value();
  108. }
  109. if (node->isa<Parameter>()) {
  110. for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
  111. if (input_idx > input_tensors.size()) {
  112. MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size();
  113. }
  114. if (graph.inputs()[input_idx] == node) {
  115. return input_tensors[input_idx];
  116. }
  117. }
  118. MS_LOG(EXCEPTION) << "parameter : " << node->DebugString() << "has no output addr";
  119. }
  120. }
  121. // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
  122. auto address = AnfAlgo::GetOutputAddr(node, output_index);
  123. MS_EXCEPTION_IF_NULL(address);
  124. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  125. TypeId type_id = kNumberTypeFloat32;
  126. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  127. std::vector<int> temp_shape;
  128. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  129. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  130. // if in paynative mode,data only copyed to host when user want to print data
  131. auto ms_context = MsContext::GetInstance();
  132. MS_EXCEPTION_IF_NULL(ms_context);
  133. if (ms_context->execution_mode() == kPynativeMode) {
  134. tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
  135. tensor->set_dirty(false);
  136. } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
  137. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  138. tensor->data_c(true))) {
  139. MS_LOG(INFO) << "output sync device to host error!!!";
  140. tensor->set_dirty(false);
  141. }
  142. return tensor;
  143. }
  144. BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  145. const std::vector<tensor::TensorPtr> &input_tensors) {
  146. MS_EXCEPTION_IF_NULL(anf);
  147. MS_LOG(INFO) << "create tensor for output[" << anf->DebugString() << "]";
  148. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  149. MS_EXCEPTION_IF_NULL(item_with_index.first);
  150. // special handle for maketuple
  151. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  152. auto cnode = item_with_index.first->cast<CNodePtr>();
  153. MS_EXCEPTION_IF_NULL(cnode);
  154. VectorRef ret;
  155. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  156. auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
  157. ret.push_back(out);
  158. }
  159. return ret;
  160. }
  161. // if is graph return nothing ,the function should return a null anylist
  162. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  163. if (size == 0) {
  164. return VectorRef();
  165. }
  166. return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
  167. }
  168. BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  169. const std::vector<tensor::TensorPtr> &input_tensors) {
  170. MS_EXCEPTION_IF_NULL(anf);
  171. if (!AnfAlgo::IsRealKernel(anf)) {
  172. MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] should be a executable kernel";
  173. }
  174. if (anf->isa<ValueNode>()) {
  175. return CreateOneTensor(anf, 0, graph, input_tensors);
  176. }
  177. VectorRef ret;
  178. if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
  179. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
  180. auto out = CreateOneTensor(anf, i, graph, input_tensors);
  181. ret.emplace_back(out);
  182. }
  183. }
  184. return ret;
  185. }
  186. ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
  187. auto value_node = anf->cast<ValueNodePtr>();
  188. MS_EXCEPTION_IF_NULL(value_node);
  189. auto value = value_node->value();
  190. MS_EXCEPTION_IF_NULL(value);
  191. if (value->isa<None>()) {
  192. return nullptr;
  193. }
  194. auto new_value_node = graph->NewValueNode(value_node);
  195. graph->FrontBackendlMapAdd(anf, new_value_node);
  196. graph->AddValueNodeToGraph(new_value_node);
  197. return new_value_node;
  198. }
  199. std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) {
  200. MS_EXCEPTION_IF_NULL(node);
  201. MS_EXCEPTION_IF_NULL(graph);
  202. std::vector<AnfNodePtr> parameters;
  203. std::vector<AnfNodePtr> pre_graph_out = {node};
  204. // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
  205. if (!AnfAlgo::IsRealKernel(node)) {
  206. pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
  207. }
  208. auto valid_inputs = graph->MutableValidInputs();
  209. MS_EXCEPTION_IF_NULL(valid_inputs);
  210. auto graph_inputs = graph->MutableInputs();
  211. MS_EXCEPTION_IF_NULL(graph_inputs);
  212. auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
  213. auto parameter = graph->NewParameter();
  214. MS_EXCEPTION_IF_NULL(parameter);
  215. parameter->set_abstract(abstract);
  216. auto new_parameter = graph->NewParameter(parameter);
  217. parameters.push_back(new_parameter);
  218. valid_inputs->push_back(valid_input);
  219. graph_inputs->push_back(new_parameter);
  220. };
  221. for (const auto &out_node : pre_graph_out) {
  222. MS_EXCEPTION_IF_NULL(out_node);
  223. auto abstract = out_node->abstract();
  224. MS_EXCEPTION_IF_NULL(abstract);
  225. // create multiple parameters if is a tuple output real kernel
  226. if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  227. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  228. MS_EXCEPTION_IF_NULL(tuple_abstract);
  229. MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]";
  230. for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
  231. create_parameter((*tuple_abstract)[output_idx]);
  232. }
  233. continue;
  234. }
  235. // create single parameter if is a abstract real kernel
  236. create_parameter(out_node->abstract());
  237. }
  238. return parameters;
  239. }
  240. size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
  241. MS_LOG(INFO) << "Load kInputCtrlTensors";
  242. auto inputs_params = graph->input_ctrl_tensors();
  243. if (inputs_params == nullptr) {
  244. return 0;
  245. }
  246. if (inputs_params->empty()) {
  247. MS_LOG(EXCEPTION) << "Illegal empty inputs_params";
  248. }
  249. auto tensor = (*inputs_params)[0];
  250. MS_EXCEPTION_IF_NULL(tensor);
  251. auto *val = static_cast<int32_t *>(tensor->data_c(true));
  252. MS_EXCEPTION_IF_NULL(val);
  253. *val = 0;
  254. tensor->set_dirty(true);
  255. // set loop_count to zero
  256. MS_EXCEPTION_IF_NULL(inputs);
  257. inputs->push_back(tensor);
  258. return inputs_params->size();
  259. }
  260. ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor) {
  261. MS_EXCEPTION_IF_NULL(graph);
  262. MS_EXCEPTION_IF_NULL(input_tensor);
  263. auto value_node = std::make_shared<ValueNode>(input_tensor);
  264. // construct abstract of value node
  265. auto type_of_tensor = input_tensor->Dtype();
  266. auto shape_of_tensor = input_tensor->shape();
  267. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  268. value_node->set_abstract(abstract);
  269. // add value node to graph
  270. auto input_value_node = graph->NewValueNode(value_node);
  271. graph->AddValueNodeToGraph(input_value_node);
  272. return input_value_node;
  273. }
  274. ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
  275. int tensor_mask) {
  276. auto param = graph->NewParameter();
  277. MS_EXCEPTION_IF_NULL(param);
  278. if (tensor_mask == kParameterWeightTensorMask) {
  279. py::object obj;
  280. auto param_value_new = std::make_shared<ParamValuePy>(obj);
  281. param->set_default_param(param_value_new);
  282. }
  283. // set the kernel info of parameter
  284. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  285. MS_EXCEPTION_IF_NULL(input_tensor);
  286. if (input_tensor->device_address().get() == nullptr) {
  287. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  288. TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
  289. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
  290. } else {
  291. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{input_tensor->device_address()->format()});
  292. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()});
  293. }
  294. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
  295. // construct abstract of parameter
  296. auto type_of_tensor = input_tensor->Dtype();
  297. auto shape_of_tensor = input_tensor->shape();
  298. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  299. param->set_abstract(abstract);
  300. return param;
  301. }
  302. void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
  303. MS_LOG(INFO) << "graph outputs:";
  304. const size_t max_deep = 10;
  305. if (recurse_level > max_deep) {
  306. MS_LOG(INFO) << "recurse too deep";
  307. return;
  308. }
  309. std::string tab_str;
  310. for (size_t i = 0; i < recurse_level; i++) {
  311. tab_str = tab_str.append(" ");
  312. }
  313. if (any.is<AnyList>()) {
  314. (void)tab_str.append("{");
  315. MS_LOG(INFO) << tab_str;
  316. auto any_list = any.cast<AnyList>();
  317. for (auto &it : any_list) {
  318. DumpGraphOutput(it, recurse_level + 1);
  319. }
  320. (void)tab_str.append("}");
  321. MS_LOG(INFO) << tab_str;
  322. }
  323. (void)tab_str.append(any.ToString());
  324. MS_LOG(INFO) << tab_str;
  325. }
  326. } // namespace
  327. GraphId SessionBasic::graph_sum_ = 0;
  328. ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input,
  329. KernelGraph *graph) {
  330. MS_EXCEPTION_IF_NULL(anf);
  331. if (!anf->isa<Parameter>()) {
  332. MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
  333. }
  334. auto m_tensor = GetParamDefaultInputTensor(anf);
  335. auto valid_inputs = graph->MutableValidInputs();
  336. MS_EXCEPTION_IF_NULL(valid_inputs);
  337. auto graph_inputs = graph->MutableInputs();
  338. MS_EXCEPTION_IF_NULL(graph_inputs);
  339. ParameterPtr new_parameter = nullptr;
  340. // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
  341. if (python_paras_ == nullptr) {
  342. python_paras_ = std::make_shared<std::map<tensor::TensorPtr, ParameterPtr>>();
  343. }
  344. if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) != kInvalidGraphId) {
  345. new_parameter = (*python_paras_)[m_tensor];
  346. } else {
  347. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  348. if (m_tensor != nullptr) {
  349. (*python_paras_)[m_tensor] = new_parameter;
  350. }
  351. }
  352. graph_inputs->push_back(new_parameter);
  353. valid_inputs->push_back(valid_input);
  354. return new_parameter;
  355. }
  356. AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
  357. MS_EXCEPTION_IF_NULL(anf);
  358. MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
  359. auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
  360. if (parameters.empty()) {
  361. MS_LOG(EXCEPTION) << "No parameter exist!!";
  362. }
  363. if (parameters.size() == 1) {
  364. return parameters[0];
  365. }
  366. std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
  367. (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
  368. auto make_tuple = graph->NewCNode(make_tuple_input);
  369. MS_EXCEPTION_IF_NULL(make_tuple);
  370. MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
  371. return make_tuple;
  372. }
  373. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
  374. bool *from_other_graph,
  375. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  376. MS_EXCEPTION_IF_NULL(cnode);
  377. MS_EXCEPTION_IF_NULL(graph);
  378. MS_EXCEPTION_IF_NULL(from_other_graph);
  379. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  380. *from_other_graph = false;
  381. // get primitive of old node
  382. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  383. MS_EXCEPTION_IF_NULL(prim);
  384. // push attr to inputs[0] of new cnode
  385. std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
  386. // if has multiple depends,only select first depend as parameter
  387. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  388. auto anf = cnode->inputs()[input_idx];
  389. MS_EXCEPTION_IF_NULL(anf);
  390. // anf has been created before
  391. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  392. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  393. continue;
  394. } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
  395. cnode_inputs.push_back((*other_graph_cnode)[anf]);
  396. continue;
  397. } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
  398. // if input is a value node,
  399. auto new_value_node = CreateNewValueNode(anf, graph);
  400. if (new_value_node != nullptr) {
  401. cnode_inputs.emplace_back(new_value_node);
  402. }
  403. continue;
  404. } else if (anf->isa<Parameter>()) {
  405. auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
  406. cnode_inputs.push_back(new_parameter);
  407. if (GetGraphIdByNode(anf) == kInvalidGraphId) {
  408. graph->FrontBackendlMapAdd(anf, new_parameter);
  409. } else {
  410. (*other_graph_cnode)[anf] = new_parameter;
  411. }
  412. continue;
  413. } else if (anf->isa<CNode>()) {
  414. *from_other_graph = true;
  415. // the input node is a cnode from other graph
  416. auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
  417. cnode_inputs.push_back(parameter_from_cnode);
  418. (*other_graph_cnode)[anf] = parameter_from_cnode;
  419. continue;
  420. }
  421. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  422. }
  423. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  424. auto new_cnode = graph->NewCNode(cnode_inputs);
  425. TraceManager::EndTrace();
  426. return new_cnode;
  427. }
  428. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
  429. MS_EXCEPTION_IF_NULL(cnode);
  430. MS_EXCEPTION_IF_NULL(graph);
  431. std::vector<AnfNodePtr> cnode_inputs;
  432. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  433. MS_EXCEPTION_IF_NULL(attr_input);
  434. if (IsValueNode<FuncGraph>(attr_input)) {
  435. // create primitive of cnode:call
  436. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  437. // create a ValueNode<KernelGraph> as input of cnode:call
  438. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  439. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
  440. } else {
  441. auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
  442. if (new_value_node != nullptr) {
  443. cnode_inputs.emplace_back(new_value_node);
  444. }
  445. }
  446. } else if (attr_input->isa<CNode>()) {
  447. // create primitive of cnode:call(switch)
  448. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  449. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  450. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  451. if (!AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
  452. MS_LOG(EXCEPTION) << "CNode input[0] must be switch.";
  453. }
  454. cnode_inputs.emplace_back(cnode_input);
  455. } else {
  456. MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
  457. << ", but input[0] has not been created.";
  458. }
  459. } else {
  460. // get primitive of old node
  461. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  462. MS_EXCEPTION_IF_NULL(prim);
  463. // push attr to inputs[0] of new cnode
  464. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
  465. }
  466. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  467. auto anf = cnode->inputs()[input_idx];
  468. MS_EXCEPTION_IF_NULL(anf);
  469. // anf has been created before
  470. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  471. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  472. continue;
  473. } else if (anf->isa<ValueNode>()) {
  474. if (!IsValueNode<FuncGraph>(anf)) {
  475. // if input is a common value node,
  476. auto new_value_node = CreateNewValueNode(anf, graph);
  477. if (new_value_node != nullptr) {
  478. cnode_inputs.emplace_back(new_value_node);
  479. }
  480. } else {
  481. // if input is a ValueNode<FuncGraph>
  482. auto new_value_node = CreateValueNodeKernelGraph(anf, graph);
  483. if (new_value_node != nullptr) {
  484. cnode_inputs.emplace_back(new_value_node);
  485. }
  486. }
  487. continue;
  488. } else if (anf->isa<Parameter>()) {
  489. auto new_parameter = CreateNewParameter(anf, graph);
  490. cnode_inputs.push_back(new_parameter);
  491. continue;
  492. }
  493. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  494. }
  495. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  496. auto new_cnode = graph->NewCNode(cnode_inputs);
  497. TraceManager::EndTrace();
  498. return new_cnode;
  499. }
  500. ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
  501. MS_EXCEPTION_IF_NULL(anf);
  502. auto value_node = anf->cast<ValueNodePtr>();
  503. MS_EXCEPTION_IF_NULL(value_node);
  504. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
  505. MS_EXCEPTION_IF_NULL(sub_func_graph);
  506. if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
  507. MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
  508. }
  509. auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
  510. ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
  511. new_value_node->set_abstract(value_node->abstract());
  512. // create new kernel_info of new value_node
  513. auto kernel_info = std::make_shared<device::KernelInfo>();
  514. kernel_info->SetFeatureMapFlag(false);
  515. new_value_node->set_kernel_info(kernel_info);
  516. // create kernel_build_info for new value node
  517. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  518. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  519. AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
  520. graph->FrontBackendlMapAdd(anf, new_value_node);
  521. return new_value_node;
  522. }
  523. ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  524. MS_EXCEPTION_IF_NULL(anf);
  525. if (!anf->isa<Parameter>()) {
  526. MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
  527. }
  528. auto graph_inputs = graph->MutableInputs();
  529. MS_EXCEPTION_IF_NULL(graph_inputs);
  530. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  531. auto new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  532. TraceManager::EndTrace();
  533. graph_inputs->push_back(new_parameter);
  534. graph->FrontBackendlMapAdd(anf, new_parameter);
  535. return new_parameter;
  536. }
  537. KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  538. std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
  539. auto graph = NewKernelGraph();
  540. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  541. size_t from_other_graph_depend_num = 0;
  542. for (const auto &node : lst) {
  543. MS_EXCEPTION_IF_NULL(node);
  544. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  545. if (!node->isa<CNode>()) {
  546. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
  547. }
  548. auto cnode = node->cast<CNodePtr>();
  549. MS_EXCEPTION_IF_NULL(cnode);
  550. // create a new cnode object
  551. bool from_other_graph = false;
  552. // only first depend from other graph can create
  553. bool valid_input = true;
  554. if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
  555. valid_input = false;
  556. }
  557. auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
  558. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
  559. from_other_graph_depend_num++;
  560. }
  561. MS_EXCEPTION_IF_NULL(new_cnode);
  562. new_cnode->set_abstract(cnode->abstract());
  563. new_cnode->set_scope(cnode->scope());
  564. // record map relations between anf from ME and new anf node used in backend
  565. graph->FrontBackendlMapAdd(node, new_cnode);
  566. }
  567. // add a make_tuple at the end of graph as output
  568. graph->set_output(ConstructOutput(outputs, graph));
  569. MS_EXCEPTION_IF_NULL(context_);
  570. FuncGraphManagerPtr manager = context_->manager();
  571. if (manager) {
  572. manager->AddFuncGraph(graph);
  573. graph->set_manager(manager);
  574. }
  575. graph->SetExecOrderByDefault();
  576. opt::BackendCommonOptimization(graph);
  577. return graph;
  578. }
  579. std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
  580. MS_EXCEPTION_IF_NULL(func_graph);
  581. if (front_backend_graph_map_.find(func_graph) != front_backend_graph_map_.end()) {
  582. MS_LOG(INFO) << "FuncGraph: " << func_graph->ToString() << " has been transformed to KernelGraph.";
  583. return front_backend_graph_map_[func_graph];
  584. }
  585. auto node_list = TopoSort(func_graph->get_return());
  586. auto graph = NewKernelGraph();
  587. front_backend_graph_map_[func_graph] = graph;
  588. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  589. for (const auto &node : node_list) {
  590. MS_EXCEPTION_IF_NULL(node);
  591. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  592. if (!node->isa<CNode>()) {
  593. MS_LOG(DEBUG) << "Node " << node->DebugString() << " is not CNode";
  594. continue;
  595. } else {
  596. auto cnode = node->cast<CNodePtr>();
  597. MS_EXCEPTION_IF_NULL(cnode);
  598. // recurse control ops: call, partial
  599. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  600. MS_EXCEPTION_IF_NULL(attr_input);
  601. if (IsValueNode<FuncGraph>(attr_input)) {
  602. // recurse call subgraph
  603. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(attr_input);
  604. ConstructKernelGraph(sub_func_graph);
  605. } else if (IsValueNode<Primitive>(attr_input)) {
  606. auto prim = GetCNodePrimitive(node);
  607. MS_EXCEPTION_IF_NULL(prim);
  608. if (prim->name() == kPartialOpName) {
  609. // recurse partial subgraph
  610. auto func_graph_node = cnode->input(kAnfPartialFuncGraphIndex);
  611. MS_EXCEPTION_IF_NULL(func_graph_node);
  612. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node);
  613. ConstructKernelGraph(sub_func_graph);
  614. } else if (prim->name() == kReturnOpName) {
  615. std::vector<AnfNodePtr> outputs;
  616. auto inputs = cnode->inputs();
  617. if (inputs.size() < 2) {
  618. MS_LOG(EXCEPTION) << "CNode[return] must have two inputs at least, actual inputs size is " << inputs.size();
  619. }
  620. (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(outputs));
  621. // add a make_tuple before return as graph output
  622. graph->set_output(ConstructOutput(outputs, graph));
  623. continue;
  624. }
  625. }
  626. // create a new cnode object
  627. auto new_cnode = CreateNewCNode(cnode, graph.get());
  628. MS_EXCEPTION_IF_NULL(new_cnode);
  629. new_cnode->set_abstract(cnode->abstract());
  630. new_cnode->set_scope(cnode->scope());
  631. graph->FrontBackendlMapAdd(node, new_cnode);
  632. }
  633. }
  634. MS_EXCEPTION_IF_NULL(context_);
  635. FuncGraphManagerPtr manager = context_->manager();
  636. if (manager) {
  637. manager->AddFuncGraph(graph);
  638. graph->set_manager(manager);
  639. }
  640. graph->SetExecOrderByDefault();
  641. return graph;
  642. }
  643. // run graph steps
  644. void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  645. const std::vector<tensor::TensorPtr> &inputs_const) const {
  646. std::vector<tensor::TensorPtr> inputs(inputs_const);
  647. size_t input_ctrl_size = 1;
  648. MS_EXCEPTION_IF_NULL(kernel_graph);
  649. if (kernel_graph->input_ctrl_tensors()) {
  650. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  651. }
  652. auto input_nodes = kernel_graph->inputs();
  653. if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
  654. MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  655. << ", input_ctrl_size:" << input_ctrl_size;
  656. }
  657. auto ms_context = MsContext::GetInstance();
  658. MS_EXCEPTION_IF_NULL(ms_context);
  659. for (size_t i = 0; i < inputs.size(); ++i) {
  660. auto tensor = inputs[i];
  661. MS_EXCEPTION_IF_NULL(tensor);
  662. auto input_node = input_nodes[i];
  663. MS_EXCEPTION_IF_NULL(input_node);
  664. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  665. auto pk_node = input_node->cast<ParameterPtr>();
  666. auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
  667. bool need_sync = false;
  668. if (ms_context->enable_pynative_infer()) {
  669. if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
  670. need_sync = true;
  671. }
  672. } else {
  673. if (tensor->is_dirty()) {
  674. need_sync = true;
  675. } else if (tensor->device_address() != device_address) {
  676. (void)tensor->data_sync();
  677. need_sync = true;
  678. }
  679. }
  680. if (need_sync) {
  681. tensor->set_device_address(device_address);
  682. MS_EXCEPTION_IF_NULL(device_address);
  683. if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
  684. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  685. tensor->data_c(false))) {
  686. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  687. }
  688. }
  689. }
  690. tensor->set_dirty(false);
  691. }
  692. }
  693. void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  694. const std::vector<tensor::TensorPtr> &input_tensors) const {
  695. MS_EXCEPTION_IF_NULL(kernel_graph);
  696. MS_EXCEPTION_IF_NULL(outputs);
  697. auto anf_outputs = kernel_graph->outputs();
  698. for (auto &item : anf_outputs) {
  699. MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
  700. MS_EXCEPTION_IF_NULL(item);
  701. if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
  702. outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
  703. continue;
  704. }
  705. outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
  706. }
  707. }
  708. void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
  709. MS_EXCEPTION_IF_NULL(callback);
  710. summary_callback_ = callback;
  711. }
  712. void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
  713. MS_EXCEPTION_IF_NULL(node_list);
  714. std::vector<CNodePtr> all_opt_list;
  715. std::vector<CNodePtr> non_opt_list;
  716. for (const auto &node : *node_list) {
  717. MS_EXCEPTION_IF_NULL(node);
  718. if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) {
  719. all_opt_list.emplace_back(node);
  720. } else {
  721. non_opt_list.emplace_back(node);
  722. }
  723. }
  724. node_list->clear();
  725. (void)std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list));
  726. (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
  727. }
  728. void SessionBasic::Summary(KernelGraph *graph) {
  729. if (summary_callback_ == nullptr) {
  730. return;
  731. }
  732. MS_EXCEPTION_IF_NULL(graph);
  733. bool exist_summary = ExistSummaryNode(graph);
  734. if (!exist_summary) {
  735. return;
  736. }
  737. std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs;
  738. GetSummaryNodes(graph, &summary_outputs);
  739. std::map<std::string, tensor::TensorPtr> params_list;
  740. // fetch outputs apply kernel in session & run callback functions
  741. for (auto &output_item : summary_outputs) {
  742. auto node = output_item.second.first;
  743. size_t index = IntToSize(output_item.second.second);
  744. auto address = AnfAlgo::GetOutputAddr(node, index);
  745. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  746. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  747. std::vector<int> temp_shape;
  748. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  749. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  750. MS_EXCEPTION_IF_NULL(address);
  751. if (!address->GetPtr()) {
  752. continue;
  753. }
  754. if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
  755. tensor->data_type(), tensor->data_c(true))) {
  756. MS_LOG(ERROR) << "Failed to sync output from device to host.";
  757. }
  758. tensor->set_dirty(false);
  759. params_list[output_item.first] = tensor;
  760. }
  761. // call callback function here
  762. summary_callback_(0, params_list);
  763. }
  764. CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
  765. MS_EXCEPTION_IF_NULL(graph);
  766. std::vector<AnfNodePtr> output_args;
  767. for (const auto &output : outputs) {
  768. MS_LOG(INFO) << "output:" << output->DebugString();
  769. }
  770. auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
  771. auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
  772. if (backend_anf != nullptr) {
  773. return backend_anf;
  774. }
  775. MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
  776. };
  777. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  778. (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
  779. [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
  780. return graph->NewCNode(output_args);
  781. }
  782. void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
  783. MS_LOG(INFO) << "Start!";
  784. std::vector<AnfNodePtr> make_tuple_inputs;
  785. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  786. if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
  787. for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
  788. auto idx = NewValueNode(SizeToInt(output_index));
  789. MS_EXCEPTION_IF_NULL(idx);
  790. auto imm = std::make_shared<Int32Imm>(output_index);
  791. idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
  792. MS_EXCEPTION_IF_NULL(graph);
  793. auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
  794. std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
  795. std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
  796. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
  797. make_tuple_inputs.push_back(getitem);
  798. }
  799. } else {
  800. make_tuple_inputs.push_back(cnode);
  801. }
  802. // create output
  803. auto g_output = graph->NewCNode(make_tuple_inputs);
  804. graph->set_output(g_output);
  805. // set graph manager,which now is only used to get valuenodes and hardware optimizing
  806. MS_EXCEPTION_IF_NULL(context_);
  807. FuncGraphManagerPtr manager = context_->manager();
  808. if (manager != nullptr) {
  809. manager->AddFuncGraph(graph);
  810. graph->set_manager(manager);
  811. }
  812. MS_LOG(INFO) << "Finish!";
  813. }
  814. std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  815. const std::vector<tensor::TensorPtr> &input_tensors,
  816. const std::vector<int> &tensors_mask) {
  817. auto graph = std::make_shared<KernelGraph>();
  818. std::vector<AnfNodePtr> inputs;
  819. // set input[0]
  820. PrimitivePtr op_prim = op_run_info.py_primitive;
  821. MS_EXCEPTION_IF_NULL(op_prim);
  822. inputs.push_back(std::make_shared<ValueNode>(op_prim));
  823. // set input parameter
  824. MS_LOG(INFO) << "Input tensor size: " << input_tensors.size();
  825. if (input_tensors.size() != tensors_mask.size()) {
  826. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  827. << tensors_mask.size();
  828. }
  829. for (size_t i = 0; i < input_tensors.size(); ++i) {
  830. if (tensors_mask[i] == kValueNodeTensorMask) {
  831. auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]);
  832. inputs.push_back(value_node);
  833. continue;
  834. }
  835. auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
  836. inputs.push_back(parameter);
  837. graph->MutableInputs()->push_back(parameter);
  838. }
  839. // set execution order
  840. auto cnode = graph->NewCNode(inputs);
  841. MS_EXCEPTION_IF_NULL(cnode);
  842. // set abstract,which include inferred shapes and types
  843. cnode->set_abstract(op_run_info.abstract);
  844. // set execution order
  845. std::vector<CNodePtr> exe_order = {cnode};
  846. graph->set_execution_order(exe_order);
  847. // set output
  848. CreateOutputNode(cnode, graph);
  849. return graph;
  850. }
  851. BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
  852. if (utils::isa<VectorRef>(base_ref)) {
  853. auto ref_list = utils::cast<VectorRef>(base_ref);
  854. py::tuple output_tensors(ref_list.size());
  855. for (size_t i = 0; i < ref_list.size(); ++i) {
  856. auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
  857. if (utils::isa<tensor::TensorPtr>(output)) {
  858. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  859. MS_EXCEPTION_IF_NULL(tensor_ptr);
  860. output_tensors[i] = tensor_ptr;
  861. } else if (utils::isa<PyObjectRef>(output)) {
  862. py::object obj = utils::cast<PyObjectRef>(output).object_;
  863. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  864. output_tensors[i] = tensor_tuple;
  865. } else {
  866. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  867. }
  868. }
  869. return output_tensors; // turn tuple to py::object and store in PyObjectRef
  870. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  871. return base_ref;
  872. } else {
  873. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  874. }
  875. }
  876. KernelGraphPtr SessionBasic::NewKernelGraph() {
  877. auto graph = std::make_shared<KernelGraph>();
  878. graph->set_graph_id(graph_sum_);
  879. graphs_[graph_sum_++] = graph;
  880. return graph;
  881. }
  882. } // namespace session
  883. } // namespace mindspore