You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.cc 39 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/session_basic.h"
  17. #include <utility>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include "pipeline/parse/data_converter.h"
  22. #include "ir/manager.h"
  23. #include "ir/param_value_py.h"
  24. #include "operator/ops.h"
  25. #include "common/trans.h"
  26. #include "utils/context/ms_context.h"
  27. #include "utils/config_manager.h"
  28. #include "session/anf_runtime_algorithm.h"
  29. #include "kernel/oplib/oplib.h"
  30. #include "pre_activate/common/common_backend_optimization.h"
  31. #include "pre_activate/pass/const_input_to_attr_registry.h"
  32. #include "pre_activate/common/helper.h"
  33. #include "common/utils.h"
  34. #include "ir/dtype.h"
  35. #include "ir/anf.h"
  36. namespace mindspore {
  37. namespace session {
  38. static std::shared_ptr<std::map<PyObject *, ParameterPtr>> python_paras_;
  39. void ClearPythonParasMap() { python_paras_ = nullptr; }
  40. namespace {
  41. const int kSummaryGetItem = 2;
  42. PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
  43. if (node == nullptr) {
  44. return nullptr;
  45. }
  46. auto parameter = node->cast<ParameterPtr>();
  47. if (parameter == nullptr || !parameter->has_default()) {
  48. return nullptr;
  49. }
  50. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
  51. auto py_param = param_value->value();
  52. return py_param.ptr();
  53. }
  54. BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
  55. const std::vector<tensor::TensorPtr> &input_tensors) {
  56. MS_EXCEPTION_IF_NULL(node);
  57. MS_LOG(INFO) << "create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
  58. // if node is a value node, no need sync addr from device to host
  59. if (!AnfAlgo::OutputAddrExist(node, output_index)) {
  60. if (node->isa<ValueNode>()) {
  61. auto value_node = node->cast<ValueNodePtr>();
  62. MS_EXCEPTION_IF_NULL(value_node);
  63. return value_node->value();
  64. }
  65. if (node->isa<Parameter>()) {
  66. for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
  67. if (input_idx > input_tensors.size()) {
  68. MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size();
  69. }
  70. if (graph.inputs()[input_idx] == node) {
  71. return input_tensors[input_idx];
  72. }
  73. }
  74. MS_LOG(EXCEPTION) << "parameter : " << node->DebugString() << "has no output addr";
  75. }
  76. }
  77. // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
  78. auto address = AnfAlgo::GetOutputAddr(node, output_index);
  79. MS_EXCEPTION_IF_NULL(address);
  80. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  81. TypeId type_id = kNumberTypeFloat32;
  82. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  83. std::vector<int> temp_shape;
  84. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  85. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  86. // if in paynative mode,data only copyed to host when user want to print data
  87. auto ms_context = MsContext::GetInstance();
  88. MS_EXCEPTION_IF_NULL(ms_context);
  89. if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
  90. tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
  91. tensor->set_dirty(false);
  92. } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
  93. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  94. tensor->data_c(true))) {
  95. MS_LOG(INFO) << "output sync device to host error!!!";
  96. tensor->set_dirty(false);
  97. }
  98. return tensor;
  99. }
  100. BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  101. const std::vector<tensor::TensorPtr> &input_tensors) {
  102. MS_EXCEPTION_IF_NULL(anf);
  103. MS_LOG(INFO) << "create tensor for output[" << anf->DebugString() << "]";
  104. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  105. MS_EXCEPTION_IF_NULL(item_with_index.first);
  106. MS_LOG(INFO) << "create tensor for output after visit:" << item_with_index.first->DebugString();
  107. // special handle for maketuple
  108. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  109. auto cnode = item_with_index.first->cast<CNodePtr>();
  110. MS_EXCEPTION_IF_NULL(cnode);
  111. VectorRef ret;
  112. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  113. auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
  114. ret.push_back(out);
  115. }
  116. return ret;
  117. }
  118. // if is graph return nothing ,the function should return a null anylist
  119. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  120. if (size == 0) {
  121. return VectorRef();
  122. }
  123. return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
  124. }
  125. BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  126. const std::vector<tensor::TensorPtr> &input_tensors) {
  127. MS_EXCEPTION_IF_NULL(anf);
  128. if (!AnfAlgo::IsRealKernel(anf)) {
  129. MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] should be a executable kernel";
  130. }
  131. if (anf->isa<ValueNode>()) {
  132. return CreateOneTensor(anf, 0, graph, input_tensors);
  133. }
  134. VectorRef ret;
  135. if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
  136. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
  137. auto out = CreateOneTensor(anf, i, graph, input_tensors);
  138. ret.emplace_back(out);
  139. }
  140. }
  141. return ret;
  142. }
  143. ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
  144. auto value_node = anf->cast<ValueNodePtr>();
  145. MS_EXCEPTION_IF_NULL(value_node);
  146. auto value = value_node->value();
  147. MS_EXCEPTION_IF_NULL(value);
  148. if (value->isa<None>()) {
  149. return nullptr;
  150. }
  151. auto new_value_node = graph->NewValueNode(value_node);
  152. graph->FrontBackendlMapAdd(anf, new_value_node);
  153. graph->AddValueNodeToGraph(new_value_node);
  154. return new_value_node;
  155. }
  156. std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) {
  157. MS_EXCEPTION_IF_NULL(node);
  158. MS_EXCEPTION_IF_NULL(graph);
  159. std::vector<AnfNodePtr> parameters;
  160. std::vector<AnfNodePtr> pre_graph_out = {node};
  161. // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
  162. if (!AnfAlgo::IsRealKernel(node)) {
  163. pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
  164. }
  165. auto valid_inputs = graph->MutableValidInputs();
  166. MS_EXCEPTION_IF_NULL(valid_inputs);
  167. auto graph_inputs = graph->MutableInputs();
  168. MS_EXCEPTION_IF_NULL(graph_inputs);
  169. auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
  170. auto parameter = graph->NewParameter();
  171. MS_EXCEPTION_IF_NULL(parameter);
  172. parameter->set_abstract(abstract);
  173. auto new_parameter = graph->NewParameter(parameter);
  174. parameters.push_back(new_parameter);
  175. valid_inputs->push_back(valid_input);
  176. graph_inputs->push_back(new_parameter);
  177. };
  178. for (const auto &out_node : pre_graph_out) {
  179. MS_EXCEPTION_IF_NULL(out_node);
  180. auto abstract = out_node->abstract();
  181. MS_EXCEPTION_IF_NULL(abstract);
  182. // create multiple parameters if is a tuple output real kernel
  183. if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  184. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  185. MS_EXCEPTION_IF_NULL(tuple_abstract);
  186. MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]";
  187. for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
  188. create_parameter((*tuple_abstract)[output_idx]);
  189. }
  190. continue;
  191. }
  192. // create single parameter if is a abstract real kernel
  193. create_parameter(out_node->abstract());
  194. }
  195. return parameters;
  196. }
  197. size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
  198. MS_LOG(INFO) << "Load kInputCtrlTensors";
  199. auto inputs_params = graph->input_ctrl_tensors();
  200. if (inputs_params == nullptr) {
  201. return 0;
  202. }
  203. if (inputs_params->empty()) {
  204. MS_LOG(EXCEPTION) << "Illegal empty inputs_params";
  205. }
  206. auto tensor = (*inputs_params)[0];
  207. MS_EXCEPTION_IF_NULL(tensor);
  208. auto *val = static_cast<int32_t *>(tensor->data_c(true));
  209. MS_EXCEPTION_IF_NULL(val);
  210. *val = 0;
  211. tensor->set_dirty(true);
  212. // set loop_count to zero
  213. MS_EXCEPTION_IF_NULL(inputs);
  214. inputs->push_back(tensor);
  215. return inputs_params->size();
  216. }
  217. ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor) {
  218. MS_EXCEPTION_IF_NULL(graph);
  219. MS_EXCEPTION_IF_NULL(input_tensor);
  220. auto value_node = std::make_shared<ValueNode>(input_tensor);
  221. // construct abstract of value node
  222. auto type_of_tensor = input_tensor->Dtype();
  223. auto shape_of_tensor = input_tensor->shape();
  224. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  225. value_node->set_abstract(abstract);
  226. // add value node to graph
  227. auto input_value_node = graph->NewValueNode(value_node);
  228. graph->AddValueNodeToGraph(input_value_node);
  229. return input_value_node;
  230. }
  231. ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
  232. int tensor_mask) {
  233. auto param = graph->NewParameter();
  234. MS_EXCEPTION_IF_NULL(param);
  235. if (tensor_mask == kParameterWeightTensorMask) {
  236. py::object obj;
  237. auto param_value_new = std::make_shared<ParamValuePy>(obj);
  238. param->set_default_param(param_value_new);
  239. }
  240. // set the kernel info of parameter
  241. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  242. MS_EXCEPTION_IF_NULL(input_tensor);
  243. if (input_tensor->device_address().get() == nullptr) {
  244. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  245. TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
  246. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
  247. } else {
  248. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{input_tensor->device_address()->format()});
  249. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()});
  250. }
  251. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
  252. // construct abstract of parameter
  253. auto type_of_tensor = input_tensor->Dtype();
  254. auto shape_of_tensor = input_tensor->shape();
  255. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  256. param->set_abstract(abstract);
  257. return param;
  258. }
  259. void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
  260. MS_LOG(INFO) << "graph outputs:";
  261. const size_t max_deep = 10;
  262. if (recurse_level > max_deep) {
  263. MS_LOG(INFO) << "recurse too deep";
  264. return;
  265. }
  266. std::string tab_str;
  267. for (size_t i = 0; i < recurse_level; i++) {
  268. tab_str = tab_str.append(" ");
  269. }
  270. if (any.is<AnyList>()) {
  271. (void)tab_str.append("{");
  272. MS_LOG(INFO) << tab_str;
  273. auto any_list = any.cast<AnyList>();
  274. for (auto &it : any_list) {
  275. DumpGraphOutput(it, recurse_level + 1);
  276. }
  277. (void)tab_str.append("}");
  278. MS_LOG(INFO) << tab_str;
  279. }
  280. (void)tab_str.append(any.ToString());
  281. MS_LOG(INFO) << tab_str;
  282. }
  283. bool ExistSummaryNode(const KernelGraph *graph) {
  284. auto ret = graph->get_return();
  285. MS_EXCEPTION_IF_NULL(ret);
  286. auto all_nodes = DeepLinkedGraphSearch(ret);
  287. for (auto &n : all_nodes) {
  288. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  289. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  290. return true;
  291. }
  292. }
  293. return false;
  294. }
  295. } // namespace
  296. GraphId SessionBasic::graph_sum_ = 0;
  297. ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input,
  298. KernelGraph *graph) {
  299. MS_EXCEPTION_IF_NULL(anf);
  300. if (!anf->isa<Parameter>()) {
  301. MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
  302. }
  303. auto m_tensor = GetParamDefaultInputTensor(anf);
  304. auto valid_inputs = graph->MutableValidInputs();
  305. MS_EXCEPTION_IF_NULL(valid_inputs);
  306. auto graph_inputs = graph->MutableInputs();
  307. MS_EXCEPTION_IF_NULL(graph_inputs);
  308. ParameterPtr new_parameter = nullptr;
  309. // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
  310. if (python_paras_ == nullptr) {
  311. python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
  312. }
  313. auto iter = python_paras_->find(m_tensor);
  314. if (iter != python_paras_->end()) {
  315. new_parameter = iter->second;
  316. } else {
  317. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  318. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  319. if (m_tensor != nullptr) {
  320. (*python_paras_)[m_tensor] = new_parameter;
  321. }
  322. TraceManager::EndTrace();
  323. }
  324. graph_inputs->push_back(new_parameter);
  325. valid_inputs->push_back(valid_input);
  326. return new_parameter;
  327. }
  328. AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
  329. MS_EXCEPTION_IF_NULL(anf);
  330. MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
  331. auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
  332. if (parameters.empty()) {
  333. MS_LOG(EXCEPTION) << "No parameter exist!!";
  334. }
  335. if (parameters.size() == 1) {
  336. return parameters[0];
  337. }
  338. std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
  339. (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
  340. auto make_tuple = graph->NewCNode(make_tuple_input);
  341. MS_EXCEPTION_IF_NULL(make_tuple);
  342. MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
  343. return make_tuple;
  344. }
  345. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
  346. bool *from_other_graph,
  347. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  348. MS_EXCEPTION_IF_NULL(cnode);
  349. MS_EXCEPTION_IF_NULL(graph);
  350. MS_EXCEPTION_IF_NULL(from_other_graph);
  351. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  352. *from_other_graph = false;
  353. // get primitive of old node
  354. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  355. MS_EXCEPTION_IF_NULL(prim);
  356. // push attr to inputs[0] of new cnode
  357. std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
  358. // if has multiple depends,only select first depend as parameter
  359. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  360. auto anf = cnode->inputs()[input_idx];
  361. MS_EXCEPTION_IF_NULL(anf);
  362. // anf has been created before
  363. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  364. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  365. continue;
  366. } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
  367. cnode_inputs.push_back((*other_graph_cnode)[anf]);
  368. continue;
  369. } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
  370. // if input is a value node,
  371. auto new_value_node = CreateNewValueNode(anf, graph);
  372. if (new_value_node != nullptr) {
  373. cnode_inputs.emplace_back(new_value_node);
  374. }
  375. continue;
  376. } else if (anf->isa<Parameter>() && AnfAlgo::GetOutputTensorNum(anf) == 1) {
  377. auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
  378. cnode_inputs.push_back(new_parameter);
  379. if (GetGraphIdByNode(anf) == kInvalidGraphId) {
  380. graph->FrontBackendlMapAdd(anf, new_parameter);
  381. } else {
  382. (*other_graph_cnode)[anf] = new_parameter;
  383. }
  384. continue;
  385. } else if (anf->isa<AnfNode>()) {
  386. *from_other_graph = true;
  387. // the input node is a cnode from other graph
  388. auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
  389. cnode_inputs.push_back(parameter_from_cnode);
  390. (*other_graph_cnode)[anf] = parameter_from_cnode;
  391. continue;
  392. }
  393. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  394. }
  395. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  396. auto new_cnode = graph->NewCNode(cnode_inputs);
  397. TraceManager::EndTrace();
  398. return new_cnode;
  399. }
  400. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
  401. MS_EXCEPTION_IF_NULL(cnode);
  402. MS_EXCEPTION_IF_NULL(graph);
  403. std::vector<AnfNodePtr> cnode_inputs;
  404. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  405. MS_EXCEPTION_IF_NULL(attr_input);
  406. if (IsValueNode<FuncGraph>(attr_input)) {
  407. // create primitive of cnode:call
  408. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  409. // create a ValueNode<KernelGraph> as input of cnode:call
  410. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  411. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
  412. } else {
  413. auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
  414. if (new_value_node != nullptr) {
  415. cnode_inputs.emplace_back(new_value_node);
  416. }
  417. }
  418. } else if (attr_input->isa<CNode>()) {
  419. // create primitive of cnode:call(switch)
  420. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  421. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  422. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  423. if (!AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
  424. MS_LOG(EXCEPTION) << "CNode input[0] must be switch.";
  425. }
  426. cnode_inputs.emplace_back(cnode_input);
  427. } else {
  428. MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
  429. << ", but input[0] has not been created.";
  430. }
  431. } else {
  432. // get primitive of old node
  433. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  434. MS_EXCEPTION_IF_NULL(prim);
  435. // push attr to inputs[0] of new cnode
  436. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
  437. }
  438. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  439. auto anf = cnode->input(input_idx);
  440. MS_EXCEPTION_IF_NULL(anf);
  441. // anf has been created before
  442. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  443. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  444. continue;
  445. } else if (IsValueNode<None>(anf)) {
  446. continue;
  447. }
  448. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  449. }
  450. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  451. auto new_cnode = graph->NewCNode(cnode_inputs);
  452. TraceManager::EndTrace();
  453. return new_cnode;
  454. }
  455. ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
  456. MS_EXCEPTION_IF_NULL(anf);
  457. auto value_node = anf->cast<ValueNodePtr>();
  458. MS_EXCEPTION_IF_NULL(value_node);
  459. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
  460. MS_EXCEPTION_IF_NULL(sub_func_graph);
  461. if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
  462. MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
  463. }
  464. auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
  465. ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
  466. new_value_node->set_abstract(value_node->abstract());
  467. // create new kernel_info of new value_node
  468. auto kernel_info = std::make_shared<device::KernelInfo>();
  469. kernel_info->SetFeatureMapFlag(false);
  470. new_value_node->set_kernel_info(kernel_info);
  471. // create kernel_build_info for new value node
  472. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  473. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  474. AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
  475. graph->FrontBackendlMapAdd(anf, new_value_node);
  476. return new_value_node;
  477. }
  478. ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  479. MS_EXCEPTION_IF_NULL(anf);
  480. if (!anf->isa<Parameter>()) {
  481. MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
  482. }
  483. auto m_tensor = GetParamDefaultInputTensor(anf);
  484. ParameterPtr new_parameter = nullptr;
  485. if (python_paras_ == nullptr) {
  486. python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
  487. }
  488. auto iter = python_paras_->find(m_tensor);
  489. if (iter != python_paras_->end()) {
  490. new_parameter = iter->second;
  491. } else {
  492. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  493. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  494. if (m_tensor != nullptr) {
  495. (*python_paras_)[m_tensor] = new_parameter;
  496. }
  497. TraceManager::EndTrace();
  498. }
  499. return new_parameter;
  500. }
  501. KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  502. std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
  503. auto graph = NewKernelGraph();
  504. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  505. size_t from_other_graph_depend_num = 0;
  506. for (const auto &node : lst) {
  507. MS_EXCEPTION_IF_NULL(node);
  508. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  509. if (!node->isa<CNode>()) {
  510. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
  511. }
  512. auto cnode = node->cast<CNodePtr>();
  513. MS_EXCEPTION_IF_NULL(cnode);
  514. // create a new cnode object
  515. bool from_other_graph = false;
  516. // only first depend from other graph can create
  517. bool valid_input = true;
  518. if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
  519. valid_input = false;
  520. }
  521. auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
  522. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
  523. from_other_graph_depend_num++;
  524. }
  525. MS_EXCEPTION_IF_NULL(new_cnode);
  526. new_cnode->set_abstract(cnode->abstract());
  527. new_cnode->set_scope(cnode->scope());
  528. // record map relations between anf from ME and new anf node used in backend
  529. graph->FrontBackendlMapAdd(node, new_cnode);
  530. }
  531. // add a make_tuple at the end of graph as output
  532. graph->set_output(ConstructOutput(outputs, graph));
  533. MS_EXCEPTION_IF_NULL(context_);
  534. FuncGraphManagerPtr manager = MakeManager({graph});
  535. if (manager) {
  536. manager->AddFuncGraph(graph);
  537. graph->set_manager(manager);
  538. }
  539. graph->SetExecOrderByDefault();
  540. if (ExistSummaryNode(graph.get())) {
  541. graph->set_summary_node_exist(true);
  542. }
  543. opt::BackendCommonOptimization(graph);
  544. return graph;
  545. }
  546. std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph) {
  547. MS_EXCEPTION_IF_NULL(func_graph);
  548. auto node_list = TopoSort(func_graph->get_return());
  549. auto graph = NewKernelGraph();
  550. front_backend_graph_map_[func_graph] = graph;
  551. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  552. bool is_trace_back = false;
  553. for (const auto &node : node_list) {
  554. MS_EXCEPTION_IF_NULL(node);
  555. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  556. if (node->isa<Parameter>()) {
  557. auto graph_inputs = graph->MutableInputs();
  558. MS_EXCEPTION_IF_NULL(graph_inputs);
  559. auto new_parameter = CreateNewParameter(node, graph.get());
  560. graph_inputs->push_back(new_parameter);
  561. graph->FrontBackendlMapAdd(node, new_parameter);
  562. continue;
  563. } else if (node->isa<ValueNode>()) {
  564. if (!IsValueNode<FuncGraph>(node)) {
  565. // if input is a common value node,
  566. (void)CreateNewValueNode(node, graph.get());
  567. } else {
  568. // if input is a ValueNode<FuncGraph>
  569. FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
  570. if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
  571. is_trace_back = true;
  572. } else {
  573. (void)ConstructKernelGraph(child_graph);
  574. }
  575. (void)CreateValueNodeKernelGraph(node, graph.get());
  576. }
  577. continue;
  578. } else {
  579. auto cnode = node->cast<CNodePtr>();
  580. MS_EXCEPTION_IF_NULL(cnode);
  581. // create a new cnode object
  582. auto new_cnode = CreateNewCNode(cnode, graph.get());
  583. MS_EXCEPTION_IF_NULL(new_cnode);
  584. new_cnode->set_abstract(cnode->abstract());
  585. new_cnode->set_scope(cnode->scope());
  586. graph->FrontBackendlMapAdd(node, new_cnode);
  587. if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
  588. graph->set_return(new_cnode);
  589. }
  590. }
  591. }
  592. // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
  593. graph->set_output_null(is_trace_back);
  594. AddParameterToGraphInputs(func_graph->parameters(), graph.get());
  595. graph->SetExecOrderByDefault();
  596. if (ExistSummaryNode(graph.get())) {
  597. graph->set_summary_node_exist(true);
  598. }
  599. opt::BackendCommonOptimization(graph);
  600. return graph;
  601. }
  602. void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) {
  603. MS_EXCEPTION_IF_NULL(graph);
  604. auto graph_inputs = graph->MutableInputs();
  605. MS_EXCEPTION_IF_NULL(graph_inputs);
  606. graph_inputs->clear();
  607. for (auto &parameter : parameters) {
  608. MS_EXCEPTION_IF_NULL(parameter);
  609. auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
  610. if (backend_parameter == nullptr) {
  611. // for example "def f(x,y,z) {return x + y}", parameter z in unused
  612. auto new_parameter = CreateNewParameter(parameter, graph);
  613. graph_inputs->push_back(new_parameter);
  614. MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
  615. continue;
  616. }
  617. MS_LOG(INFO) << "graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
  618. graph_inputs->push_back(backend_parameter);
  619. }
  620. }
  621. // run graph steps
  622. void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  623. const std::vector<tensor::TensorPtr> &inputs_const) const {
  624. std::vector<tensor::TensorPtr> inputs(inputs_const);
  625. size_t input_ctrl_size = 1;
  626. MS_EXCEPTION_IF_NULL(kernel_graph);
  627. if (kernel_graph->input_ctrl_tensors()) {
  628. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  629. }
  630. auto input_nodes = kernel_graph->inputs();
  631. if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
  632. MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  633. << ", input_ctrl_size:" << input_ctrl_size;
  634. }
  635. auto ms_context = MsContext::GetInstance();
  636. MS_EXCEPTION_IF_NULL(ms_context);
  637. for (size_t i = 0; i < inputs.size(); ++i) {
  638. auto tensor = inputs[i];
  639. MS_EXCEPTION_IF_NULL(tensor);
  640. auto input_node = input_nodes[i];
  641. MS_EXCEPTION_IF_NULL(input_node);
  642. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  643. auto pk_node = input_node->cast<ParameterPtr>();
  644. auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
  645. bool need_sync = false;
  646. if (ms_context->enable_pynative_infer()) {
  647. if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
  648. need_sync = true;
  649. }
  650. } else {
  651. if (tensor->is_dirty()) {
  652. need_sync = true;
  653. } else if (tensor->device_address() != device_address) {
  654. (void)tensor->data_sync();
  655. need_sync = true;
  656. }
  657. }
  658. if (need_sync) {
  659. if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) {
  660. tensor->set_device_address(device_address);
  661. }
  662. MS_EXCEPTION_IF_NULL(device_address);
  663. if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
  664. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  665. tensor->data_c(false))) {
  666. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  667. }
  668. }
  669. }
  670. tensor->set_dirty(false);
  671. }
  672. }
  673. void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  674. const std::vector<tensor::TensorPtr> &input_tensors) const {
  675. MS_EXCEPTION_IF_NULL(kernel_graph);
  676. MS_EXCEPTION_IF_NULL(outputs);
  677. if (!kernel_graph->child_graph_order().empty()) {
  678. // use the last child graph output as the root graph output
  679. UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors);
  680. return;
  681. }
  682. auto anf_outputs = kernel_graph->outputs();
  683. for (auto &item : anf_outputs) {
  684. MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
  685. MS_EXCEPTION_IF_NULL(item);
  686. if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
  687. outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
  688. continue;
  689. }
  690. outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
  691. }
  692. }
  693. void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
  694. MS_EXCEPTION_IF_NULL(callback);
  695. summary_callback_ = callback;
  696. }
  697. void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
  698. MS_EXCEPTION_IF_NULL(node_list);
  699. std::vector<CNodePtr> all_opt_list;
  700. std::vector<CNodePtr> non_opt_list;
  701. for (const auto &node : *node_list) {
  702. MS_EXCEPTION_IF_NULL(node);
  703. if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) {
  704. all_opt_list.emplace_back(node);
  705. } else {
  706. non_opt_list.emplace_back(node);
  707. }
  708. }
  709. node_list->clear();
  710. (void)std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list));
  711. (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
  712. }
  713. void SessionBasic::GetSummaryNodes(const KernelGraph *graph, NamedSummaryOutputs *summary) {
  714. MS_LOG(DEBUG) << "Update summary Start";
  715. MS_EXCEPTION_IF_NULL(graph);
  716. MS_EXCEPTION_IF_NULL(summary);
  717. if (!graph->summary_node_exist()) {
  718. return;
  719. }
  720. auto apply_list = TopoSort(graph->get_return());
  721. for (auto &n : apply_list) {
  722. MS_EXCEPTION_IF_NULL(n);
  723. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  724. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  725. auto cnode = n->cast<CNodePtr>();
  726. MS_EXCEPTION_IF_NULL(cnode);
  727. if (cnode->inputs().size() <= kSummaryGetItem) {
  728. MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
  729. }
  730. auto node = cnode->input(kSummaryGetItem);
  731. MS_EXCEPTION_IF_NULL(node);
  732. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0);
  733. if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
  734. MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
  735. }
  736. (*summary)[n->fullname_with_scope()] = item_with_index;
  737. }
  738. }
  739. MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
  740. }
  741. void SessionBasic::Summary(KernelGraph *graph) {
  742. if (summary_callback_ == nullptr) {
  743. return;
  744. }
  745. MS_EXCEPTION_IF_NULL(graph);
  746. NamedSummaryOutputs summary_outputs;
  747. GetSummaryNodes(graph, &summary_outputs);
  748. // do not exist summary node
  749. if (summary_outputs.empty()) {
  750. return;
  751. }
  752. std::map<std::string, tensor::TensorPtr> params_list;
  753. // fetch outputs apply kernel in session & run callback functions
  754. for (auto &output_item : summary_outputs) {
  755. auto node = output_item.second.first;
  756. size_t index = IntToSize(output_item.second.second);
  757. auto address = AnfAlgo::GetOutputAddr(node, index);
  758. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  759. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  760. std::vector<int> temp_shape;
  761. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  762. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  763. MS_EXCEPTION_IF_NULL(address);
  764. if (!address->GetPtr()) {
  765. continue;
  766. }
  767. if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
  768. tensor->data_type(), tensor->data_c(true))) {
  769. MS_LOG(ERROR) << "Failed to sync output from device to host.";
  770. }
  771. tensor->set_dirty(false);
  772. params_list[output_item.first] = tensor;
  773. }
  774. // call callback function here
  775. summary_callback_(0, params_list);
  776. }
  777. CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
  778. MS_EXCEPTION_IF_NULL(graph);
  779. std::vector<AnfNodePtr> output_args;
  780. for (const auto &output : outputs) {
  781. MS_LOG(INFO) << "output:" << output->DebugString();
  782. }
  783. auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
  784. auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
  785. if (backend_anf != nullptr) {
  786. return backend_anf;
  787. }
  788. MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
  789. };
  790. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  791. (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
  792. [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
  793. return graph->NewCNode(output_args);
  794. }
  795. void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
  796. MS_LOG(INFO) << "Start!";
  797. std::vector<AnfNodePtr> make_tuple_inputs;
  798. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  799. if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
  800. for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
  801. auto idx = NewValueNode(SizeToInt(output_index));
  802. MS_EXCEPTION_IF_NULL(idx);
  803. auto imm = std::make_shared<Int32Imm>(output_index);
  804. idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
  805. MS_EXCEPTION_IF_NULL(graph);
  806. auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
  807. std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
  808. std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
  809. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
  810. make_tuple_inputs.push_back(getitem);
  811. }
  812. } else {
  813. make_tuple_inputs.push_back(cnode);
  814. }
  815. // create output
  816. auto g_output = graph->NewCNode(make_tuple_inputs);
  817. graph->set_output(g_output);
  818. // set graph manager,which now is only used to get valuenodes and hardware optimizing
  819. MS_EXCEPTION_IF_NULL(context_);
  820. FuncGraphManagerPtr manager = context_->manager();
  821. if (manager != nullptr) {
  822. manager->AddFuncGraph(graph);
  823. graph->set_manager(manager);
  824. }
  825. MS_LOG(INFO) << "Finish!";
  826. }
  827. std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  828. const std::vector<tensor::TensorPtr> &input_tensors,
  829. const std::vector<int> &tensors_mask) {
  830. auto graph = std::make_shared<KernelGraph>();
  831. std::vector<AnfNodePtr> inputs;
  832. // set input[0]
  833. PrimitivePtr op_prim = op_run_info.py_primitive;
  834. MS_EXCEPTION_IF_NULL(op_prim);
  835. inputs.push_back(std::make_shared<ValueNode>(op_prim));
  836. // set input parameter
  837. MS_LOG(INFO) << "Input tensor size: " << input_tensors.size();
  838. if (input_tensors.size() != tensors_mask.size()) {
  839. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  840. << tensors_mask.size();
  841. }
  842. for (size_t i = 0; i < input_tensors.size(); ++i) {
  843. if (tensors_mask[i] == kValueNodeTensorMask) {
  844. auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]);
  845. inputs.push_back(value_node);
  846. continue;
  847. }
  848. auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
  849. inputs.push_back(parameter);
  850. graph->MutableInputs()->push_back(parameter);
  851. }
  852. // set execution order
  853. auto cnode = graph->NewCNode(inputs);
  854. MS_EXCEPTION_IF_NULL(cnode);
  855. // set abstract,which include inferred shapes and types
  856. cnode->set_abstract(op_run_info.abstract);
  857. // set execution order
  858. std::vector<CNodePtr> exe_order = {cnode};
  859. graph->set_execution_order(exe_order);
  860. // set output
  861. CreateOutputNode(cnode, graph);
  862. return graph;
  863. }
  864. BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
  865. if (utils::isa<VectorRef>(base_ref)) {
  866. auto ref_list = utils::cast<VectorRef>(base_ref);
  867. py::tuple output_tensors(ref_list.size());
  868. for (size_t i = 0; i < ref_list.size(); ++i) {
  869. auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
  870. if (utils::isa<tensor::TensorPtr>(output)) {
  871. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  872. MS_EXCEPTION_IF_NULL(tensor_ptr);
  873. output_tensors[i] = tensor_ptr;
  874. } else if (utils::isa<PyObjectRef>(output)) {
  875. py::object obj = utils::cast<PyObjectRef>(output).object_;
  876. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  877. output_tensors[i] = tensor_tuple;
  878. } else {
  879. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  880. }
  881. }
  882. return output_tensors; // turn tuple to py::object and store in PyObjectRef
  883. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  884. return base_ref;
  885. } else {
  886. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  887. }
  888. }
  889. KernelGraphPtr SessionBasic::NewKernelGraph() {
  890. auto graph = std::make_shared<KernelGraph>();
  891. graph->set_graph_id(graph_sum_);
  892. graphs_[graph_sum_++] = graph;
  893. return graph;
  894. }
  895. } // namespace session
  896. } // namespace mindspore