You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.cc 45 kB

5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/session_basic.h"
  17. #include <utility>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include "pipeline/parse/data_converter.h"
  22. #include "ir/manager.h"
  23. #include "ir/param_value_py.h"
  24. #include "kernel/common_utils.h"
  25. #include "operator/ops.h"
  26. #include "common/trans.h"
  27. #include "utils/context/ms_context.h"
  28. #include "utils/config_manager.h"
  29. #include "session/anf_runtime_algorithm.h"
  30. #include "kernel/oplib/oplib.h"
  31. #include "pre_activate/common/common_backend_optimization.h"
  32. #include "pre_activate/pass/const_input_to_attr_registry.h"
  33. #include "pre_activate/common/helper.h"
  34. #include "common/utils.h"
  35. #include "ir/dtype.h"
  36. #include "ir/anf.h"
  37. #include "ir/func_graph_cloner.h"
  38. namespace mindspore {
  39. namespace session {
  40. static std::shared_ptr<std::map<PyObject *, ParameterPtr>> python_paras_;
  41. void ClearPythonParasMap() { python_paras_ = nullptr; }
  42. namespace {
  43. const int kSummaryGetItem = 2;
  44. PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
  45. if (node == nullptr) {
  46. return nullptr;
  47. }
  48. auto parameter = node->cast<ParameterPtr>();
  49. if (parameter == nullptr || !parameter->has_default()) {
  50. return nullptr;
  51. }
  52. auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
  53. MS_EXCEPTION_IF_NULL(param_value);
  54. auto py_param = param_value->value();
  55. return py_param.ptr();
  56. }
  57. BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
  58. const std::vector<tensor::TensorPtr> &input_tensors) {
  59. MS_EXCEPTION_IF_NULL(node);
  60. MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
  61. // if node is a value node, no need sync addr from device to host
  62. if (!AnfAlgo::OutputAddrExist(node, output_index)) {
  63. if (node->isa<ValueNode>()) {
  64. auto value_node = node->cast<ValueNodePtr>();
  65. MS_EXCEPTION_IF_NULL(value_node);
  66. return value_node->value();
  67. }
  68. if (node->isa<Parameter>()) {
  69. for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
  70. if (input_idx >= input_tensors.size()) {
  71. MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
  72. }
  73. if (graph.inputs()[input_idx] == node) {
  74. return input_tensors[input_idx];
  75. }
  76. }
  77. MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr";
  78. }
  79. }
  80. // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
  81. DeviceAddressPtr address;
  82. auto is_all_nop_node = opt::IsAllNopNode(&graph);
  83. if (is_all_nop_node) {
  84. // The graph does not remove the nop node.
  85. address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
  86. } else {
  87. // The graph removes the nop node.
  88. address = AnfAlgo::GetMutableOutputAddr(node, output_index, true);
  89. }
  90. MS_EXCEPTION_IF_NULL(address);
  91. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  92. TypeId type_id = kNumberTypeFloat32;
  93. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  94. std::vector<int> temp_shape;
  95. if (graph.IsInternalOutput(node)) {
  96. temp_shape.emplace_back(1);
  97. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  98. tensor->set_device_address(address);
  99. tensor->set_dirty(false);
  100. return tensor;
  101. }
  102. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  103. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  104. // if in paynative mode,data only copyed to host when user want to print data
  105. auto ms_context = MsContext::GetInstance();
  106. MS_EXCEPTION_IF_NULL(ms_context);
  107. if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
  108. tensor->set_device_address(address);
  109. tensor->set_dirty(false);
  110. } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
  111. LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
  112. MS_LOG(INFO) << "Output sync device to host error!!!";
  113. tensor->set_dirty(false);
  114. }
  115. return tensor;
  116. }
  117. BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  118. const std::vector<tensor::TensorPtr> &input_tensors) {
  119. MS_EXCEPTION_IF_NULL(anf);
  120. MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
  121. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  122. MS_EXCEPTION_IF_NULL(item_with_index.first);
  123. MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
  124. // special handle for maketuple
  125. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  126. auto cnode = item_with_index.first->cast<CNodePtr>();
  127. MS_EXCEPTION_IF_NULL(cnode);
  128. VectorRef ret;
  129. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  130. auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
  131. ret.push_back(out);
  132. }
  133. return ret;
  134. }
  135. // if is graph return nothing ,the function should return a null anylist
  136. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  137. if (size == 0) {
  138. return VectorRef();
  139. }
  140. return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
  141. }
  142. BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  143. const std::vector<tensor::TensorPtr> &input_tensors) {
  144. MS_EXCEPTION_IF_NULL(anf);
  145. if (!AnfAlgo::IsRealKernel(anf)) {
  146. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel";
  147. }
  148. if (anf->isa<ValueNode>()) {
  149. return CreateOneTensor(anf, 0, graph, input_tensors);
  150. }
  151. VectorRef ret;
  152. if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
  153. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
  154. auto out = CreateOneTensor(anf, i, graph, input_tensors);
  155. ret.emplace_back(out);
  156. }
  157. }
  158. return ret;
  159. }
  160. ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
  161. MS_EXCEPTION_IF_NULL(anf);
  162. MS_EXCEPTION_IF_NULL(graph);
  163. auto value_node = anf->cast<ValueNodePtr>();
  164. MS_EXCEPTION_IF_NULL(value_node);
  165. auto value = value_node->value();
  166. MS_EXCEPTION_IF_NULL(value);
  167. if (value->isa<None>()) {
  168. return nullptr;
  169. }
  170. auto new_value_node = graph->NewValueNode(value_node);
  171. graph->FrontBackendlMapAdd(anf, new_value_node);
  172. graph->AddValueNodeToGraph(new_value_node);
  173. return new_value_node;
  174. }
  175. size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
  176. MS_EXCEPTION_IF_NULL(graph);
  177. MS_LOG(INFO) << "Load kInputCtrlTensors";
  178. auto inputs_params = graph->input_ctrl_tensors();
  179. if (inputs_params == nullptr) {
  180. return 0;
  181. }
  182. if (inputs_params->empty()) {
  183. MS_LOG(EXCEPTION) << "Illegal empty inputs_params";
  184. }
  185. auto tensor = (*inputs_params)[0];
  186. MS_EXCEPTION_IF_NULL(tensor);
  187. auto *val = static_cast<int32_t *>(tensor->data_c());
  188. MS_EXCEPTION_IF_NULL(val);
  189. *val = 0;
  190. tensor->set_dirty(true);
  191. // set loop_count to zero
  192. MS_EXCEPTION_IF_NULL(inputs);
  193. inputs->push_back(tensor);
  194. return inputs_params->size();
  195. }
  196. ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor) {
  197. MS_EXCEPTION_IF_NULL(graph);
  198. MS_EXCEPTION_IF_NULL(input_tensor);
  199. auto value_node = std::make_shared<ValueNode>(input_tensor);
  200. MS_EXCEPTION_IF_NULL(value_node);
  201. // construct abstract of value node
  202. auto type_of_tensor = input_tensor->Dtype();
  203. auto shape_of_tensor = input_tensor->shape();
  204. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  205. value_node->set_abstract(abstract);
  206. // add value node to graph
  207. auto input_value_node = graph->NewValueNode(value_node);
  208. graph->AddValueNodeToGraph(input_value_node);
  209. return input_value_node;
  210. }
  211. ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
  212. int tensor_mask) {
  213. MS_EXCEPTION_IF_NULL(graph);
  214. auto param = graph->NewParameter();
  215. MS_EXCEPTION_IF_NULL(param);
  216. if (tensor_mask == kParameterWeightTensorMask) {
  217. py::object obj;
  218. auto param_value_new = std::make_shared<ParamValuePy>(obj);
  219. param->set_default_param(param_value_new);
  220. }
  221. // set the kernel info of parameter
  222. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  223. MS_EXCEPTION_IF_NULL(input_tensor);
  224. if (input_tensor->device_address().get() == nullptr) {
  225. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  226. TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
  227. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
  228. } else {
  229. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{input_tensor->device_address()->format()});
  230. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()});
  231. }
  232. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
  233. // construct abstract of parameter
  234. auto type_of_tensor = input_tensor->Dtype();
  235. auto shape_of_tensor = input_tensor->shape();
  236. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  237. param->set_abstract(abstract);
  238. return param;
  239. }
  240. void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
  241. MS_LOG(INFO) << "Graph outputs:";
  242. const size_t max_deep = 10;
  243. if (recurse_level > max_deep) {
  244. MS_LOG(INFO) << "Recurse too deep";
  245. return;
  246. }
  247. std::string tab_str;
  248. for (size_t i = 0; i < recurse_level; i++) {
  249. tab_str = tab_str.append(" ");
  250. }
  251. if (any.is<AnyList>()) {
  252. (void)tab_str.append("{");
  253. MS_LOG(INFO) << tab_str;
  254. auto any_list = any.cast<AnyList>();
  255. for (auto &it : any_list) {
  256. DumpGraphOutput(it, recurse_level + 1);
  257. }
  258. (void)tab_str.append("}");
  259. MS_LOG(INFO) << tab_str;
  260. }
  261. (void)tab_str.append(any.ToString());
  262. MS_LOG(INFO) << tab_str;
  263. }
  264. bool ExistSummaryNode(const KernelGraph *graph) {
  265. MS_EXCEPTION_IF_NULL(graph);
  266. auto ret = graph->get_return();
  267. MS_EXCEPTION_IF_NULL(ret);
  268. auto all_nodes = DeepLinkedGraphSearch(ret);
  269. for (auto &n : all_nodes) {
  270. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  271. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  272. return true;
  273. }
  274. }
  275. return false;
  276. }
  277. } // namespace
  278. GraphId SessionBasic::graph_sum_ = 0;
  279. KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) {
  280. auto it = graphs_.find(graph_id);
  281. if (it == graphs_.end()) {
  282. MS_LOG(WARNING) << "Can't find graph " << graph_id;
  283. return nullptr;
  284. }
  285. return it->second;
  286. }
  287. void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter) {
  288. auto graph_id = GetGraphIdByNode(out_node);
  289. if (graph_id == kInvalidGraphId) {
  290. return;
  291. }
  292. auto node_graph = GetGraph(graph_id);
  293. if (node_graph == nullptr) {
  294. return;
  295. }
  296. MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
  297. auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node);
  298. if (ref_node == nullptr) {
  299. MS_LOG(INFO) << "No corresponding internal output for output node";
  300. return;
  301. }
  302. auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0);
  303. auto ref_real_node = real_kernel.first;
  304. auto ref_real_node_index = real_kernel.second;
  305. if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) &&
  306. node_graph->IsFinalOutputKernel(ref_real_node)) {
  307. auto kernel_info = ref_real_node->kernel_info();
  308. if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) {
  309. MS_LOG(INFO) << "No kernel info";
  310. return;
  311. }
  312. auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
  313. if (address == nullptr) {
  314. MS_LOG(INFO) << "No kernel address";
  315. return;
  316. }
  317. auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
  318. auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
  319. parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
  320. auto d_kernel_info = parameter->kernel_info();
  321. MS_EXCEPTION_IF_NULL(d_kernel_info);
  322. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  323. builder.SetOutputsDeviceType({type});
  324. builder.SetOutputsFormat({format});
  325. d_kernel_info->set_select_kernel_build_info(builder.Build());
  326. AnfAlgo::SetOutputAddr(address, 0, parameter.get());
  327. }
  328. }
  329. std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input,
  330. KernelGraph *graph) {
  331. MS_EXCEPTION_IF_NULL(node);
  332. MS_EXCEPTION_IF_NULL(graph);
  333. std::vector<AnfNodePtr> parameters;
  334. std::vector<AnfNodePtr> pre_graph_out = {node};
  335. // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
  336. if (!AnfAlgo::IsRealKernel(node)) {
  337. pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
  338. }
  339. auto valid_inputs = graph->MutableValidInputs();
  340. MS_EXCEPTION_IF_NULL(valid_inputs);
  341. auto graph_inputs = graph->MutableInputs();
  342. MS_EXCEPTION_IF_NULL(graph_inputs);
  343. auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
  344. auto parameter = graph->NewParameter();
  345. MS_EXCEPTION_IF_NULL(parameter);
  346. parameter->set_abstract(abstract);
  347. auto new_parameter = graph->NewParameter(parameter);
  348. parameters.push_back(new_parameter);
  349. valid_inputs->push_back(valid_input);
  350. graph_inputs->push_back(new_parameter);
  351. };
  352. for (const auto &out_node : pre_graph_out) {
  353. MS_EXCEPTION_IF_NULL(out_node);
  354. auto abstract = out_node->abstract();
  355. MS_EXCEPTION_IF_NULL(abstract);
  356. // create multiple parameters if is a tuple output real kernel
  357. if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  358. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  359. MS_EXCEPTION_IF_NULL(tuple_abstract);
  360. MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
  361. for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
  362. create_parameter((*tuple_abstract)[output_idx]);
  363. }
  364. continue;
  365. }
  366. // create single parameter if is a abstract real kernel
  367. create_parameter(out_node->abstract());
  368. InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]);
  369. }
  370. return parameters;
  371. }
  372. ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input,
  373. KernelGraph *graph) {
  374. MS_EXCEPTION_IF_NULL(anf);
  375. if (!anf->isa<Parameter>()) {
  376. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  377. }
  378. MS_EXCEPTION_IF_NULL(graph);
  379. auto m_tensor = GetParamDefaultInputTensor(anf);
  380. auto valid_inputs = graph->MutableValidInputs();
  381. MS_EXCEPTION_IF_NULL(valid_inputs);
  382. auto graph_inputs = graph->MutableInputs();
  383. MS_EXCEPTION_IF_NULL(graph_inputs);
  384. ParameterPtr new_parameter = nullptr;
  385. // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
  386. if (python_paras_ == nullptr) {
  387. python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
  388. }
  389. auto iter = python_paras_->find(m_tensor);
  390. if (iter != python_paras_->end()) {
  391. new_parameter = iter->second;
  392. } else {
  393. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  394. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  395. if (m_tensor != nullptr) {
  396. (*python_paras_)[m_tensor] = new_parameter;
  397. }
  398. TraceManager::EndTrace();
  399. }
  400. graph_inputs->push_back(new_parameter);
  401. valid_inputs->push_back(valid_input);
  402. return new_parameter;
  403. }
  404. AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
  405. MS_EXCEPTION_IF_NULL(anf);
  406. MS_EXCEPTION_IF_NULL(graph);
  407. MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
  408. auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
  409. if (parameters.empty()) {
  410. MS_LOG(EXCEPTION) << "No parameter exist!!";
  411. }
  412. if (parameters.size() == 1) {
  413. return parameters[0];
  414. }
  415. std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
  416. (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
  417. auto make_tuple = graph->NewCNode(make_tuple_input);
  418. MS_EXCEPTION_IF_NULL(make_tuple);
  419. MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
  420. return make_tuple;
  421. }
  422. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
  423. bool *from_other_graph,
  424. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  425. MS_EXCEPTION_IF_NULL(cnode);
  426. MS_EXCEPTION_IF_NULL(graph);
  427. MS_EXCEPTION_IF_NULL(from_other_graph);
  428. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  429. *from_other_graph = false;
  430. // get primitive of old node
  431. std::vector<AnfNodePtr> cnode_inputs;
  432. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  433. if (prim != nullptr) {
  434. // push attr to inputs[0] of new cnode
  435. cnode_inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
  436. } else {
  437. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  438. MS_EXCEPTION_IF_NULL(fg);
  439. auto new_fg = BasicClone(fg);
  440. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  441. }
  442. auto origin_inputs = cnode->inputs();
  443. bool optimize_depend = false;
  444. if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
  445. origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) {
  446. optimize_depend = true;
  447. }
  448. // if has multiple depends,only select first depend as parameter
  449. for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
  450. auto anf = origin_inputs[input_idx];
  451. MS_EXCEPTION_IF_NULL(anf);
  452. // anf has been created before
  453. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  454. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  455. continue;
  456. } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
  457. cnode_inputs.push_back((*other_graph_cnode)[anf]);
  458. continue;
  459. } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
  460. // if input is a value node,
  461. auto new_value_node = CreateNewValueNode(anf, graph);
  462. if (new_value_node != nullptr) {
  463. cnode_inputs.emplace_back(new_value_node);
  464. }
  465. continue;
  466. } else if (anf->isa<Parameter>() && AnfAlgo::GetOutputTensorNum(anf) == 1) {
  467. auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
  468. cnode_inputs.push_back(new_parameter);
  469. if (GetGraphIdByNode(anf) == kInvalidGraphId) {
  470. graph->FrontBackendlMapAdd(anf, new_parameter);
  471. } else {
  472. (*other_graph_cnode)[anf] = new_parameter;
  473. }
  474. continue;
  475. } else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
  476. cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
  477. continue;
  478. } else if (anf->isa<AnfNode>()) {
  479. *from_other_graph = true;
  480. // the input node is a cnode from other graph
  481. auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
  482. cnode_inputs.push_back(parameter_from_cnode);
  483. (*other_graph_cnode)[anf] = parameter_from_cnode;
  484. continue;
  485. }
  486. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  487. }
  488. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  489. auto new_cnode = graph->NewCNode(cnode_inputs);
  490. TraceManager::EndTrace();
  491. return new_cnode;
  492. }
  493. static std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
  494. MS_EXCEPTION_IF_NULL(cnode);
  495. MS_EXCEPTION_IF_NULL(graph);
  496. // create primitive of cnode:call(partial or switch)
  497. std::vector<AnfNodePtr> cnode_inputs = {
  498. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  499. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  500. MS_EXCEPTION_IF_NULL(attr_input);
  501. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  502. if (cnode_input == nullptr) {
  503. MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
  504. << ", but input[0] has not been created.";
  505. }
  506. // if the node is partial, insert the inputs of partial to the call
  507. if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
  508. auto partial_node = attr_input->cast<CNodePtr>();
  509. MS_EXCEPTION_IF_NULL(partial_node);
  510. auto partial_inputs = partial_node->inputs();
  511. std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
  512. std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
  513. MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
  514. return graph->GetBackendAnfByFrontAnf(node);
  515. });
  516. return cnode_inputs;
  517. } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
  518. cnode_inputs.emplace_back(cnode_input);
  519. return cnode_inputs;
  520. }
  521. MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
  522. }
  523. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
  524. MS_EXCEPTION_IF_NULL(cnode);
  525. MS_EXCEPTION_IF_NULL(graph);
  526. std::vector<AnfNodePtr> cnode_inputs;
  527. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  528. MS_EXCEPTION_IF_NULL(attr_input);
  529. if (AnfAlgo::IsGraphKernel(cnode)) {
  530. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  531. MS_EXCEPTION_IF_NULL(fg);
  532. auto new_fg = BasicClone(fg);
  533. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  534. } else if (IsValueNode<FuncGraph>(attr_input)) {
  535. // create primitive of cnode:call
  536. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  537. // create a ValueNode<KernelGraph> as input of cnode:call
  538. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  539. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
  540. } else {
  541. auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
  542. if (new_value_node != nullptr) {
  543. cnode_inputs.emplace_back(new_value_node);
  544. }
  545. }
  546. } else if (attr_input->isa<CNode>()) {
  547. cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
  548. } else {
  549. // get primitive of old node
  550. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  551. MS_EXCEPTION_IF_NULL(prim);
  552. // push attr to inputs[0] of new cnode
  553. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
  554. }
  555. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  556. auto anf = cnode->input(input_idx);
  557. MS_EXCEPTION_IF_NULL(anf);
  558. // anf has been created before
  559. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  560. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  561. continue;
  562. } else if (IsValueNode<None>(anf)) {
  563. continue;
  564. }
  565. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  566. }
  567. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  568. auto new_cnode = graph->NewCNode(cnode_inputs);
  569. TraceManager::EndTrace();
  570. return new_cnode;
  571. }
  572. ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
  573. MS_EXCEPTION_IF_NULL(anf);
  574. MS_EXCEPTION_IF_NULL(graph);
  575. auto value_node = anf->cast<ValueNodePtr>();
  576. MS_EXCEPTION_IF_NULL(value_node);
  577. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
  578. MS_EXCEPTION_IF_NULL(sub_func_graph);
  579. if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
  580. MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
  581. }
  582. auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
  583. ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
  584. new_value_node->set_abstract(value_node->abstract());
  585. // create new kernel_info of new value_node
  586. auto kernel_info = std::make_shared<device::KernelInfo>();
  587. kernel_info->SetFeatureMapFlag(false);
  588. new_value_node->set_kernel_info(kernel_info);
  589. // create kernel_build_info for new value node
  590. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  591. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  592. AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
  593. graph->FrontBackendlMapAdd(anf, new_value_node);
  594. return new_value_node;
  595. }
  596. ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  597. MS_EXCEPTION_IF_NULL(anf);
  598. MS_EXCEPTION_IF_NULL(graph);
  599. if (!anf->isa<Parameter>()) {
  600. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  601. }
  602. auto m_tensor = GetParamDefaultInputTensor(anf);
  603. ParameterPtr new_parameter = nullptr;
  604. if (python_paras_ == nullptr) {
  605. python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
  606. }
  607. auto iter = python_paras_->find(m_tensor);
  608. if (iter != python_paras_->end()) {
  609. new_parameter = iter->second;
  610. } else {
  611. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  612. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  613. if (m_tensor != nullptr) {
  614. (*python_paras_)[m_tensor] = new_parameter;
  615. }
  616. TraceManager::EndTrace();
  617. }
  618. return new_parameter;
  619. }
  620. KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  621. std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
  622. auto graph = NewKernelGraph();
  623. MS_EXCEPTION_IF_NULL(graph);
  624. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  625. size_t from_other_graph_depend_num = 0;
  626. for (const auto &node : lst) {
  627. MS_EXCEPTION_IF_NULL(node);
  628. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  629. if (!node->isa<CNode>()) {
  630. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
  631. }
  632. auto cnode = node->cast<CNodePtr>();
  633. MS_EXCEPTION_IF_NULL(cnode);
  634. // create a new cnode object
  635. bool from_other_graph = false;
  636. // only first depend from other graph can create
  637. bool valid_input = true;
  638. if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
  639. valid_input = false;
  640. }
  641. auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
  642. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
  643. from_other_graph_depend_num++;
  644. }
  645. MS_EXCEPTION_IF_NULL(new_cnode);
  646. new_cnode->set_abstract(cnode->abstract());
  647. new_cnode->set_scope(cnode->scope());
  648. // record map relations between anf from ME and new anf node used in backend
  649. graph->FrontBackendlMapAdd(node, new_cnode);
  650. }
  651. // add a make_tuple at the end of graph as output
  652. graph->set_output(ConstructOutput(outputs, graph));
  653. MS_EXCEPTION_IF_NULL(context_);
  654. FuncGraphManagerPtr manager = MakeManager({graph});
  655. if (manager) {
  656. manager->AddFuncGraph(graph);
  657. graph->set_manager(manager);
  658. }
  659. graph->SetExecOrderByDefault();
  660. if (ExistSummaryNode(graph.get())) {
  661. graph->set_summary_node_exist(true);
  662. }
  663. opt::BackendCommonOptimization(graph);
  664. return graph;
  665. }
  666. void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) {
  667. MS_EXCEPTION_IF_NULL(node);
  668. MS_EXCEPTION_IF_NULL(graph);
  669. auto cnode = node->cast<CNodePtr>();
  670. MS_EXCEPTION_IF_NULL(cnode);
  671. // create a new cnode object
  672. auto new_cnode = CreateNewCNode(cnode, graph.get());
  673. MS_EXCEPTION_IF_NULL(new_cnode);
  674. new_cnode->set_abstract(cnode->abstract());
  675. new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
  676. new_cnode->set_scope(cnode->scope());
  677. graph->FrontBackendlMapAdd(node, new_cnode);
  678. if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
  679. graph->set_return(new_cnode);
  680. }
  681. }
  682. std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
  683. std::vector<KernelGraphPtr> *all_out_graph) {
  684. MS_EXCEPTION_IF_NULL(func_graph);
  685. MS_EXCEPTION_IF_NULL(all_out_graph);
  686. auto node_list = TopoSort(func_graph->get_return());
  687. auto graph = NewKernelGraph();
  688. MS_EXCEPTION_IF_NULL(graph);
  689. front_backend_graph_map_[func_graph] = graph;
  690. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  691. bool is_trace_back = false;
  692. for (const auto &node : node_list) {
  693. MS_EXCEPTION_IF_NULL(node);
  694. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  695. if (node->isa<Parameter>()) {
  696. auto graph_inputs = graph->MutableInputs();
  697. MS_EXCEPTION_IF_NULL(graph_inputs);
  698. auto new_parameter = CreateNewParameter(node, graph.get());
  699. graph_inputs->push_back(new_parameter);
  700. graph->FrontBackendlMapAdd(node, new_parameter);
  701. continue;
  702. } else if (node->isa<ValueNode>()) {
  703. if (!IsValueNode<FuncGraph>(node)) {
  704. // if input is a common value node,
  705. (void)CreateNewValueNode(node, graph.get());
  706. } else {
  707. // if input is a ValueNode<FuncGraph>
  708. FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
  709. if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
  710. is_trace_back = true;
  711. } else {
  712. (void)ConstructKernelGraph(child_graph, all_out_graph);
  713. }
  714. (void)CreateValueNodeKernelGraph(node, graph.get());
  715. }
  716. continue;
  717. } else {
  718. CreateCNodeKernelGraph(node, graph);
  719. }
  720. }
  721. // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
  722. graph->set_output_null(is_trace_back);
  723. AddParameterToGraphInputs(func_graph->parameters(), graph.get());
  724. graph->SetExecOrderByDefault();
  725. if (ExistSummaryNode(graph.get())) {
  726. graph->set_summary_node_exist(true);
  727. }
  728. all_out_graph->push_back(graph);
  729. return graph;
  730. }
  731. void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) {
  732. MS_EXCEPTION_IF_NULL(graph);
  733. auto graph_inputs = graph->MutableInputs();
  734. MS_EXCEPTION_IF_NULL(graph_inputs);
  735. graph_inputs->clear();
  736. for (auto &parameter : parameters) {
  737. MS_EXCEPTION_IF_NULL(parameter);
  738. auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
  739. if (backend_parameter == nullptr) {
  740. // for example "def f(x,y,z) {return x + y}", parameter z in unused
  741. auto new_parameter = CreateNewParameter(parameter, graph);
  742. graph_inputs->push_back(new_parameter);
  743. MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
  744. continue;
  745. }
  746. MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
  747. graph_inputs->push_back(backend_parameter);
  748. }
  749. }
  750. // run graph steps
  751. void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  752. const std::vector<tensor::TensorPtr> &inputs_const) const {
  753. std::vector<tensor::TensorPtr> inputs(inputs_const);
  754. size_t input_ctrl_size = 1;
  755. MS_EXCEPTION_IF_NULL(kernel_graph);
  756. if (kernel_graph->input_ctrl_tensors()) {
  757. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  758. }
  759. auto input_nodes = kernel_graph->inputs();
  760. if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
  761. MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  762. << ", input_ctrl_size:" << input_ctrl_size;
  763. }
  764. auto ms_context = MsContext::GetInstance();
  765. MS_EXCEPTION_IF_NULL(ms_context);
  766. for (size_t i = 0; i < inputs.size(); ++i) {
  767. auto tensor = inputs[i];
  768. MS_EXCEPTION_IF_NULL(tensor);
  769. auto input_node = input_nodes[i];
  770. MS_EXCEPTION_IF_NULL(input_node);
  771. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  772. auto pk_node = input_node->cast<ParameterPtr>();
  773. auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
  774. bool need_sync = false;
  775. if (ms_context->enable_pynative_infer()) {
  776. if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
  777. need_sync = true;
  778. }
  779. } else {
  780. if (tensor->is_dirty()) {
  781. need_sync = true;
  782. } else if (tensor->device_address() != device_address) {
  783. (void)tensor->data_sync();
  784. need_sync = true;
  785. }
  786. }
  787. if (need_sync) {
  788. if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) {
  789. tensor->set_device_address(device_address);
  790. }
  791. MS_EXCEPTION_IF_NULL(device_address);
  792. if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
  793. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  794. tensor->data_c())) {
  795. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  796. }
  797. }
  798. }
  799. tensor->set_dirty(false);
  800. }
  801. }
  802. void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  803. const std::vector<tensor::TensorPtr> &input_tensors) const {
  804. MS_EXCEPTION_IF_NULL(kernel_graph);
  805. MS_EXCEPTION_IF_NULL(outputs);
  806. if (!kernel_graph->child_graph_order().empty()) {
  807. // use the last child graph output as the root graph output
  808. UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors);
  809. return;
  810. }
  811. auto anf_outputs = kernel_graph->outputs();
  812. for (auto &item : anf_outputs) {
  813. MS_EXCEPTION_IF_NULL(item);
  814. MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
  815. if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
  816. outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
  817. continue;
  818. }
  819. outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
  820. }
  821. }
  822. void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
  823. MS_EXCEPTION_IF_NULL(callback);
  824. summary_callback_ = callback;
  825. }
  826. void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
  827. void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
  828. MS_LOG(DEBUG) << "Update summary Start";
  829. MS_EXCEPTION_IF_NULL(graph);
  830. if (!graph->summary_node_exist()) {
  831. return;
  832. }
  833. auto summary = graph->summary_nodes();
  834. auto apply_list = TopoSort(graph->get_return());
  835. for (auto &n : apply_list) {
  836. MS_EXCEPTION_IF_NULL(n);
  837. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  838. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  839. auto cnode = n->cast<CNodePtr>();
  840. MS_EXCEPTION_IF_NULL(cnode);
  841. if (cnode->inputs().size() <= kSummaryGetItem) {
  842. MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!";
  843. }
  844. auto node = cnode->input(kSummaryGetItem);
  845. MS_EXCEPTION_IF_NULL(node);
  846. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  847. MS_EXCEPTION_IF_NULL(item_with_index.first);
  848. if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
  849. MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
  850. }
  851. summary[n->fullname_with_scope()] = item_with_index;
  852. }
  853. }
  854. graph->set_summary_nodes(summary);
  855. MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
  856. }
  857. void SessionBasic::Summary(KernelGraph *graph) {
  858. if (summary_callback_ == nullptr) {
  859. return;
  860. }
  861. MS_EXCEPTION_IF_NULL(graph);
  862. bool exist_summary = graph->summary_node_exist();
  863. if (!exist_summary) {
  864. return;
  865. }
  866. GetSummaryNodes(graph);
  867. auto summary_outputs = graph->summary_nodes();
  868. std::map<std::string, tensor::TensorPtr> params_list;
  869. // fetch outputs apply kernel in session & run callback functions
  870. for (auto &output_item : summary_outputs) {
  871. auto node = output_item.second.first;
  872. size_t index = IntToSize(output_item.second.second);
  873. auto address = AnfAlgo::GetOutputAddr(node, index);
  874. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  875. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  876. std::vector<int> temp_shape;
  877. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  878. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  879. MS_EXCEPTION_IF_NULL(address);
  880. if (!address->GetPtr()) {
  881. continue;
  882. }
  883. if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
  884. tensor->data_type(), tensor->data_c())) {
  885. MS_LOG(ERROR) << "Failed to sync output from device to host.";
  886. }
  887. tensor->set_dirty(false);
  888. params_list[output_item.first] = tensor;
  889. }
  890. // call callback function here
  891. summary_callback_(0, params_list);
  892. }
  893. CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
  894. MS_EXCEPTION_IF_NULL(graph);
  895. std::vector<AnfNodePtr> output_args;
  896. for (const auto &output : outputs) {
  897. MS_EXCEPTION_IF_NULL(output);
  898. MS_LOG(INFO) << "Output:" << output->DebugString();
  899. }
  900. auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
  901. auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
  902. if (backend_anf != nullptr) {
  903. auto front_real_kernel = AnfAlgo::VisitKernel(out, 0);
  904. auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0);
  905. MS_EXCEPTION_IF_NULL(out);
  906. auto out_func_graph = out->func_graph();
  907. MS_EXCEPTION_IF_NULL(out_func_graph);
  908. auto out_func_graph_manager = out_func_graph->manager();
  909. if (out_func_graph_manager == nullptr) {
  910. return backend_anf;
  911. }
  912. auto node_users = out_func_graph_manager->node_users();
  913. auto users = node_users[out];
  914. bool internal_output = true;
  915. std::string kernel_target = GetCNodeTarget(front_real_kernel.first);
  916. for (auto user : users) {
  917. if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) {
  918. internal_output = false;
  919. break;
  920. }
  921. }
  922. if (internal_output) {
  923. MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString();
  924. graph->AddInternalOutput(out, backend_real_kernel.first);
  925. }
  926. return backend_anf;
  927. }
  928. MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
  929. };
  930. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  931. (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
  932. [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
  933. return graph->NewCNode(output_args);
  934. }
  935. void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
  936. MS_LOG(INFO) << "Start!";
  937. std::vector<AnfNodePtr> make_tuple_inputs;
  938. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  939. MS_EXCEPTION_IF_NULL(graph);
  940. if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
  941. for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
  942. auto idx = NewValueNode(SizeToInt(output_index));
  943. MS_EXCEPTION_IF_NULL(idx);
  944. auto imm = std::make_shared<Int32Imm>(output_index);
  945. idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
  946. auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
  947. std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
  948. std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
  949. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
  950. make_tuple_inputs.push_back(getitem);
  951. }
  952. } else {
  953. make_tuple_inputs.push_back(cnode);
  954. }
  955. // create output
  956. auto g_output = graph->NewCNode(make_tuple_inputs);
  957. graph->set_output(g_output);
  958. // set graph manager,which now is only used to get valuenodes and hardware optimizing
  959. MS_EXCEPTION_IF_NULL(context_);
  960. FuncGraphManagerPtr manager = context_->manager();
  961. if (manager != nullptr) {
  962. manager->AddFuncGraph(graph);
  963. graph->set_manager(manager);
  964. }
  965. MS_LOG(INFO) << "Finish!";
  966. }
  967. std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  968. const std::vector<tensor::TensorPtr> &input_tensors,
  969. const std::vector<int> &tensors_mask) {
  970. auto graph = std::make_shared<KernelGraph>();
  971. std::vector<AnfNodePtr> inputs;
  972. // set input[0]
  973. PrimitivePtr op_prim = op_run_info.py_primitive;
  974. MS_EXCEPTION_IF_NULL(op_prim);
  975. inputs.push_back(std::make_shared<ValueNode>(op_prim));
  976. // set input parameter
  977. MS_LOG(INFO) << "Input tensor size: " << input_tensors.size();
  978. if (input_tensors.size() != tensors_mask.size()) {
  979. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  980. << tensors_mask.size();
  981. }
  982. for (size_t i = 0; i < input_tensors.size(); ++i) {
  983. if (tensors_mask[i] == kValueNodeTensorMask) {
  984. auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]);
  985. inputs.push_back(value_node);
  986. continue;
  987. }
  988. auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
  989. inputs.push_back(parameter);
  990. auto mutable_inputs = graph->MutableInputs();
  991. MS_EXCEPTION_IF_NULL(mutable_inputs);
  992. mutable_inputs->push_back(parameter);
  993. }
  994. // set execution order
  995. auto cnode = graph->NewCNode(inputs);
  996. MS_EXCEPTION_IF_NULL(cnode);
  997. // set abstract,which include inferred shapes and types
  998. cnode->set_abstract(op_run_info.abstract);
  999. // set execution order
  1000. std::vector<CNodePtr> exe_order = {cnode};
  1001. graph->set_execution_order(exe_order);
  1002. // set output
  1003. CreateOutputNode(cnode, graph);
  1004. return graph;
  1005. }
  1006. BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
  1007. if (utils::isa<VectorRef>(base_ref)) {
  1008. auto ref_list = utils::cast<VectorRef>(base_ref);
  1009. py::tuple output_tensors(ref_list.size());
  1010. for (size_t i = 0; i < ref_list.size(); ++i) {
  1011. auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
  1012. if (utils::isa<tensor::TensorPtr>(output)) {
  1013. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  1014. MS_EXCEPTION_IF_NULL(tensor_ptr);
  1015. output_tensors[i] = tensor_ptr;
  1016. } else if (utils::isa<PyObjectRef>(output)) {
  1017. py::object obj = utils::cast<PyObjectRef>(output).object_;
  1018. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  1019. output_tensors[i] = tensor_tuple;
  1020. } else {
  1021. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  1022. }
  1023. }
  1024. return output_tensors; // turn tuple to py::object and store in PyObjectRef
  1025. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  1026. return base_ref;
  1027. } else {
  1028. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  1029. }
  1030. }
  1031. KernelGraphPtr SessionBasic::NewKernelGraph() {
  1032. auto graph = std::make_shared<KernelGraph>();
  1033. graph->set_graph_id(graph_sum_);
  1034. graphs_[graph_sum_++] = graph;
  1035. return graph;
  1036. }
  1037. } // namespace session
  1038. } // namespace mindspore