You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.cc 38 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/session_basic.h"
  17. #include <utility>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include "pipeline/parse/data_converter.h"
  22. #include "ir/manager.h"
  23. #include "ir/param_value_py.h"
  24. #include "operator/ops.h"
  25. #include "common/trans.h"
  26. #include "utils/context/ms_context.h"
  27. #include "utils/config_manager.h"
  28. #include "session/anf_runtime_algorithm.h"
  29. #include "kernel/oplib/oplib.h"
  30. #include "pre_activate/common/common_backend_optimization.h"
  31. #include "pre_activate/pass/const_input_to_attr_registry.h"
  32. #include "pre_activate/common/helper.h"
  33. #include "common/utils.h"
  34. #include "ir/dtype.h"
  35. #include "ir/anf.h"
  36. namespace mindspore {
  37. namespace session {
  38. static std::shared_ptr<std::map<PyObject *, ParameterPtr>> python_paras_;
  39. void ClearPythonParasMap() { python_paras_ = nullptr; }
  40. namespace {
  41. const int kSummaryGetItem = 2;
  42. PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
  43. if (node == nullptr) {
  44. return nullptr;
  45. }
  46. auto parameter = node->cast<ParameterPtr>();
  47. if (parameter == nullptr || !parameter->has_default()) {
  48. return nullptr;
  49. }
  50. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
  51. auto py_param = param_value->value();
  52. return py_param.ptr();
  53. }
  54. void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
  55. MS_LOG(DEBUG) << "Update summary Start";
  56. MS_EXCEPTION_IF_NULL(graph);
  57. MS_EXCEPTION_IF_NULL(summary);
  58. summary->clear();
  59. auto apply_list = TopoSort(graph->get_return());
  60. for (auto &n : apply_list) {
  61. MS_EXCEPTION_IF_NULL(n);
  62. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  63. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  64. auto cnode = n->cast<CNodePtr>();
  65. MS_EXCEPTION_IF_NULL(cnode);
  66. if (cnode->inputs().size() <= kSummaryGetItem) {
  67. MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
  68. }
  69. auto node = cnode->input(kSummaryGetItem);
  70. MS_EXCEPTION_IF_NULL(node);
  71. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
  72. if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
  73. MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
  74. }
  75. (*summary)[n->fullname_with_scope()] = item_with_index;
  76. }
  77. }
  78. MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
  79. }
  80. bool ExistSummaryNode(const KernelGraph *graph) {
  81. auto ret = graph->get_return();
  82. MS_EXCEPTION_IF_NULL(ret);
  83. auto all_nodes = DeepLinkedGraphSearch(ret);
  84. for (auto &n : all_nodes) {
  85. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  86. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  87. return true;
  88. }
  89. }
  90. return false;
  91. }
  92. BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
  93. const std::vector<tensor::TensorPtr> &input_tensors) {
  94. MS_EXCEPTION_IF_NULL(node);
  95. MS_LOG(INFO) << "create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
  96. // if node is a value node, no need sync addr from device to host
  97. if (!AnfAlgo::OutputAddrExist(node, output_index)) {
  98. if (node->isa<ValueNode>()) {
  99. auto value_node = node->cast<ValueNodePtr>();
  100. MS_EXCEPTION_IF_NULL(value_node);
  101. return value_node->value();
  102. }
  103. if (node->isa<Parameter>()) {
  104. for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
  105. if (input_idx > input_tensors.size()) {
  106. MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size();
  107. }
  108. if (graph.inputs()[input_idx] == node) {
  109. return input_tensors[input_idx];
  110. }
  111. }
  112. MS_LOG(EXCEPTION) << "parameter : " << node->DebugString() << "has no output addr";
  113. }
  114. }
  115. // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
  116. auto address = AnfAlgo::GetOutputAddr(node, output_index);
  117. MS_EXCEPTION_IF_NULL(address);
  118. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  119. TypeId type_id = kNumberTypeFloat32;
  120. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  121. std::vector<int> temp_shape;
  122. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  123. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  124. // if in paynative mode,data only copyed to host when user want to print data
  125. auto ms_context = MsContext::GetInstance();
  126. MS_EXCEPTION_IF_NULL(ms_context);
  127. if (ms_context->execution_mode() == kPynativeMode) {
  128. tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
  129. tensor->set_dirty(false);
  130. } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
  131. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  132. tensor->data_c(true))) {
  133. MS_LOG(INFO) << "output sync device to host error!!!";
  134. tensor->set_dirty(false);
  135. }
  136. return tensor;
  137. }
  138. BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  139. const std::vector<tensor::TensorPtr> &input_tensors) {
  140. MS_EXCEPTION_IF_NULL(anf);
  141. MS_LOG(INFO) << "create tensor for output[" << anf->DebugString() << "]";
  142. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  143. MS_EXCEPTION_IF_NULL(item_with_index.first);
  144. // special handle for maketuple
  145. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  146. auto cnode = item_with_index.first->cast<CNodePtr>();
  147. MS_EXCEPTION_IF_NULL(cnode);
  148. VectorRef ret;
  149. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  150. auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
  151. ret.push_back(out);
  152. }
  153. return ret;
  154. }
  155. // if is graph return nothing ,the function should return a null anylist
  156. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  157. if (size == 0) {
  158. return VectorRef();
  159. }
  160. return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
  161. }
  162. BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  163. const std::vector<tensor::TensorPtr> &input_tensors) {
  164. MS_EXCEPTION_IF_NULL(anf);
  165. if (!AnfAlgo::IsRealKernel(anf)) {
  166. MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] should be a executable kernel";
  167. }
  168. if (anf->isa<ValueNode>()) {
  169. return CreateOneTensor(anf, 0, graph, input_tensors);
  170. }
  171. VectorRef ret;
  172. if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
  173. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
  174. auto out = CreateOneTensor(anf, i, graph, input_tensors);
  175. ret.emplace_back(out);
  176. }
  177. }
  178. return ret;
  179. }
  180. ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
  181. auto value_node = anf->cast<ValueNodePtr>();
  182. MS_EXCEPTION_IF_NULL(value_node);
  183. auto value = value_node->value();
  184. MS_EXCEPTION_IF_NULL(value);
  185. if (value->isa<None>()) {
  186. return nullptr;
  187. }
  188. auto new_value_node = graph->NewValueNode(value_node);
  189. graph->FrontBackendlMapAdd(anf, new_value_node);
  190. graph->AddValueNodeToGraph(new_value_node);
  191. return new_value_node;
  192. }
  193. std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) {
  194. MS_EXCEPTION_IF_NULL(node);
  195. MS_EXCEPTION_IF_NULL(graph);
  196. std::vector<AnfNodePtr> parameters;
  197. std::vector<AnfNodePtr> pre_graph_out = {node};
  198. // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
  199. if (!AnfAlgo::IsRealKernel(node)) {
  200. pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
  201. }
  202. auto valid_inputs = graph->MutableValidInputs();
  203. MS_EXCEPTION_IF_NULL(valid_inputs);
  204. auto graph_inputs = graph->MutableInputs();
  205. MS_EXCEPTION_IF_NULL(graph_inputs);
  206. auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
  207. auto parameter = graph->NewParameter();
  208. MS_EXCEPTION_IF_NULL(parameter);
  209. parameter->set_abstract(abstract);
  210. auto new_parameter = graph->NewParameter(parameter);
  211. parameters.push_back(new_parameter);
  212. valid_inputs->push_back(valid_input);
  213. graph_inputs->push_back(new_parameter);
  214. };
  215. for (const auto &out_node : pre_graph_out) {
  216. MS_EXCEPTION_IF_NULL(out_node);
  217. auto abstract = out_node->abstract();
  218. MS_EXCEPTION_IF_NULL(abstract);
  219. // create multiple parameters if is a tuple output real kernel
  220. if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  221. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  222. MS_EXCEPTION_IF_NULL(tuple_abstract);
  223. MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]";
  224. for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
  225. create_parameter((*tuple_abstract)[output_idx]);
  226. }
  227. continue;
  228. }
  229. // create single parameter if is a abstract real kernel
  230. create_parameter(out_node->abstract());
  231. }
  232. return parameters;
  233. }
  234. size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
  235. MS_LOG(INFO) << "Load kInputCtrlTensors";
  236. auto inputs_params = graph->input_ctrl_tensors();
  237. if (inputs_params == nullptr) {
  238. return 0;
  239. }
  240. if (inputs_params->empty()) {
  241. MS_LOG(EXCEPTION) << "Illegal empty inputs_params";
  242. }
  243. auto tensor = (*inputs_params)[0];
  244. MS_EXCEPTION_IF_NULL(tensor);
  245. auto *val = static_cast<int32_t *>(tensor->data_c(true));
  246. MS_EXCEPTION_IF_NULL(val);
  247. *val = 0;
  248. tensor->set_dirty(true);
  249. // set loop_count to zero
  250. MS_EXCEPTION_IF_NULL(inputs);
  251. inputs->push_back(tensor);
  252. return inputs_params->size();
  253. }
  254. ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor) {
  255. MS_EXCEPTION_IF_NULL(graph);
  256. MS_EXCEPTION_IF_NULL(input_tensor);
  257. auto value_node = std::make_shared<ValueNode>(input_tensor);
  258. // construct abstract of value node
  259. auto type_of_tensor = input_tensor->Dtype();
  260. auto shape_of_tensor = input_tensor->shape();
  261. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  262. value_node->set_abstract(abstract);
  263. // add value node to graph
  264. auto input_value_node = graph->NewValueNode(value_node);
  265. graph->AddValueNodeToGraph(input_value_node);
  266. return input_value_node;
  267. }
  268. ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
  269. int tensor_mask) {
  270. auto param = graph->NewParameter();
  271. MS_EXCEPTION_IF_NULL(param);
  272. if (tensor_mask == kParameterWeightTensorMask) {
  273. py::object obj;
  274. auto param_value_new = std::make_shared<ParamValuePy>(obj);
  275. param->set_default_param(param_value_new);
  276. }
  277. // set the kernel info of parameter
  278. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  279. MS_EXCEPTION_IF_NULL(input_tensor);
  280. if (input_tensor->device_address().get() == nullptr) {
  281. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  282. TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
  283. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
  284. } else {
  285. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{input_tensor->device_address()->format()});
  286. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()});
  287. }
  288. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
  289. // construct abstract of parameter
  290. auto type_of_tensor = input_tensor->Dtype();
  291. auto shape_of_tensor = input_tensor->shape();
  292. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  293. param->set_abstract(abstract);
  294. return param;
  295. }
  296. void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
  297. MS_LOG(INFO) << "graph outputs:";
  298. const size_t max_deep = 10;
  299. if (recurse_level > max_deep) {
  300. MS_LOG(INFO) << "recurse too deep";
  301. return;
  302. }
  303. std::string tab_str;
  304. for (size_t i = 0; i < recurse_level; i++) {
  305. tab_str = tab_str.append(" ");
  306. }
  307. if (any.is<AnyList>()) {
  308. (void)tab_str.append("{");
  309. MS_LOG(INFO) << tab_str;
  310. auto any_list = any.cast<AnyList>();
  311. for (auto &it : any_list) {
  312. DumpGraphOutput(it, recurse_level + 1);
  313. }
  314. (void)tab_str.append("}");
  315. MS_LOG(INFO) << tab_str;
  316. }
  317. (void)tab_str.append(any.ToString());
  318. MS_LOG(INFO) << tab_str;
  319. }
  320. } // namespace
  321. GraphId SessionBasic::graph_sum_ = 0;
  322. ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input,
  323. KernelGraph *graph) {
  324. MS_EXCEPTION_IF_NULL(anf);
  325. if (!anf->isa<Parameter>()) {
  326. MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
  327. }
  328. auto m_tensor = GetParamDefaultInputTensor(anf);
  329. auto valid_inputs = graph->MutableValidInputs();
  330. MS_EXCEPTION_IF_NULL(valid_inputs);
  331. auto graph_inputs = graph->MutableInputs();
  332. MS_EXCEPTION_IF_NULL(graph_inputs);
  333. ParameterPtr new_parameter = nullptr;
  334. // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
  335. if (python_paras_ == nullptr) {
  336. python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
  337. }
  338. if (python_paras_->find(m_tensor) != python_paras_->end() && GetGraphIdByNode(anf) == kInvalidGraphId) {
  339. new_parameter = (*python_paras_)[m_tensor];
  340. } else {
  341. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  342. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  343. if (m_tensor != nullptr) {
  344. (*python_paras_)[m_tensor] = new_parameter;
  345. }
  346. TraceManager::EndTrace();
  347. }
  348. graph_inputs->push_back(new_parameter);
  349. valid_inputs->push_back(valid_input);
  350. return new_parameter;
  351. }
  352. AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
  353. MS_EXCEPTION_IF_NULL(anf);
  354. MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
  355. auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
  356. if (parameters.empty()) {
  357. MS_LOG(EXCEPTION) << "No parameter exist!!";
  358. }
  359. if (parameters.size() == 1) {
  360. return parameters[0];
  361. }
  362. std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
  363. (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
  364. auto make_tuple = graph->NewCNode(make_tuple_input);
  365. MS_EXCEPTION_IF_NULL(make_tuple);
  366. MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
  367. return make_tuple;
  368. }
  369. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
  370. bool *from_other_graph,
  371. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  372. MS_EXCEPTION_IF_NULL(cnode);
  373. MS_EXCEPTION_IF_NULL(graph);
  374. MS_EXCEPTION_IF_NULL(from_other_graph);
  375. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  376. *from_other_graph = false;
  377. // get primitive of old node
  378. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  379. MS_EXCEPTION_IF_NULL(prim);
  380. // push attr to inputs[0] of new cnode
  381. std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
  382. // if has multiple depends,only select first depend as parameter
  383. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  384. auto anf = cnode->inputs()[input_idx];
  385. MS_EXCEPTION_IF_NULL(anf);
  386. // anf has been created before
  387. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  388. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  389. continue;
  390. } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
  391. cnode_inputs.push_back((*other_graph_cnode)[anf]);
  392. continue;
  393. } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
  394. // if input is a value node,
  395. auto new_value_node = CreateNewValueNode(anf, graph);
  396. if (new_value_node != nullptr) {
  397. cnode_inputs.emplace_back(new_value_node);
  398. }
  399. continue;
  400. } else if (anf->isa<Parameter>()) {
  401. auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
  402. cnode_inputs.push_back(new_parameter);
  403. if (GetGraphIdByNode(anf) == kInvalidGraphId) {
  404. graph->FrontBackendlMapAdd(anf, new_parameter);
  405. } else {
  406. (*other_graph_cnode)[anf] = new_parameter;
  407. }
  408. continue;
  409. } else if (anf->isa<CNode>()) {
  410. *from_other_graph = true;
  411. // the input node is a cnode from other graph
  412. auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
  413. cnode_inputs.push_back(parameter_from_cnode);
  414. (*other_graph_cnode)[anf] = parameter_from_cnode;
  415. continue;
  416. }
  417. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  418. }
  419. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  420. auto new_cnode = graph->NewCNode(cnode_inputs);
  421. TraceManager::EndTrace();
  422. return new_cnode;
  423. }
  424. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
  425. MS_EXCEPTION_IF_NULL(cnode);
  426. MS_EXCEPTION_IF_NULL(graph);
  427. std::vector<AnfNodePtr> cnode_inputs;
  428. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  429. MS_EXCEPTION_IF_NULL(attr_input);
  430. if (IsValueNode<FuncGraph>(attr_input)) {
  431. // create primitive of cnode:call
  432. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  433. // create a ValueNode<KernelGraph> as input of cnode:call
  434. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  435. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
  436. } else {
  437. auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
  438. if (new_value_node != nullptr) {
  439. cnode_inputs.emplace_back(new_value_node);
  440. }
  441. }
  442. } else if (attr_input->isa<CNode>()) {
  443. // create primitive of cnode:call(switch)
  444. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  445. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  446. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  447. if (!AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
  448. MS_LOG(EXCEPTION) << "CNode input[0] must be switch.";
  449. }
  450. cnode_inputs.emplace_back(cnode_input);
  451. } else {
  452. MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
  453. << ", but input[0] has not been created.";
  454. }
  455. } else {
  456. // get primitive of old node
  457. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  458. MS_EXCEPTION_IF_NULL(prim);
  459. // push attr to inputs[0] of new cnode
  460. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
  461. }
  462. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  463. auto anf = cnode->inputs()[input_idx];
  464. MS_EXCEPTION_IF_NULL(anf);
  465. // anf has been created before
  466. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  467. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  468. continue;
  469. } else if (anf->isa<ValueNode>()) {
  470. if (!IsValueNode<FuncGraph>(anf)) {
  471. // if input is a common value node,
  472. auto new_value_node = CreateNewValueNode(anf, graph);
  473. if (new_value_node != nullptr) {
  474. cnode_inputs.emplace_back(new_value_node);
  475. }
  476. } else {
  477. // if input is a ValueNode<FuncGraph>
  478. auto new_value_node = CreateValueNodeKernelGraph(anf, graph);
  479. if (new_value_node != nullptr) {
  480. cnode_inputs.emplace_back(new_value_node);
  481. }
  482. }
  483. continue;
  484. } else if (anf->isa<Parameter>()) {
  485. auto new_parameter = CreateNewParameter(anf, graph);
  486. cnode_inputs.push_back(new_parameter);
  487. continue;
  488. }
  489. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  490. }
  491. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  492. auto new_cnode = graph->NewCNode(cnode_inputs);
  493. TraceManager::EndTrace();
  494. return new_cnode;
  495. }
  496. ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
  497. MS_EXCEPTION_IF_NULL(anf);
  498. auto value_node = anf->cast<ValueNodePtr>();
  499. MS_EXCEPTION_IF_NULL(value_node);
  500. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
  501. MS_EXCEPTION_IF_NULL(sub_func_graph);
  502. if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
  503. MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
  504. }
  505. auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
  506. ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
  507. new_value_node->set_abstract(value_node->abstract());
  508. // create new kernel_info of new value_node
  509. auto kernel_info = std::make_shared<device::KernelInfo>();
  510. kernel_info->SetFeatureMapFlag(false);
  511. new_value_node->set_kernel_info(kernel_info);
  512. // create kernel_build_info for new value node
  513. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  514. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  515. AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
  516. graph->FrontBackendlMapAdd(anf, new_value_node);
  517. return new_value_node;
  518. }
  519. ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  520. MS_EXCEPTION_IF_NULL(anf);
  521. if (!anf->isa<Parameter>()) {
  522. MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
  523. }
  524. auto graph_inputs = graph->MutableInputs();
  525. MS_EXCEPTION_IF_NULL(graph_inputs);
  526. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  527. auto new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  528. TraceManager::EndTrace();
  529. graph_inputs->push_back(new_parameter);
  530. graph->FrontBackendlMapAdd(anf, new_parameter);
  531. return new_parameter;
  532. }
  533. KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  534. std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
  535. auto graph = NewKernelGraph();
  536. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  537. size_t from_other_graph_depend_num = 0;
  538. for (const auto &node : lst) {
  539. MS_EXCEPTION_IF_NULL(node);
  540. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  541. if (!node->isa<CNode>()) {
  542. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
  543. }
  544. auto cnode = node->cast<CNodePtr>();
  545. MS_EXCEPTION_IF_NULL(cnode);
  546. // create a new cnode object
  547. bool from_other_graph = false;
  548. // only first depend from other graph can create
  549. bool valid_input = true;
  550. if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
  551. valid_input = false;
  552. }
  553. auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
  554. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
  555. from_other_graph_depend_num++;
  556. }
  557. MS_EXCEPTION_IF_NULL(new_cnode);
  558. new_cnode->set_abstract(cnode->abstract());
  559. new_cnode->set_scope(cnode->scope());
  560. // record map relations between anf from ME and new anf node used in backend
  561. graph->FrontBackendlMapAdd(node, new_cnode);
  562. }
  563. // add a make_tuple at the end of graph as output
  564. graph->set_output(ConstructOutput(outputs, graph));
  565. MS_EXCEPTION_IF_NULL(context_);
  566. FuncGraphManagerPtr manager = MakeManager({graph});
  567. if (manager) {
  568. manager->AddFuncGraph(graph);
  569. graph->set_manager(manager);
  570. }
  571. graph->SetExecOrderByDefault();
  572. opt::BackendCommonOptimization(graph);
  573. return graph;
  574. }
  575. std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
  576. MS_EXCEPTION_IF_NULL(func_graph);
  577. if (front_backend_graph_map_.find(func_graph) != front_backend_graph_map_.end()) {
  578. MS_LOG(INFO) << "FuncGraph: " << func_graph->ToString() << " has been transformed to KernelGraph.";
  579. return front_backend_graph_map_[func_graph];
  580. }
  581. auto node_list = TopoSort(func_graph->get_return());
  582. auto graph = NewKernelGraph();
  583. front_backend_graph_map_[func_graph] = graph;
  584. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  585. for (const auto &node : node_list) {
  586. MS_EXCEPTION_IF_NULL(node);
  587. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  588. if (!node->isa<CNode>()) {
  589. MS_LOG(DEBUG) << "Node " << node->DebugString() << " is not CNode";
  590. continue;
  591. } else {
  592. auto cnode = node->cast<CNodePtr>();
  593. MS_EXCEPTION_IF_NULL(cnode);
  594. // recurse control ops: call, partial
  595. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  596. MS_EXCEPTION_IF_NULL(attr_input);
  597. if (IsValueNode<FuncGraph>(attr_input)) {
  598. // recurse call subgraph
  599. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(attr_input);
  600. ConstructKernelGraph(sub_func_graph);
  601. } else if (IsValueNode<Primitive>(attr_input)) {
  602. auto prim = GetCNodePrimitive(node);
  603. MS_EXCEPTION_IF_NULL(prim);
  604. if (prim->name() == kPartialOpName) {
  605. // recurse partial subgraph
  606. auto func_graph_node = cnode->input(kAnfPartialFuncGraphIndex);
  607. MS_EXCEPTION_IF_NULL(func_graph_node);
  608. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(func_graph_node);
  609. ConstructKernelGraph(sub_func_graph);
  610. }
  611. }
  612. // create a new cnode object
  613. auto new_cnode = CreateNewCNode(cnode, graph.get());
  614. MS_EXCEPTION_IF_NULL(new_cnode);
  615. new_cnode->set_abstract(cnode->abstract());
  616. new_cnode->set_scope(cnode->scope());
  617. graph->FrontBackendlMapAdd(node, new_cnode);
  618. if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
  619. graph->set_return(new_cnode);
  620. }
  621. }
  622. }
  623. MS_EXCEPTION_IF_NULL(context_);
  624. FuncGraphManagerPtr manager = context_->manager();
  625. if (manager) {
  626. manager->AddFuncGraph(graph);
  627. graph->set_manager(manager);
  628. }
  629. graph->SetExecOrderByDefault();
  630. return graph;
  631. }
  632. // run graph steps
  633. void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  634. const std::vector<tensor::TensorPtr> &inputs_const) const {
  635. std::vector<tensor::TensorPtr> inputs(inputs_const);
  636. size_t input_ctrl_size = 1;
  637. MS_EXCEPTION_IF_NULL(kernel_graph);
  638. if (kernel_graph->input_ctrl_tensors()) {
  639. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  640. }
  641. auto input_nodes = kernel_graph->inputs();
  642. if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
  643. MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  644. << ", input_ctrl_size:" << input_ctrl_size;
  645. }
  646. auto ms_context = MsContext::GetInstance();
  647. MS_EXCEPTION_IF_NULL(ms_context);
  648. for (size_t i = 0; i < inputs.size(); ++i) {
  649. auto tensor = inputs[i];
  650. MS_EXCEPTION_IF_NULL(tensor);
  651. auto input_node = input_nodes[i];
  652. MS_EXCEPTION_IF_NULL(input_node);
  653. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  654. auto pk_node = input_node->cast<ParameterPtr>();
  655. auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
  656. bool need_sync = false;
  657. if (ms_context->enable_pynative_infer()) {
  658. if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
  659. need_sync = true;
  660. }
  661. } else {
  662. if (tensor->is_dirty()) {
  663. need_sync = true;
  664. } else if (tensor->device_address() != device_address) {
  665. (void)tensor->data_sync();
  666. need_sync = true;
  667. }
  668. }
  669. if (need_sync) {
  670. tensor->set_device_address(device_address);
  671. MS_EXCEPTION_IF_NULL(device_address);
  672. if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
  673. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  674. tensor->data_c(false))) {
  675. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  676. }
  677. }
  678. }
  679. tensor->set_dirty(false);
  680. }
  681. }
  682. void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  683. const std::vector<tensor::TensorPtr> &input_tensors) const {
  684. MS_EXCEPTION_IF_NULL(kernel_graph);
  685. MS_EXCEPTION_IF_NULL(outputs);
  686. auto anf_outputs = kernel_graph->outputs();
  687. for (auto &item : anf_outputs) {
  688. MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
  689. MS_EXCEPTION_IF_NULL(item);
  690. if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
  691. outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
  692. continue;
  693. }
  694. outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
  695. }
  696. }
  697. void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
  698. MS_EXCEPTION_IF_NULL(callback);
  699. summary_callback_ = callback;
  700. }
  701. void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
  702. MS_EXCEPTION_IF_NULL(node_list);
  703. std::vector<CNodePtr> all_opt_list;
  704. std::vector<CNodePtr> non_opt_list;
  705. for (const auto &node : *node_list) {
  706. MS_EXCEPTION_IF_NULL(node);
  707. if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) {
  708. all_opt_list.emplace_back(node);
  709. } else {
  710. non_opt_list.emplace_back(node);
  711. }
  712. }
  713. node_list->clear();
  714. (void)std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list));
  715. (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
  716. }
  717. void SessionBasic::Summary(KernelGraph *graph) {
  718. if (summary_callback_ == nullptr) {
  719. return;
  720. }
  721. MS_EXCEPTION_IF_NULL(graph);
  722. bool exist_summary = ExistSummaryNode(graph);
  723. if (!exist_summary) {
  724. return;
  725. }
  726. std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs;
  727. GetSummaryNodes(graph, &summary_outputs);
  728. std::map<std::string, tensor::TensorPtr> params_list;
  729. // fetch outputs apply kernel in session & run callback functions
  730. for (auto &output_item : summary_outputs) {
  731. auto node = output_item.second.first;
  732. size_t index = IntToSize(output_item.second.second);
  733. auto address = AnfAlgo::GetOutputAddr(node, index);
  734. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  735. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  736. std::vector<int> temp_shape;
  737. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  738. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  739. MS_EXCEPTION_IF_NULL(address);
  740. if (!address->GetPtr()) {
  741. continue;
  742. }
  743. if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
  744. tensor->data_type(), tensor->data_c(true))) {
  745. MS_LOG(ERROR) << "Failed to sync output from device to host.";
  746. }
  747. tensor->set_dirty(false);
  748. params_list[output_item.first] = tensor;
  749. }
  750. // call callback function here
  751. summary_callback_(0, params_list);
  752. }
  753. CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
  754. MS_EXCEPTION_IF_NULL(graph);
  755. std::vector<AnfNodePtr> output_args;
  756. for (const auto &output : outputs) {
  757. MS_LOG(INFO) << "output:" << output->DebugString();
  758. }
  759. auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
  760. auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
  761. if (backend_anf != nullptr) {
  762. return backend_anf;
  763. }
  764. MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
  765. };
  766. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  767. (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
  768. [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
  769. return graph->NewCNode(output_args);
  770. }
  771. void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
  772. MS_LOG(INFO) << "Start!";
  773. std::vector<AnfNodePtr> make_tuple_inputs;
  774. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  775. if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
  776. for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
  777. auto idx = NewValueNode(SizeToInt(output_index));
  778. MS_EXCEPTION_IF_NULL(idx);
  779. auto imm = std::make_shared<Int32Imm>(output_index);
  780. idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
  781. MS_EXCEPTION_IF_NULL(graph);
  782. auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
  783. std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
  784. std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
  785. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
  786. make_tuple_inputs.push_back(getitem);
  787. }
  788. } else {
  789. make_tuple_inputs.push_back(cnode);
  790. }
  791. // create output
  792. auto g_output = graph->NewCNode(make_tuple_inputs);
  793. graph->set_output(g_output);
  794. // set graph manager,which now is only used to get valuenodes and hardware optimizing
  795. MS_EXCEPTION_IF_NULL(context_);
  796. FuncGraphManagerPtr manager = context_->manager();
  797. if (manager != nullptr) {
  798. manager->AddFuncGraph(graph);
  799. graph->set_manager(manager);
  800. }
  801. MS_LOG(INFO) << "Finish!";
  802. }
  803. std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  804. const std::vector<tensor::TensorPtr> &input_tensors,
  805. const std::vector<int> &tensors_mask) {
  806. auto graph = std::make_shared<KernelGraph>();
  807. std::vector<AnfNodePtr> inputs;
  808. // set input[0]
  809. PrimitivePtr op_prim = op_run_info.py_primitive;
  810. MS_EXCEPTION_IF_NULL(op_prim);
  811. inputs.push_back(std::make_shared<ValueNode>(op_prim));
  812. // set input parameter
  813. MS_LOG(INFO) << "Input tensor size: " << input_tensors.size();
  814. if (input_tensors.size() != tensors_mask.size()) {
  815. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  816. << tensors_mask.size();
  817. }
  818. for (size_t i = 0; i < input_tensors.size(); ++i) {
  819. if (tensors_mask[i] == kValueNodeTensorMask) {
  820. auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]);
  821. inputs.push_back(value_node);
  822. continue;
  823. }
  824. auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
  825. inputs.push_back(parameter);
  826. graph->MutableInputs()->push_back(parameter);
  827. }
  828. // set execution order
  829. auto cnode = graph->NewCNode(inputs);
  830. MS_EXCEPTION_IF_NULL(cnode);
  831. // set abstract,which include inferred shapes and types
  832. cnode->set_abstract(op_run_info.abstract);
  833. // set execution order
  834. std::vector<CNodePtr> exe_order = {cnode};
  835. graph->set_execution_order(exe_order);
  836. // set output
  837. CreateOutputNode(cnode, graph);
  838. return graph;
  839. }
  840. BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
  841. if (utils::isa<VectorRef>(base_ref)) {
  842. auto ref_list = utils::cast<VectorRef>(base_ref);
  843. py::tuple output_tensors(ref_list.size());
  844. for (size_t i = 0; i < ref_list.size(); ++i) {
  845. auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
  846. if (utils::isa<tensor::TensorPtr>(output)) {
  847. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  848. MS_EXCEPTION_IF_NULL(tensor_ptr);
  849. output_tensors[i] = tensor_ptr;
  850. } else if (utils::isa<PyObjectRef>(output)) {
  851. py::object obj = utils::cast<PyObjectRef>(output).object_;
  852. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  853. output_tensors[i] = tensor_tuple;
  854. } else {
  855. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  856. }
  857. }
  858. return output_tensors; // turn tuple to py::object and store in PyObjectRef
  859. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  860. return base_ref;
  861. } else {
  862. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  863. }
  864. }
  865. KernelGraphPtr SessionBasic::NewKernelGraph() {
  866. auto graph = std::make_shared<KernelGraph>();
  867. graph->set_graph_id(graph_sum_);
  868. graphs_[graph_sum_++] = graph;
  869. return graph;
  870. }
  871. } // namespace session
  872. } // namespace mindspore