You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.cc 41 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/session_basic.h"
  17. #include <utility>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include "pipeline/parse/data_converter.h"
  22. #include "ir/manager.h"
  23. #include "ir/param_value_py.h"
  24. #include "kernel/common_utils.h"
  25. #include "operator/ops.h"
  26. #include "common/trans.h"
  27. #include "utils/context/ms_context.h"
  28. #include "utils/config_manager.h"
  29. #include "session/anf_runtime_algorithm.h"
  30. #include "kernel/oplib/oplib.h"
  31. #include "pre_activate/common/common_backend_optimization.h"
  32. #include "pre_activate/pass/const_input_to_attr_registry.h"
  33. #include "pre_activate/common/helper.h"
  34. #include "common/utils.h"
  35. #include "ir/dtype.h"
  36. #include "ir/anf.h"
  37. #include "ir/func_graph_cloner.h"
  38. namespace mindspore {
  39. namespace session {
  40. static std::shared_ptr<std::map<PyObject *, ParameterPtr>> python_paras_;
  41. void ClearPythonParasMap() { python_paras_ = nullptr; }
  42. namespace {
  43. const int kSummaryGetItem = 2;
  44. PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
  45. if (node == nullptr) {
  46. return nullptr;
  47. }
  48. auto parameter = node->cast<ParameterPtr>();
  49. if (parameter == nullptr || !parameter->has_default()) {
  50. return nullptr;
  51. }
  52. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
  53. MS_EXCEPTION_IF_NULL(param_value);
  54. auto py_param = param_value->value();
  55. return py_param.ptr();
  56. }
  57. BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
  58. const std::vector<tensor::TensorPtr> &input_tensors) {
  59. MS_EXCEPTION_IF_NULL(node);
  60. MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
  61. // if node is a value node, no need sync addr from device to host
  62. if (!AnfAlgo::OutputAddrExist(node, output_index)) {
  63. if (node->isa<ValueNode>()) {
  64. auto value_node = node->cast<ValueNodePtr>();
  65. MS_EXCEPTION_IF_NULL(value_node);
  66. return value_node->value();
  67. }
  68. if (node->isa<Parameter>()) {
  69. for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
  70. if (input_idx >= input_tensors.size()) {
  71. MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
  72. }
  73. if (graph.inputs()[input_idx] == node) {
  74. return input_tensors[input_idx];
  75. }
  76. }
  77. MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr";
  78. }
  79. }
  80. // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
  81. DeviceAddressPtr address;
  82. auto is_all_nop_node = opt::IsAllNopNode(&graph);
  83. if (is_all_nop_node) {
  84. // The graph does not remove the nop node.
  85. address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
  86. } else {
  87. // The graph removes the nop node.
  88. address = AnfAlgo::GetMutableOutputAddr(node, output_index, true);
  89. }
  90. MS_EXCEPTION_IF_NULL(address);
  91. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  92. TypeId type_id = kNumberTypeFloat32;
  93. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  94. std::vector<int> temp_shape;
  95. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  96. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  97. // if in paynative mode,data only copyed to host when user want to print data
  98. auto ms_context = MsContext::GetInstance();
  99. MS_EXCEPTION_IF_NULL(ms_context);
  100. if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
  101. tensor->set_device_address(address);
  102. tensor->set_dirty(false);
  103. } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
  104. LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
  105. MS_LOG(INFO) << "Output sync device to host error!!!";
  106. tensor->set_dirty(false);
  107. }
  108. return tensor;
  109. }
  110. BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  111. const std::vector<tensor::TensorPtr> &input_tensors) {
  112. MS_EXCEPTION_IF_NULL(anf);
  113. MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
  114. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  115. MS_EXCEPTION_IF_NULL(item_with_index.first);
  116. MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
  117. // special handle for maketuple
  118. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  119. auto cnode = item_with_index.first->cast<CNodePtr>();
  120. MS_EXCEPTION_IF_NULL(cnode);
  121. VectorRef ret;
  122. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  123. auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
  124. ret.push_back(out);
  125. }
  126. return ret;
  127. }
  128. // if is graph return nothing ,the function should return a null anylist
  129. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  130. if (size == 0) {
  131. return VectorRef();
  132. }
  133. return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
  134. }
  135. BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  136. const std::vector<tensor::TensorPtr> &input_tensors) {
  137. MS_EXCEPTION_IF_NULL(anf);
  138. if (!AnfAlgo::IsRealKernel(anf)) {
  139. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel";
  140. }
  141. if (anf->isa<ValueNode>()) {
  142. return CreateOneTensor(anf, 0, graph, input_tensors);
  143. }
  144. VectorRef ret;
  145. if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
  146. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
  147. auto out = CreateOneTensor(anf, i, graph, input_tensors);
  148. ret.emplace_back(out);
  149. }
  150. }
  151. return ret;
  152. }
  153. ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
  154. MS_EXCEPTION_IF_NULL(anf);
  155. MS_EXCEPTION_IF_NULL(graph);
  156. auto value_node = anf->cast<ValueNodePtr>();
  157. MS_EXCEPTION_IF_NULL(value_node);
  158. auto value = value_node->value();
  159. MS_EXCEPTION_IF_NULL(value);
  160. if (value->isa<None>()) {
  161. return nullptr;
  162. }
  163. auto new_value_node = graph->NewValueNode(value_node);
  164. graph->FrontBackendlMapAdd(anf, new_value_node);
  165. graph->AddValueNodeToGraph(new_value_node);
  166. return new_value_node;
  167. }
  168. std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) {
  169. MS_EXCEPTION_IF_NULL(node);
  170. MS_EXCEPTION_IF_NULL(graph);
  171. std::vector<AnfNodePtr> parameters;
  172. std::vector<AnfNodePtr> pre_graph_out = {node};
  173. // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
  174. if (!AnfAlgo::IsRealKernel(node)) {
  175. pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
  176. }
  177. auto valid_inputs = graph->MutableValidInputs();
  178. MS_EXCEPTION_IF_NULL(valid_inputs);
  179. auto graph_inputs = graph->MutableInputs();
  180. MS_EXCEPTION_IF_NULL(graph_inputs);
  181. auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
  182. auto parameter = graph->NewParameter();
  183. MS_EXCEPTION_IF_NULL(parameter);
  184. parameter->set_abstract(abstract);
  185. auto new_parameter = graph->NewParameter(parameter);
  186. parameters.push_back(new_parameter);
  187. valid_inputs->push_back(valid_input);
  188. graph_inputs->push_back(new_parameter);
  189. };
  190. for (const auto &out_node : pre_graph_out) {
  191. MS_EXCEPTION_IF_NULL(out_node);
  192. auto abstract = out_node->abstract();
  193. MS_EXCEPTION_IF_NULL(abstract);
  194. // create multiple parameters if is a tuple output real kernel
  195. if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  196. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  197. MS_EXCEPTION_IF_NULL(tuple_abstract);
  198. MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
  199. for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
  200. create_parameter((*tuple_abstract)[output_idx]);
  201. }
  202. continue;
  203. }
  204. // create single parameter if is a abstract real kernel
  205. create_parameter(out_node->abstract());
  206. }
  207. return parameters;
  208. }
  209. size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
  210. MS_EXCEPTION_IF_NULL(graph);
  211. MS_LOG(INFO) << "Load kInputCtrlTensors";
  212. auto inputs_params = graph->input_ctrl_tensors();
  213. if (inputs_params == nullptr) {
  214. return 0;
  215. }
  216. if (inputs_params->empty()) {
  217. MS_LOG(EXCEPTION) << "Illegal empty inputs_params";
  218. }
  219. auto tensor = (*inputs_params)[0];
  220. MS_EXCEPTION_IF_NULL(tensor);
  221. auto *val = static_cast<int32_t *>(tensor->data_c());
  222. MS_EXCEPTION_IF_NULL(val);
  223. *val = 0;
  224. tensor->set_dirty(true);
  225. // set loop_count to zero
  226. MS_EXCEPTION_IF_NULL(inputs);
  227. inputs->push_back(tensor);
  228. return inputs_params->size();
  229. }
  230. ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor) {
  231. MS_EXCEPTION_IF_NULL(graph);
  232. MS_EXCEPTION_IF_NULL(input_tensor);
  233. auto value_node = std::make_shared<ValueNode>(input_tensor);
  234. MS_EXCEPTION_IF_NULL(value_node);
  235. // construct abstract of value node
  236. auto type_of_tensor = input_tensor->Dtype();
  237. auto shape_of_tensor = input_tensor->shape();
  238. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  239. value_node->set_abstract(abstract);
  240. // add value node to graph
  241. auto input_value_node = graph->NewValueNode(value_node);
  242. graph->AddValueNodeToGraph(input_value_node);
  243. return input_value_node;
  244. }
  245. ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
  246. int tensor_mask) {
  247. MS_EXCEPTION_IF_NULL(graph);
  248. auto param = graph->NewParameter();
  249. MS_EXCEPTION_IF_NULL(param);
  250. if (tensor_mask == kParameterWeightTensorMask) {
  251. py::object obj;
  252. auto param_value_new = std::make_shared<ParamValuePy>(obj);
  253. param->set_default_param(param_value_new);
  254. }
  255. // set the kernel info of parameter
  256. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  257. MS_EXCEPTION_IF_NULL(input_tensor);
  258. if (input_tensor->device_address().get() == nullptr) {
  259. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  260. TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
  261. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
  262. } else {
  263. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{input_tensor->device_address()->format()});
  264. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()});
  265. }
  266. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
  267. // construct abstract of parameter
  268. auto type_of_tensor = input_tensor->Dtype();
  269. auto shape_of_tensor = input_tensor->shape();
  270. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  271. param->set_abstract(abstract);
  272. return param;
  273. }
  274. void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
  275. MS_LOG(INFO) << "Graph outputs:";
  276. const size_t max_deep = 10;
  277. if (recurse_level > max_deep) {
  278. MS_LOG(INFO) << "Recurse too deep";
  279. return;
  280. }
  281. std::string tab_str;
  282. for (size_t i = 0; i < recurse_level; i++) {
  283. tab_str = tab_str.append(" ");
  284. }
  285. if (any.is<AnyList>()) {
  286. (void)tab_str.append("{");
  287. MS_LOG(INFO) << tab_str;
  288. auto any_list = any.cast<AnyList>();
  289. for (auto &it : any_list) {
  290. DumpGraphOutput(it, recurse_level + 1);
  291. }
  292. (void)tab_str.append("}");
  293. MS_LOG(INFO) << tab_str;
  294. }
  295. (void)tab_str.append(any.ToString());
  296. MS_LOG(INFO) << tab_str;
  297. }
  298. bool ExistSummaryNode(const KernelGraph *graph) {
  299. MS_EXCEPTION_IF_NULL(graph);
  300. auto ret = graph->get_return();
  301. MS_EXCEPTION_IF_NULL(ret);
  302. auto all_nodes = DeepLinkedGraphSearch(ret);
  303. for (auto &n : all_nodes) {
  304. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  305. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  306. return true;
  307. }
  308. }
  309. return false;
  310. }
  311. } // namespace
  312. GraphId SessionBasic::graph_sum_ = 0;
  313. ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input,
  314. KernelGraph *graph) {
  315. MS_EXCEPTION_IF_NULL(anf);
  316. if (!anf->isa<Parameter>()) {
  317. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  318. }
  319. MS_EXCEPTION_IF_NULL(graph);
  320. auto m_tensor = GetParamDefaultInputTensor(anf);
  321. auto valid_inputs = graph->MutableValidInputs();
  322. MS_EXCEPTION_IF_NULL(valid_inputs);
  323. auto graph_inputs = graph->MutableInputs();
  324. MS_EXCEPTION_IF_NULL(graph_inputs);
  325. ParameterPtr new_parameter = nullptr;
  326. // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
  327. if (python_paras_ == nullptr) {
  328. python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
  329. }
  330. auto iter = python_paras_->find(m_tensor);
  331. if (iter != python_paras_->end()) {
  332. new_parameter = iter->second;
  333. } else {
  334. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  335. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  336. if (m_tensor != nullptr) {
  337. (*python_paras_)[m_tensor] = new_parameter;
  338. }
  339. TraceManager::EndTrace();
  340. }
  341. graph_inputs->push_back(new_parameter);
  342. valid_inputs->push_back(valid_input);
  343. return new_parameter;
  344. }
  345. AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
  346. MS_EXCEPTION_IF_NULL(anf);
  347. MS_EXCEPTION_IF_NULL(graph);
  348. MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
  349. auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
  350. if (parameters.empty()) {
  351. MS_LOG(EXCEPTION) << "No parameter exist!!";
  352. }
  353. if (parameters.size() == 1) {
  354. return parameters[0];
  355. }
  356. std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
  357. (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
  358. auto make_tuple = graph->NewCNode(make_tuple_input);
  359. MS_EXCEPTION_IF_NULL(make_tuple);
  360. MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
  361. return make_tuple;
  362. }
  363. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
  364. bool *from_other_graph,
  365. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  366. MS_EXCEPTION_IF_NULL(cnode);
  367. MS_EXCEPTION_IF_NULL(graph);
  368. MS_EXCEPTION_IF_NULL(from_other_graph);
  369. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  370. *from_other_graph = false;
  371. // get primitive of old node
  372. std::vector<AnfNodePtr> cnode_inputs;
  373. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  374. if (prim != nullptr) {
  375. // push attr to inputs[0] of new cnode
  376. cnode_inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
  377. } else {
  378. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  379. MS_EXCEPTION_IF_NULL(fg);
  380. auto new_fg = BasicClone(fg);
  381. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  382. }
  383. auto origin_inputs = cnode->inputs();
  384. bool optimize_depend = false;
  385. if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
  386. origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) {
  387. optimize_depend = true;
  388. }
  389. // if has multiple depends,only select first depend as parameter
  390. for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
  391. auto anf = origin_inputs[input_idx];
  392. MS_EXCEPTION_IF_NULL(anf);
  393. // anf has been created before
  394. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  395. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  396. continue;
  397. } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
  398. cnode_inputs.push_back((*other_graph_cnode)[anf]);
  399. continue;
  400. } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
  401. // if input is a value node,
  402. auto new_value_node = CreateNewValueNode(anf, graph);
  403. if (new_value_node != nullptr) {
  404. cnode_inputs.emplace_back(new_value_node);
  405. }
  406. continue;
  407. } else if (anf->isa<Parameter>() && AnfAlgo::GetOutputTensorNum(anf) == 1) {
  408. auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
  409. cnode_inputs.push_back(new_parameter);
  410. if (GetGraphIdByNode(anf) == kInvalidGraphId) {
  411. graph->FrontBackendlMapAdd(anf, new_parameter);
  412. } else {
  413. (*other_graph_cnode)[anf] = new_parameter;
  414. }
  415. continue;
  416. } else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
  417. cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
  418. continue;
  419. } else if (anf->isa<AnfNode>()) {
  420. *from_other_graph = true;
  421. // the input node is a cnode from other graph
  422. auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
  423. cnode_inputs.push_back(parameter_from_cnode);
  424. (*other_graph_cnode)[anf] = parameter_from_cnode;
  425. continue;
  426. }
  427. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  428. }
  429. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  430. auto new_cnode = graph->NewCNode(cnode_inputs);
  431. TraceManager::EndTrace();
  432. return new_cnode;
  433. }
  434. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
  435. MS_EXCEPTION_IF_NULL(cnode);
  436. MS_EXCEPTION_IF_NULL(graph);
  437. std::vector<AnfNodePtr> cnode_inputs;
  438. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  439. MS_EXCEPTION_IF_NULL(attr_input);
  440. if (AnfAlgo::IsGraphKernel(cnode)) {
  441. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  442. MS_EXCEPTION_IF_NULL(fg);
  443. auto new_fg = BasicClone(fg);
  444. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  445. } else if (IsValueNode<FuncGraph>(attr_input)) {
  446. // create primitive of cnode:call
  447. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  448. // create a ValueNode<KernelGraph> as input of cnode:call
  449. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  450. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
  451. } else {
  452. auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
  453. if (new_value_node != nullptr) {
  454. cnode_inputs.emplace_back(new_value_node);
  455. }
  456. }
  457. } else if (attr_input->isa<CNode>()) {
  458. // create primitive of cnode:call(switch)
  459. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  460. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  461. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  462. if (!AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
  463. MS_LOG(EXCEPTION) << "CNode input[0] must be switch.";
  464. }
  465. cnode_inputs.emplace_back(cnode_input);
  466. } else {
  467. MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
  468. << ", but input[0] has not been created.";
  469. }
  470. } else {
  471. // get primitive of old node
  472. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  473. MS_EXCEPTION_IF_NULL(prim);
  474. // push attr to inputs[0] of new cnode
  475. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
  476. }
  477. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  478. auto anf = cnode->input(input_idx);
  479. MS_EXCEPTION_IF_NULL(anf);
  480. // anf has been created before
  481. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  482. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  483. continue;
  484. } else if (IsValueNode<None>(anf)) {
  485. continue;
  486. }
  487. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  488. }
  489. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  490. auto new_cnode = graph->NewCNode(cnode_inputs);
  491. TraceManager::EndTrace();
  492. return new_cnode;
  493. }
  494. ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
  495. MS_EXCEPTION_IF_NULL(anf);
  496. MS_EXCEPTION_IF_NULL(graph);
  497. auto value_node = anf->cast<ValueNodePtr>();
  498. MS_EXCEPTION_IF_NULL(value_node);
  499. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
  500. MS_EXCEPTION_IF_NULL(sub_func_graph);
  501. if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
  502. MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
  503. }
  504. auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
  505. ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
  506. new_value_node->set_abstract(value_node->abstract());
  507. // create new kernel_info of new value_node
  508. auto kernel_info = std::make_shared<device::KernelInfo>();
  509. kernel_info->SetFeatureMapFlag(false);
  510. new_value_node->set_kernel_info(kernel_info);
  511. // create kernel_build_info for new value node
  512. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  513. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  514. AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
  515. graph->FrontBackendlMapAdd(anf, new_value_node);
  516. return new_value_node;
  517. }
  518. ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  519. MS_EXCEPTION_IF_NULL(anf);
  520. MS_EXCEPTION_IF_NULL(graph);
  521. if (!anf->isa<Parameter>()) {
  522. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  523. }
  524. auto m_tensor = GetParamDefaultInputTensor(anf);
  525. ParameterPtr new_parameter = nullptr;
  526. if (python_paras_ == nullptr) {
  527. python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
  528. }
  529. auto iter = python_paras_->find(m_tensor);
  530. if (iter != python_paras_->end()) {
  531. new_parameter = iter->second;
  532. } else {
  533. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  534. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  535. if (m_tensor != nullptr) {
  536. (*python_paras_)[m_tensor] = new_parameter;
  537. }
  538. TraceManager::EndTrace();
  539. }
  540. return new_parameter;
  541. }
  542. KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  543. std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
  544. auto graph = NewKernelGraph();
  545. MS_EXCEPTION_IF_NULL(graph);
  546. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  547. size_t from_other_graph_depend_num = 0;
  548. for (const auto &node : lst) {
  549. MS_EXCEPTION_IF_NULL(node);
  550. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  551. if (!node->isa<CNode>()) {
  552. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
  553. }
  554. auto cnode = node->cast<CNodePtr>();
  555. MS_EXCEPTION_IF_NULL(cnode);
  556. // create a new cnode object
  557. bool from_other_graph = false;
  558. // only first depend from other graph can create
  559. bool valid_input = true;
  560. if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
  561. valid_input = false;
  562. }
  563. auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
  564. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
  565. from_other_graph_depend_num++;
  566. }
  567. MS_EXCEPTION_IF_NULL(new_cnode);
  568. new_cnode->set_abstract(cnode->abstract());
  569. new_cnode->set_scope(cnode->scope());
  570. // record map relations between anf from ME and new anf node used in backend
  571. graph->FrontBackendlMapAdd(node, new_cnode);
  572. }
  573. // add a make_tuple at the end of graph as output
  574. graph->set_output(ConstructOutput(outputs, graph));
  575. MS_EXCEPTION_IF_NULL(context_);
  576. FuncGraphManagerPtr manager = MakeManager({graph});
  577. if (manager) {
  578. manager->AddFuncGraph(graph);
  579. graph->set_manager(manager);
  580. }
  581. graph->SetExecOrderByDefault();
  582. if (ExistSummaryNode(graph.get())) {
  583. graph->set_summary_node_exist(true);
  584. }
  585. opt::BackendCommonOptimization(graph);
  586. return graph;
  587. }
  588. void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) {
  589. MS_EXCEPTION_IF_NULL(node);
  590. MS_EXCEPTION_IF_NULL(graph);
  591. auto cnode = node->cast<CNodePtr>();
  592. MS_EXCEPTION_IF_NULL(cnode);
  593. // create a new cnode object
  594. auto new_cnode = CreateNewCNode(cnode, graph.get());
  595. MS_EXCEPTION_IF_NULL(new_cnode);
  596. new_cnode->set_abstract(cnode->abstract());
  597. new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
  598. new_cnode->set_scope(cnode->scope());
  599. graph->FrontBackendlMapAdd(node, new_cnode);
  600. if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
  601. graph->set_return(new_cnode);
  602. }
  603. }
  604. std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
  605. std::vector<KernelGraphPtr> *all_out_graph) {
  606. MS_EXCEPTION_IF_NULL(func_graph);
  607. MS_EXCEPTION_IF_NULL(all_out_graph);
  608. auto node_list = TopoSort(func_graph->get_return());
  609. auto graph = NewKernelGraph();
  610. MS_EXCEPTION_IF_NULL(graph);
  611. front_backend_graph_map_[func_graph] = graph;
  612. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  613. bool is_trace_back = false;
  614. for (const auto &node : node_list) {
  615. MS_EXCEPTION_IF_NULL(node);
  616. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  617. if (node->isa<Parameter>()) {
  618. auto graph_inputs = graph->MutableInputs();
  619. MS_EXCEPTION_IF_NULL(graph_inputs);
  620. auto new_parameter = CreateNewParameter(node, graph.get());
  621. graph_inputs->push_back(new_parameter);
  622. graph->FrontBackendlMapAdd(node, new_parameter);
  623. continue;
  624. } else if (node->isa<ValueNode>()) {
  625. if (!IsValueNode<FuncGraph>(node)) {
  626. // if input is a common value node,
  627. (void)CreateNewValueNode(node, graph.get());
  628. } else {
  629. // if input is a ValueNode<FuncGraph>
  630. FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
  631. if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
  632. is_trace_back = true;
  633. } else {
  634. (void)ConstructKernelGraph(child_graph, all_out_graph);
  635. }
  636. (void)CreateValueNodeKernelGraph(node, graph.get());
  637. }
  638. continue;
  639. } else {
  640. CreateCNodeKernelGraph(node, graph);
  641. }
  642. }
  643. // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
  644. graph->set_output_null(is_trace_back);
  645. AddParameterToGraphInputs(func_graph->parameters(), graph.get());
  646. graph->SetExecOrderByDefault();
  647. if (ExistSummaryNode(graph.get())) {
  648. graph->set_summary_node_exist(true);
  649. }
  650. all_out_graph->push_back(graph);
  651. return graph;
  652. }
  653. void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) {
  654. MS_EXCEPTION_IF_NULL(graph);
  655. auto graph_inputs = graph->MutableInputs();
  656. MS_EXCEPTION_IF_NULL(graph_inputs);
  657. graph_inputs->clear();
  658. for (auto &parameter : parameters) {
  659. MS_EXCEPTION_IF_NULL(parameter);
  660. auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
  661. if (backend_parameter == nullptr) {
  662. // for example "def f(x,y,z) {return x + y}", parameter z in unused
  663. auto new_parameter = CreateNewParameter(parameter, graph);
  664. graph_inputs->push_back(new_parameter);
  665. MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
  666. continue;
  667. }
  668. MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
  669. graph_inputs->push_back(backend_parameter);
  670. }
  671. }
  672. // run graph steps
  673. void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  674. const std::vector<tensor::TensorPtr> &inputs_const) const {
  675. std::vector<tensor::TensorPtr> inputs(inputs_const);
  676. size_t input_ctrl_size = 1;
  677. MS_EXCEPTION_IF_NULL(kernel_graph);
  678. if (kernel_graph->input_ctrl_tensors()) {
  679. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  680. }
  681. auto input_nodes = kernel_graph->inputs();
  682. if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
  683. MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  684. << ", input_ctrl_size:" << input_ctrl_size;
  685. }
  686. auto ms_context = MsContext::GetInstance();
  687. MS_EXCEPTION_IF_NULL(ms_context);
  688. for (size_t i = 0; i < inputs.size(); ++i) {
  689. auto tensor = inputs[i];
  690. MS_EXCEPTION_IF_NULL(tensor);
  691. auto input_node = input_nodes[i];
  692. MS_EXCEPTION_IF_NULL(input_node);
  693. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  694. auto pk_node = input_node->cast<ParameterPtr>();
  695. auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
  696. bool need_sync = false;
  697. if (ms_context->enable_pynative_infer()) {
  698. if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
  699. need_sync = true;
  700. }
  701. } else {
  702. if (tensor->is_dirty()) {
  703. need_sync = true;
  704. } else if (tensor->device_address() != device_address) {
  705. (void)tensor->data_sync();
  706. need_sync = true;
  707. }
  708. }
  709. if (need_sync) {
  710. if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) {
  711. tensor->set_device_address(device_address);
  712. }
  713. MS_EXCEPTION_IF_NULL(device_address);
  714. if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
  715. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  716. tensor->data_c())) {
  717. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  718. }
  719. }
  720. }
  721. tensor->set_dirty(false);
  722. }
  723. }
  724. void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  725. const std::vector<tensor::TensorPtr> &input_tensors) const {
  726. MS_EXCEPTION_IF_NULL(kernel_graph);
  727. MS_EXCEPTION_IF_NULL(outputs);
  728. if (!kernel_graph->child_graph_order().empty()) {
  729. // use the last child graph output as the root graph output
  730. UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors);
  731. return;
  732. }
  733. auto anf_outputs = kernel_graph->outputs();
  734. for (auto &item : anf_outputs) {
  735. MS_EXCEPTION_IF_NULL(item);
  736. MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
  737. if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
  738. outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
  739. continue;
  740. }
  741. outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
  742. }
  743. }
  744. void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
  745. MS_EXCEPTION_IF_NULL(callback);
  746. summary_callback_ = callback;
  747. }
  748. void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
  749. void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
  750. MS_LOG(DEBUG) << "Update summary Start";
  751. MS_EXCEPTION_IF_NULL(graph);
  752. if (!graph->summary_node_exist()) {
  753. return;
  754. }
  755. auto summary = graph->summary_nodes();
  756. auto apply_list = TopoSort(graph->get_return());
  757. for (auto &n : apply_list) {
  758. MS_EXCEPTION_IF_NULL(n);
  759. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  760. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  761. auto cnode = n->cast<CNodePtr>();
  762. MS_EXCEPTION_IF_NULL(cnode);
  763. if (cnode->inputs().size() <= kSummaryGetItem) {
  764. MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!";
  765. }
  766. auto node = cnode->input(kSummaryGetItem);
  767. MS_EXCEPTION_IF_NULL(node);
  768. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  769. MS_EXCEPTION_IF_NULL(item_with_index.first);
  770. if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
  771. MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
  772. }
  773. summary[n->fullname_with_scope()] = item_with_index;
  774. }
  775. }
  776. graph->set_summary_nodes(summary);
  777. MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
  778. }
  779. void SessionBasic::Summary(KernelGraph *graph) {
  780. if (summary_callback_ == nullptr) {
  781. return;
  782. }
  783. MS_EXCEPTION_IF_NULL(graph);
  784. bool exist_summary = graph->summary_node_exist();
  785. if (!exist_summary) {
  786. return;
  787. }
  788. GetSummaryNodes(graph);
  789. auto summary_outputs = graph->summary_nodes();
  790. std::map<std::string, tensor::TensorPtr> params_list;
  791. // fetch outputs apply kernel in session & run callback functions
  792. for (auto &output_item : summary_outputs) {
  793. auto node = output_item.second.first;
  794. size_t index = IntToSize(output_item.second.second);
  795. auto address = AnfAlgo::GetOutputAddr(node, index);
  796. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  797. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  798. std::vector<int> temp_shape;
  799. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  800. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  801. MS_EXCEPTION_IF_NULL(address);
  802. if (!address->GetPtr()) {
  803. continue;
  804. }
  805. if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
  806. tensor->data_type(), tensor->data_c())) {
  807. MS_LOG(ERROR) << "Failed to sync output from device to host.";
  808. }
  809. tensor->set_dirty(false);
  810. params_list[output_item.first] = tensor;
  811. }
  812. // call callback function here
  813. summary_callback_(0, params_list);
  814. }
  815. CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
  816. MS_EXCEPTION_IF_NULL(graph);
  817. std::vector<AnfNodePtr> output_args;
  818. for (const auto &output : outputs) {
  819. MS_EXCEPTION_IF_NULL(output);
  820. MS_LOG(INFO) << "Output:" << output->DebugString();
  821. }
  822. auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
  823. auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
  824. if (backend_anf != nullptr) {
  825. return backend_anf;
  826. }
  827. MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
  828. };
  829. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  830. (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
  831. [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
  832. return graph->NewCNode(output_args);
  833. }
  834. void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
  835. MS_LOG(INFO) << "Start!";
  836. std::vector<AnfNodePtr> make_tuple_inputs;
  837. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  838. MS_EXCEPTION_IF_NULL(graph);
  839. if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
  840. for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
  841. auto idx = NewValueNode(SizeToInt(output_index));
  842. MS_EXCEPTION_IF_NULL(idx);
  843. auto imm = std::make_shared<Int32Imm>(output_index);
  844. idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
  845. auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
  846. std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
  847. std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
  848. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
  849. make_tuple_inputs.push_back(getitem);
  850. }
  851. } else {
  852. make_tuple_inputs.push_back(cnode);
  853. }
  854. // create output
  855. auto g_output = graph->NewCNode(make_tuple_inputs);
  856. graph->set_output(g_output);
  857. // set graph manager,which now is only used to get valuenodes and hardware optimizing
  858. MS_EXCEPTION_IF_NULL(context_);
  859. FuncGraphManagerPtr manager = context_->manager();
  860. if (manager != nullptr) {
  861. manager->AddFuncGraph(graph);
  862. graph->set_manager(manager);
  863. }
  864. MS_LOG(INFO) << "Finish!";
  865. }
  866. std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  867. const std::vector<tensor::TensorPtr> &input_tensors,
  868. const std::vector<int> &tensors_mask) {
  869. auto graph = std::make_shared<KernelGraph>();
  870. std::vector<AnfNodePtr> inputs;
  871. // set input[0]
  872. PrimitivePtr op_prim = op_run_info.py_primitive;
  873. MS_EXCEPTION_IF_NULL(op_prim);
  874. inputs.push_back(std::make_shared<ValueNode>(op_prim));
  875. // set input parameter
  876. MS_LOG(INFO) << "Input tensor size: " << input_tensors.size();
  877. if (input_tensors.size() != tensors_mask.size()) {
  878. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  879. << tensors_mask.size();
  880. }
  881. for (size_t i = 0; i < input_tensors.size(); ++i) {
  882. if (tensors_mask[i] == kValueNodeTensorMask) {
  883. auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]);
  884. inputs.push_back(value_node);
  885. continue;
  886. }
  887. auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
  888. inputs.push_back(parameter);
  889. auto mutable_inputs = graph->MutableInputs();
  890. MS_EXCEPTION_IF_NULL(mutable_inputs);
  891. mutable_inputs->push_back(parameter);
  892. }
  893. // set execution order
  894. auto cnode = graph->NewCNode(inputs);
  895. MS_EXCEPTION_IF_NULL(cnode);
  896. // set abstract,which include inferred shapes and types
  897. cnode->set_abstract(op_run_info.abstract);
  898. // set execution order
  899. std::vector<CNodePtr> exe_order = {cnode};
  900. graph->set_execution_order(exe_order);
  901. // set output
  902. CreateOutputNode(cnode, graph);
  903. return graph;
  904. }
  905. BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
  906. if (utils::isa<VectorRef>(base_ref)) {
  907. auto ref_list = utils::cast<VectorRef>(base_ref);
  908. py::tuple output_tensors(ref_list.size());
  909. for (size_t i = 0; i < ref_list.size(); ++i) {
  910. auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
  911. if (utils::isa<tensor::TensorPtr>(output)) {
  912. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  913. MS_EXCEPTION_IF_NULL(tensor_ptr);
  914. output_tensors[i] = tensor_ptr;
  915. } else if (utils::isa<PyObjectRef>(output)) {
  916. py::object obj = utils::cast<PyObjectRef>(output).object_;
  917. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  918. output_tensors[i] = tensor_tuple;
  919. } else {
  920. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  921. }
  922. }
  923. return output_tensors; // turn tuple to py::object and store in PyObjectRef
  924. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  925. return base_ref;
  926. } else {
  927. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  928. }
  929. }
  930. KernelGraphPtr SessionBasic::NewKernelGraph() {
  931. auto graph = std::make_shared<KernelGraph>();
  932. graph->set_graph_id(graph_sum_);
  933. graphs_[graph_sum_++] = graph;
  934. return graph;
  935. }
  936. } // namespace session
  937. } // namespace mindspore