You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.cc 45 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/session_basic.h"
  17. #include <utility>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include "pipeline/parse/data_converter.h"
  22. #include "ir/manager.h"
  23. #include "ir/param_value.h"
  24. #include "kernel/common_utils.h"
  25. #include "operator/ops.h"
  26. #include "common/trans.h"
  27. #include "utils/context/ms_context.h"
  28. #include "utils/config_manager.h"
  29. #include "session/anf_runtime_algorithm.h"
  30. #include "kernel/oplib/oplib.h"
  31. #include "pre_activate/common/common_backend_optimization.h"
  32. #include "pre_activate/pass/const_input_to_attr_registry.h"
  33. #include "pre_activate/common/helper.h"
  34. #include "common/utils.h"
  35. #include "ir/dtype.h"
  36. #include "ir/anf.h"
  37. #include "ir/func_graph_cloner.h"
  38. namespace mindspore {
  39. namespace session {
  40. static std::shared_ptr<std::map<ParamValuePtr, ParameterPtr>> python_paras_;
  41. void ClearPythonParasMap() { python_paras_ = nullptr; }
  42. namespace {
  43. const int kSummaryGetItem = 2;
  44. ParamValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
  45. if (node == nullptr) {
  46. return nullptr;
  47. }
  48. auto parameter = node->cast<ParameterPtr>();
  49. if (parameter == nullptr || !parameter->has_default()) {
  50. return nullptr;
  51. }
  52. return parameter->default_param();
  53. }
  54. BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
  55. const std::vector<tensor::TensorPtr> &input_tensors) {
  56. MS_EXCEPTION_IF_NULL(node);
  57. MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
  58. // if node is a value node, no need sync addr from device to host
  59. if (!AnfAlgo::OutputAddrExist(node, output_index)) {
  60. if (node->isa<ValueNode>()) {
  61. auto value_node = node->cast<ValueNodePtr>();
  62. MS_EXCEPTION_IF_NULL(value_node);
  63. return value_node->value();
  64. }
  65. if (node->isa<Parameter>()) {
  66. for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
  67. if (input_idx >= input_tensors.size()) {
  68. MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
  69. }
  70. if (graph.inputs()[input_idx] == node) {
  71. return input_tensors[input_idx];
  72. }
  73. }
  74. MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr";
  75. }
  76. }
  77. // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
  78. auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
  79. MS_EXCEPTION_IF_NULL(address);
  80. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  81. TypeId type_id = kNumberTypeFloat32;
  82. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  83. std::vector<int> temp_shape;
  84. if (graph.IsInternalOutput(node)) {
  85. temp_shape.emplace_back(1);
  86. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  87. tensor->set_device_address(address);
  88. tensor->set_dirty(false);
  89. return tensor;
  90. }
  91. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  92. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  93. // if in paynative mode,data only copyed to host when user want to print data
  94. auto ms_context = MsContext::GetInstance();
  95. MS_EXCEPTION_IF_NULL(ms_context);
  96. if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
  97. tensor->set_device_address(address);
  98. tensor->set_dirty(false);
  99. } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
  100. LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
  101. MS_LOG(INFO) << "Output sync device to host error!!!";
  102. tensor->set_dirty(false);
  103. }
  104. return tensor;
  105. }
  106. BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  107. const std::vector<tensor::TensorPtr> &input_tensors) {
  108. MS_EXCEPTION_IF_NULL(anf);
  109. MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
  110. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  111. MS_EXCEPTION_IF_NULL(item_with_index.first);
  112. MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
  113. // special handle for maketuple
  114. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  115. auto cnode = item_with_index.first->cast<CNodePtr>();
  116. MS_EXCEPTION_IF_NULL(cnode);
  117. VectorRef ret;
  118. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  119. auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
  120. ret.push_back(out);
  121. }
  122. return ret;
  123. }
  124. // if is graph return nothing ,the function should return a null anylist
  125. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  126. if (size == 0) {
  127. return VectorRef();
  128. }
  129. return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
  130. }
  131. BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  132. const std::vector<tensor::TensorPtr> &input_tensors) {
  133. MS_EXCEPTION_IF_NULL(anf);
  134. if (!AnfAlgo::IsRealKernel(anf)) {
  135. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel";
  136. }
  137. if (anf->isa<ValueNode>()) {
  138. return CreateOneTensor(anf, 0, graph, input_tensors);
  139. }
  140. VectorRef ret;
  141. if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
  142. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
  143. auto out = CreateOneTensor(anf, i, graph, input_tensors);
  144. ret.emplace_back(out);
  145. }
  146. }
  147. return ret;
  148. }
  149. ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
  150. MS_EXCEPTION_IF_NULL(anf);
  151. MS_EXCEPTION_IF_NULL(graph);
  152. auto value_node = anf->cast<ValueNodePtr>();
  153. MS_EXCEPTION_IF_NULL(value_node);
  154. auto value = value_node->value();
  155. MS_EXCEPTION_IF_NULL(value);
  156. if (value->isa<None>()) {
  157. return nullptr;
  158. }
  159. auto new_value_node = graph->NewValueNode(value_node);
  160. graph->FrontBackendlMapAdd(anf, new_value_node);
  161. graph->AddValueNodeToGraph(new_value_node);
  162. return new_value_node;
  163. }
  164. size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
  165. MS_EXCEPTION_IF_NULL(graph);
  166. MS_LOG(INFO) << "Load kInputCtrlTensors";
  167. auto inputs_params = graph->input_ctrl_tensors();
  168. if (inputs_params == nullptr) {
  169. return 0;
  170. }
  171. if (inputs_params->empty()) {
  172. MS_LOG(EXCEPTION) << "Illegal empty inputs_params";
  173. }
  174. auto tensor = (*inputs_params)[0];
  175. MS_EXCEPTION_IF_NULL(tensor);
  176. auto *val = static_cast<int32_t *>(tensor->data_c());
  177. MS_EXCEPTION_IF_NULL(val);
  178. *val = 0;
  179. tensor->set_dirty(true);
  180. // set loop_count to zero
  181. MS_EXCEPTION_IF_NULL(inputs);
  182. inputs->push_back(tensor);
  183. return inputs_params->size();
  184. }
  185. ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor) {
  186. MS_EXCEPTION_IF_NULL(graph);
  187. MS_EXCEPTION_IF_NULL(input_tensor);
  188. auto value_node = std::make_shared<ValueNode>(input_tensor);
  189. MS_EXCEPTION_IF_NULL(value_node);
  190. // construct abstract of value node
  191. auto type_of_tensor = input_tensor->Dtype();
  192. auto shape_of_tensor = input_tensor->shape();
  193. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  194. value_node->set_abstract(abstract);
  195. // add value node to graph
  196. auto input_value_node = graph->NewValueNode(value_node);
  197. graph->AddValueNodeToGraph(input_value_node);
  198. return input_value_node;
  199. }
  200. ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
  201. int tensor_mask) {
  202. MS_EXCEPTION_IF_NULL(graph);
  203. auto param = graph->NewParameter();
  204. MS_EXCEPTION_IF_NULL(param);
  205. if (tensor_mask == kParameterWeightTensorMask) {
  206. auto param_value_new = std::make_shared<ParamValue>();
  207. param->set_default_param(param_value_new);
  208. }
  209. // set the kernel info of parameter
  210. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  211. MS_EXCEPTION_IF_NULL(input_tensor);
  212. if (input_tensor->device_address().get() == nullptr) {
  213. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  214. TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
  215. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
  216. } else {
  217. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{input_tensor->device_address()->format()});
  218. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()});
  219. }
  220. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
  221. // construct abstract of parameter
  222. auto type_of_tensor = input_tensor->Dtype();
  223. auto shape_of_tensor = input_tensor->shape();
  224. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  225. param->set_abstract(abstract);
  226. return param;
  227. }
  228. void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
  229. MS_LOG(INFO) << "Graph outputs:";
  230. const size_t max_deep = 10;
  231. if (recurse_level > max_deep) {
  232. MS_LOG(INFO) << "Recurse too deep";
  233. return;
  234. }
  235. std::string tab_str;
  236. for (size_t i = 0; i < recurse_level; i++) {
  237. tab_str = tab_str.append(" ");
  238. }
  239. if (any.is<AnyList>()) {
  240. (void)tab_str.append("{");
  241. MS_LOG(INFO) << tab_str;
  242. auto any_list = any.cast<AnyList>();
  243. for (auto &it : any_list) {
  244. DumpGraphOutput(it, recurse_level + 1);
  245. }
  246. (void)tab_str.append("}");
  247. MS_LOG(INFO) << tab_str;
  248. }
  249. (void)tab_str.append(any.ToString());
  250. MS_LOG(INFO) << tab_str;
  251. }
  252. bool ExistSummaryNode(const KernelGraph *graph) {
  253. MS_EXCEPTION_IF_NULL(graph);
  254. auto ret = graph->get_return();
  255. MS_EXCEPTION_IF_NULL(ret);
  256. auto all_nodes = DeepLinkedGraphSearch(ret);
  257. for (auto &n : all_nodes) {
  258. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  259. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  260. return true;
  261. }
  262. }
  263. return false;
  264. }
  265. } // namespace
  266. GraphId SessionBasic::graph_sum_ = 0;
  267. KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) {
  268. auto it = graphs_.find(graph_id);
  269. if (it == graphs_.end()) {
  270. MS_LOG(WARNING) << "Can't find graph " << graph_id;
  271. return nullptr;
  272. }
  273. return it->second;
  274. }
  275. void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter) {
  276. auto graph_id = GetGraphIdByNode(out_node);
  277. if (graph_id == kInvalidGraphId) {
  278. return;
  279. }
  280. auto node_graph = GetGraph(graph_id);
  281. if (node_graph == nullptr) {
  282. return;
  283. }
  284. MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
  285. auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node);
  286. if (ref_node == nullptr) {
  287. MS_LOG(INFO) << "No corresponding internal output for output node";
  288. return;
  289. }
  290. auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0);
  291. auto ref_real_node = real_kernel.first;
  292. auto ref_real_node_index = real_kernel.second;
  293. if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) &&
  294. node_graph->IsFinalOutputKernel(ref_real_node)) {
  295. auto kernel_info = ref_real_node->kernel_info();
  296. if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) {
  297. MS_LOG(INFO) << "No kernel info";
  298. return;
  299. }
  300. auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
  301. if (address == nullptr) {
  302. MS_LOG(INFO) << "No kernel address";
  303. return;
  304. }
  305. auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
  306. auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
  307. parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
  308. auto d_kernel_info = parameter->kernel_info();
  309. MS_EXCEPTION_IF_NULL(d_kernel_info);
  310. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  311. builder.SetOutputsDeviceType({type});
  312. builder.SetOutputsFormat({format});
  313. d_kernel_info->set_select_kernel_build_info(builder.Build());
  314. AnfAlgo::SetOutputAddr(address, 0, parameter.get());
  315. }
  316. }
  317. std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input,
  318. KernelGraph *graph) {
  319. MS_EXCEPTION_IF_NULL(node);
  320. MS_EXCEPTION_IF_NULL(graph);
  321. std::vector<AnfNodePtr> parameters;
  322. std::vector<AnfNodePtr> pre_graph_out = {node};
  323. // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
  324. if (!AnfAlgo::IsRealKernel(node)) {
  325. pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
  326. }
  327. auto valid_inputs = graph->MutableValidInputs();
  328. MS_EXCEPTION_IF_NULL(valid_inputs);
  329. auto graph_inputs = graph->MutableInputs();
  330. MS_EXCEPTION_IF_NULL(graph_inputs);
  331. auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
  332. auto parameter = graph->NewParameter();
  333. MS_EXCEPTION_IF_NULL(parameter);
  334. parameter->set_abstract(abstract);
  335. auto new_parameter = graph->NewParameter(parameter);
  336. parameters.push_back(new_parameter);
  337. valid_inputs->push_back(valid_input);
  338. graph_inputs->push_back(new_parameter);
  339. };
  340. for (const auto &out_node : pre_graph_out) {
  341. MS_EXCEPTION_IF_NULL(out_node);
  342. auto abstract = out_node->abstract();
  343. MS_EXCEPTION_IF_NULL(abstract);
  344. // create multiple parameters if is a tuple output real kernel
  345. if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  346. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  347. MS_EXCEPTION_IF_NULL(tuple_abstract);
  348. MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
  349. for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
  350. create_parameter((*tuple_abstract)[output_idx]);
  351. }
  352. continue;
  353. }
  354. // create single parameter if is a abstract real kernel
  355. create_parameter(out_node->abstract());
  356. InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]);
  357. }
  358. return parameters;
  359. }
  360. ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input,
  361. KernelGraph *graph) {
  362. MS_EXCEPTION_IF_NULL(anf);
  363. if (!anf->isa<Parameter>()) {
  364. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  365. }
  366. MS_EXCEPTION_IF_NULL(graph);
  367. auto param_value = GetParamDefaultValue(anf);
  368. auto valid_inputs = graph->MutableValidInputs();
  369. MS_EXCEPTION_IF_NULL(valid_inputs);
  370. auto graph_inputs = graph->MutableInputs();
  371. MS_EXCEPTION_IF_NULL(graph_inputs);
  372. ParameterPtr new_parameter = nullptr;
  373. // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
  374. if (python_paras_ == nullptr) {
  375. python_paras_ = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
  376. }
  377. auto iter = python_paras_->find(param_value);
  378. if (iter != python_paras_->end()) {
  379. new_parameter = iter->second;
  380. } else {
  381. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  382. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  383. if (param_value != nullptr) {
  384. (*python_paras_)[param_value] = new_parameter;
  385. }
  386. TraceManager::EndTrace();
  387. }
  388. graph_inputs->push_back(new_parameter);
  389. valid_inputs->push_back(valid_input);
  390. return new_parameter;
  391. }
  392. AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
  393. MS_EXCEPTION_IF_NULL(anf);
  394. MS_EXCEPTION_IF_NULL(graph);
  395. MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
  396. auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
  397. if (parameters.empty()) {
  398. MS_LOG(EXCEPTION) << "No parameter exist!!";
  399. }
  400. if (parameters.size() == 1) {
  401. return parameters[0];
  402. }
  403. std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
  404. (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
  405. auto make_tuple = graph->NewCNode(make_tuple_input);
  406. MS_EXCEPTION_IF_NULL(make_tuple);
  407. MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
  408. return make_tuple;
  409. }
  410. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
  411. bool *from_other_graph,
  412. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  413. MS_EXCEPTION_IF_NULL(cnode);
  414. MS_EXCEPTION_IF_NULL(graph);
  415. MS_EXCEPTION_IF_NULL(from_other_graph);
  416. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  417. *from_other_graph = false;
  418. // get primitive of old node
  419. std::vector<AnfNodePtr> cnode_inputs;
  420. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  421. if (prim != nullptr) {
  422. // push attr to inputs[0] of new cnode
  423. cnode_inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
  424. } else {
  425. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  426. MS_EXCEPTION_IF_NULL(fg);
  427. auto new_fg = BasicClone(fg);
  428. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  429. }
  430. auto origin_inputs = cnode->inputs();
  431. bool optimize_depend = false;
  432. if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
  433. origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) {
  434. optimize_depend = true;
  435. }
  436. // if has multiple depends,only select first depend as parameter
  437. for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
  438. auto anf = origin_inputs[input_idx];
  439. MS_EXCEPTION_IF_NULL(anf);
  440. // anf has been created before
  441. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  442. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  443. continue;
  444. } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
  445. cnode_inputs.push_back((*other_graph_cnode)[anf]);
  446. continue;
  447. } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
  448. // if input is a value node,
  449. auto new_value_node = CreateNewValueNode(anf, graph);
  450. if (new_value_node != nullptr) {
  451. cnode_inputs.emplace_back(new_value_node);
  452. }
  453. continue;
  454. } else if (anf->isa<Parameter>() && AnfAlgo::GetOutputTensorNum(anf) == 1) {
  455. auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
  456. cnode_inputs.push_back(new_parameter);
  457. if (GetGraphIdByNode(anf) == kInvalidGraphId) {
  458. graph->FrontBackendlMapAdd(anf, new_parameter);
  459. } else {
  460. (*other_graph_cnode)[anf] = new_parameter;
  461. }
  462. continue;
  463. } else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
  464. cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
  465. continue;
  466. } else if (anf->isa<AnfNode>()) {
  467. *from_other_graph = true;
  468. // the input node is a cnode from other graph
  469. auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
  470. cnode_inputs.push_back(parameter_from_cnode);
  471. (*other_graph_cnode)[anf] = parameter_from_cnode;
  472. continue;
  473. }
  474. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  475. }
  476. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  477. auto new_cnode = graph->NewCNode(cnode_inputs);
  478. TraceManager::EndTrace();
  479. return new_cnode;
  480. }
  481. static std::vector<AnfNodePtr> CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
  482. MS_EXCEPTION_IF_NULL(cnode);
  483. MS_EXCEPTION_IF_NULL(graph);
  484. // create primitive of cnode:call(partial or switch)
  485. std::vector<AnfNodePtr> cnode_inputs = {
  486. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  487. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  488. MS_EXCEPTION_IF_NULL(attr_input);
  489. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  490. if (cnode_input == nullptr) {
  491. MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
  492. << ", but input[0] has not been created.";
  493. }
  494. // if the node is partial, insert the inputs of partial to the call
  495. if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
  496. auto partial_node = attr_input->cast<CNodePtr>();
  497. MS_EXCEPTION_IF_NULL(partial_node);
  498. auto partial_inputs = partial_node->inputs();
  499. std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
  500. std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
  501. MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
  502. return graph->GetBackendAnfByFrontAnf(node);
  503. });
  504. return cnode_inputs;
  505. } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
  506. cnode_inputs.emplace_back(cnode_input);
  507. return cnode_inputs;
  508. }
  509. MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
  510. }
  511. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
  512. MS_EXCEPTION_IF_NULL(cnode);
  513. MS_EXCEPTION_IF_NULL(graph);
  514. std::vector<AnfNodePtr> cnode_inputs;
  515. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  516. MS_EXCEPTION_IF_NULL(attr_input);
  517. if (AnfAlgo::IsGraphKernel(cnode)) {
  518. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  519. MS_EXCEPTION_IF_NULL(fg);
  520. auto new_fg = BasicClone(fg);
  521. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  522. } else if (IsValueNode<FuncGraph>(attr_input)) {
  523. // create primitive of cnode:call
  524. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  525. // create a ValueNode<KernelGraph> as input of cnode:call
  526. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  527. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
  528. } else {
  529. auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
  530. if (new_value_node != nullptr) {
  531. cnode_inputs.emplace_back(new_value_node);
  532. }
  533. }
  534. } else if (attr_input->isa<CNode>()) {
  535. cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
  536. } else {
  537. // get primitive of old node
  538. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  539. MS_EXCEPTION_IF_NULL(prim);
  540. // push attr to inputs[0] of new cnode
  541. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
  542. }
  543. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  544. auto anf = cnode->input(input_idx);
  545. MS_EXCEPTION_IF_NULL(anf);
  546. // anf has been created before
  547. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  548. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  549. continue;
  550. } else if (IsValueNode<None>(anf)) {
  551. continue;
  552. }
  553. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  554. }
  555. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  556. auto new_cnode = graph->NewCNode(cnode_inputs);
  557. TraceManager::EndTrace();
  558. return new_cnode;
  559. }
  560. ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
  561. MS_EXCEPTION_IF_NULL(anf);
  562. MS_EXCEPTION_IF_NULL(graph);
  563. auto value_node = anf->cast<ValueNodePtr>();
  564. MS_EXCEPTION_IF_NULL(value_node);
  565. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
  566. MS_EXCEPTION_IF_NULL(sub_func_graph);
  567. if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
  568. MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
  569. }
  570. auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
  571. ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
  572. new_value_node->set_abstract(value_node->abstract());
  573. // create new kernel_info of new value_node
  574. auto kernel_info = std::make_shared<device::KernelInfo>();
  575. kernel_info->SetFeatureMapFlag(false);
  576. new_value_node->set_kernel_info(kernel_info);
  577. // create kernel_build_info for new value node
  578. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  579. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  580. AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
  581. graph->FrontBackendlMapAdd(anf, new_value_node);
  582. return new_value_node;
  583. }
  584. ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  585. MS_EXCEPTION_IF_NULL(anf);
  586. MS_EXCEPTION_IF_NULL(graph);
  587. if (!anf->isa<Parameter>()) {
  588. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  589. }
  590. auto param_value = GetParamDefaultValue(anf);
  591. ParameterPtr new_parameter = nullptr;
  592. if (python_paras_ == nullptr) {
  593. python_paras_ = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
  594. }
  595. auto iter = python_paras_->find(param_value);
  596. if (iter != python_paras_->end()) {
  597. new_parameter = iter->second;
  598. } else {
  599. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  600. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  601. if (param_value != nullptr) {
  602. (*python_paras_)[param_value] = new_parameter;
  603. }
  604. TraceManager::EndTrace();
  605. }
  606. return new_parameter;
  607. }
  608. KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  609. std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
  610. auto graph = NewKernelGraph();
  611. MS_EXCEPTION_IF_NULL(graph);
  612. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  613. size_t from_other_graph_depend_num = 0;
  614. for (const auto &node : lst) {
  615. MS_EXCEPTION_IF_NULL(node);
  616. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  617. if (!node->isa<CNode>()) {
  618. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
  619. }
  620. auto cnode = node->cast<CNodePtr>();
  621. MS_EXCEPTION_IF_NULL(cnode);
  622. // create a new cnode object
  623. bool from_other_graph = false;
  624. // only first depend from other graph can create
  625. bool valid_input = true;
  626. if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
  627. valid_input = false;
  628. }
  629. auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
  630. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
  631. from_other_graph_depend_num++;
  632. }
  633. MS_EXCEPTION_IF_NULL(new_cnode);
  634. new_cnode->set_abstract(cnode->abstract());
  635. new_cnode->set_scope(cnode->scope());
  636. // record map relations between anf from ME and new anf node used in backend
  637. graph->FrontBackendlMapAdd(node, new_cnode);
  638. }
  639. // add a make_tuple at the end of graph as output
  640. graph->set_output(ConstructOutput(outputs, graph));
  641. MS_EXCEPTION_IF_NULL(context_);
  642. FuncGraphManagerPtr manager = MakeManager({graph});
  643. if (manager) {
  644. manager->AddFuncGraph(graph);
  645. graph->set_manager(manager);
  646. }
  647. graph->SetExecOrderByDefault();
  648. if (ExistSummaryNode(graph.get())) {
  649. graph->set_summary_node_exist(true);
  650. }
  651. opt::BackendCommonOptimization(graph);
  652. return graph;
  653. }
  654. void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) {
  655. MS_EXCEPTION_IF_NULL(node);
  656. MS_EXCEPTION_IF_NULL(graph);
  657. auto cnode = node->cast<CNodePtr>();
  658. MS_EXCEPTION_IF_NULL(cnode);
  659. // create a new cnode object
  660. auto new_cnode = CreateNewCNode(cnode, graph.get());
  661. MS_EXCEPTION_IF_NULL(new_cnode);
  662. new_cnode->set_abstract(cnode->abstract());
  663. new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
  664. new_cnode->set_scope(cnode->scope());
  665. graph->FrontBackendlMapAdd(node, new_cnode);
  666. if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
  667. graph->set_return(new_cnode);
  668. }
  669. }
  670. std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
  671. std::vector<KernelGraphPtr> *all_out_graph) {
  672. MS_EXCEPTION_IF_NULL(func_graph);
  673. MS_EXCEPTION_IF_NULL(all_out_graph);
  674. auto node_list = TopoSort(func_graph->get_return());
  675. auto graph = NewKernelGraph();
  676. MS_EXCEPTION_IF_NULL(graph);
  677. front_backend_graph_map_[func_graph] = graph;
  678. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  679. bool is_trace_back = false;
  680. for (const auto &node : node_list) {
  681. MS_EXCEPTION_IF_NULL(node);
  682. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  683. if (node->isa<Parameter>()) {
  684. auto graph_inputs = graph->MutableInputs();
  685. MS_EXCEPTION_IF_NULL(graph_inputs);
  686. auto new_parameter = CreateNewParameter(node, graph.get());
  687. graph_inputs->push_back(new_parameter);
  688. graph->FrontBackendlMapAdd(node, new_parameter);
  689. continue;
  690. } else if (node->isa<ValueNode>()) {
  691. if (!IsValueNode<FuncGraph>(node)) {
  692. // if input is a common value node,
  693. (void)CreateNewValueNode(node, graph.get());
  694. } else {
  695. // if input is a ValueNode<FuncGraph>
  696. FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
  697. if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
  698. is_trace_back = true;
  699. } else {
  700. (void)ConstructKernelGraph(child_graph, all_out_graph);
  701. }
  702. (void)CreateValueNodeKernelGraph(node, graph.get());
  703. }
  704. continue;
  705. } else {
  706. CreateCNodeKernelGraph(node, graph);
  707. }
  708. }
  709. // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
  710. graph->set_output_null(is_trace_back);
  711. AddParameterToGraphInputs(func_graph->parameters(), graph.get());
  712. graph->SetExecOrderByDefault();
  713. if (ExistSummaryNode(graph.get())) {
  714. graph->set_summary_node_exist(true);
  715. }
  716. all_out_graph->push_back(graph);
  717. return graph;
  718. }
  719. void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) {
  720. MS_EXCEPTION_IF_NULL(graph);
  721. auto graph_inputs = graph->MutableInputs();
  722. MS_EXCEPTION_IF_NULL(graph_inputs);
  723. graph_inputs->clear();
  724. for (auto &parameter : parameters) {
  725. MS_EXCEPTION_IF_NULL(parameter);
  726. auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
  727. if (backend_parameter == nullptr) {
  728. // for example "def f(x,y,z) {return x + y}", parameter z in unused
  729. auto new_parameter = CreateNewParameter(parameter, graph);
  730. graph_inputs->push_back(new_parameter);
  731. MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
  732. continue;
  733. }
  734. MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
  735. graph_inputs->push_back(backend_parameter);
  736. }
  737. }
  738. // run graph steps
  739. void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  740. const std::vector<tensor::TensorPtr> &inputs_const) const {
  741. std::vector<tensor::TensorPtr> inputs(inputs_const);
  742. size_t input_ctrl_size = 1;
  743. MS_EXCEPTION_IF_NULL(kernel_graph);
  744. if (kernel_graph->input_ctrl_tensors()) {
  745. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  746. }
  747. auto input_nodes = kernel_graph->inputs();
  748. if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
  749. MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  750. << ", input_ctrl_size:" << input_ctrl_size;
  751. }
  752. auto ms_context = MsContext::GetInstance();
  753. MS_EXCEPTION_IF_NULL(ms_context);
  754. for (size_t i = 0; i < inputs.size(); ++i) {
  755. auto tensor = inputs[i];
  756. MS_EXCEPTION_IF_NULL(tensor);
  757. auto input_node = input_nodes[i];
  758. MS_EXCEPTION_IF_NULL(input_node);
  759. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  760. auto pk_node = input_node->cast<ParameterPtr>();
  761. auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
  762. bool need_sync = false;
  763. if (ms_context->enable_pynative_infer()) {
  764. if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
  765. need_sync = true;
  766. }
  767. } else {
  768. if (tensor->is_dirty()) {
  769. need_sync = true;
  770. } else if (tensor->device_address() != device_address) {
  771. (void)tensor->data_sync();
  772. need_sync = true;
  773. }
  774. }
  775. if (need_sync) {
  776. if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) {
  777. tensor->set_device_address(device_address);
  778. }
  779. MS_EXCEPTION_IF_NULL(device_address);
  780. if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
  781. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  782. tensor->data_c())) {
  783. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  784. }
  785. }
  786. }
  787. tensor->set_dirty(false);
  788. }
  789. }
  790. void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  791. const std::vector<tensor::TensorPtr> &input_tensors) const {
  792. MS_EXCEPTION_IF_NULL(kernel_graph);
  793. MS_EXCEPTION_IF_NULL(outputs);
  794. if (!kernel_graph->child_graph_order().empty()) {
  795. // use the last child graph output as the root graph output
  796. UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors);
  797. return;
  798. }
  799. auto anf_outputs = kernel_graph->outputs();
  800. for (auto &item : anf_outputs) {
  801. MS_EXCEPTION_IF_NULL(item);
  802. MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
  803. if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
  804. outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
  805. continue;
  806. }
  807. outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
  808. }
  809. }
  810. void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
  811. MS_EXCEPTION_IF_NULL(callback);
  812. summary_callback_ = callback;
  813. }
  814. void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
  815. void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
  816. MS_LOG(DEBUG) << "Update summary Start";
  817. MS_EXCEPTION_IF_NULL(graph);
  818. if (!graph->summary_node_exist()) {
  819. return;
  820. }
  821. auto summary = graph->summary_nodes();
  822. auto apply_list = TopoSort(graph->get_return());
  823. for (auto &n : apply_list) {
  824. MS_EXCEPTION_IF_NULL(n);
  825. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  826. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  827. auto cnode = n->cast<CNodePtr>();
  828. MS_EXCEPTION_IF_NULL(cnode);
  829. if (cnode->inputs().size() <= kSummaryGetItem) {
  830. MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!";
  831. }
  832. auto node = cnode->input(kSummaryGetItem);
  833. MS_EXCEPTION_IF_NULL(node);
  834. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  835. MS_EXCEPTION_IF_NULL(item_with_index.first);
  836. if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
  837. MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
  838. }
  839. summary[n->fullname_with_scope()] = item_with_index;
  840. }
  841. }
  842. graph->set_summary_nodes(summary);
  843. MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
  844. }
  845. void SessionBasic::Summary(KernelGraph *graph) {
  846. if (summary_callback_ == nullptr) {
  847. return;
  848. }
  849. MS_EXCEPTION_IF_NULL(graph);
  850. bool exist_summary = graph->summary_node_exist();
  851. if (!exist_summary) {
  852. return;
  853. }
  854. GetSummaryNodes(graph);
  855. auto summary_outputs = graph->summary_nodes();
  856. std::map<std::string, tensor::TensorPtr> params_list;
  857. // fetch outputs apply kernel in session & run callback functions
  858. for (auto &output_item : summary_outputs) {
  859. auto node = output_item.second.first;
  860. size_t index = IntToSize(output_item.second.second);
  861. auto address = AnfAlgo::GetOutputAddr(node, index);
  862. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  863. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  864. std::vector<int> temp_shape;
  865. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  866. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  867. MS_EXCEPTION_IF_NULL(address);
  868. if (!address->GetPtr()) {
  869. continue;
  870. }
  871. if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
  872. tensor->data_type(), tensor->data_c())) {
  873. MS_LOG(ERROR) << "Failed to sync output from device to host.";
  874. }
  875. tensor->set_dirty(false);
  876. params_list[output_item.first] = tensor;
  877. }
  878. // call callback function here
  879. summary_callback_(0, params_list);
  880. }
  881. CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
  882. MS_EXCEPTION_IF_NULL(graph);
  883. std::vector<AnfNodePtr> output_args;
  884. for (const auto &output : outputs) {
  885. MS_EXCEPTION_IF_NULL(output);
  886. MS_LOG(INFO) << "Output:" << output->DebugString();
  887. }
  888. auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
  889. auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
  890. if (backend_anf != nullptr) {
  891. auto context_ptr = MsContext::GetInstance();
  892. MS_EXCEPTION_IF_NULL(context_ptr);
  893. if (context_ptr->execution_mode() == kPynativeMode) {
  894. return backend_anf;
  895. }
  896. auto front_real_kernel = AnfAlgo::VisitKernel(out, 0);
  897. auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0);
  898. MS_EXCEPTION_IF_NULL(out);
  899. auto out_func_graph = out->func_graph();
  900. MS_EXCEPTION_IF_NULL(out_func_graph);
  901. auto out_func_graph_manager = out_func_graph->manager();
  902. if (out_func_graph_manager == nullptr) {
  903. return backend_anf;
  904. }
  905. auto node_users = out_func_graph_manager->node_users();
  906. auto users = node_users[out];
  907. bool internal_output = true;
  908. std::string kernel_target = GetCNodeTarget(front_real_kernel.first);
  909. for (auto user : users) {
  910. if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) {
  911. internal_output = false;
  912. break;
  913. }
  914. }
  915. if (internal_output) {
  916. MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString();
  917. graph->AddInternalOutput(out, backend_real_kernel.first);
  918. }
  919. return backend_anf;
  920. }
  921. MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
  922. };
  923. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  924. (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
  925. [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
  926. return graph->NewCNode(output_args);
  927. }
  928. void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
  929. MS_LOG(INFO) << "Start!";
  930. std::vector<AnfNodePtr> make_tuple_inputs;
  931. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  932. MS_EXCEPTION_IF_NULL(graph);
  933. if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
  934. for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
  935. auto idx = NewValueNode(SizeToInt(output_index));
  936. MS_EXCEPTION_IF_NULL(idx);
  937. auto imm = std::make_shared<Int32Imm>(output_index);
  938. idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
  939. auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
  940. std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
  941. std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
  942. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
  943. make_tuple_inputs.push_back(getitem);
  944. }
  945. } else {
  946. make_tuple_inputs.push_back(cnode);
  947. }
  948. // create output
  949. auto g_output = graph->NewCNode(make_tuple_inputs);
  950. graph->set_output(g_output);
  951. // set graph manager,which now is only used to get valuenodes and hardware optimizing
  952. MS_EXCEPTION_IF_NULL(context_);
  953. FuncGraphManagerPtr manager = context_->manager();
  954. if (manager != nullptr) {
  955. manager->AddFuncGraph(graph);
  956. graph->set_manager(manager);
  957. }
  958. MS_LOG(INFO) << "Finish!";
  959. }
  960. std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  961. const std::vector<tensor::TensorPtr> &input_tensors,
  962. const std::vector<int> &tensors_mask) {
  963. auto graph = std::make_shared<KernelGraph>();
  964. std::vector<AnfNodePtr> inputs;
  965. // set input[0]
  966. PrimitivePtr op_prim = op_run_info.py_primitive;
  967. MS_EXCEPTION_IF_NULL(op_prim);
  968. inputs.push_back(std::make_shared<ValueNode>(op_prim));
  969. // set input parameter
  970. MS_LOG(INFO) << "Input tensor size: " << input_tensors.size();
  971. if (input_tensors.size() != tensors_mask.size()) {
  972. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  973. << tensors_mask.size();
  974. }
  975. for (size_t i = 0; i < input_tensors.size(); ++i) {
  976. if (tensors_mask[i] == kValueNodeTensorMask) {
  977. auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]);
  978. inputs.push_back(value_node);
  979. continue;
  980. }
  981. auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
  982. inputs.push_back(parameter);
  983. auto mutable_inputs = graph->MutableInputs();
  984. MS_EXCEPTION_IF_NULL(mutable_inputs);
  985. mutable_inputs->push_back(parameter);
  986. }
  987. // set execution order
  988. auto cnode = graph->NewCNode(inputs);
  989. MS_EXCEPTION_IF_NULL(cnode);
  990. // set abstract,which include inferred shapes and types
  991. cnode->set_abstract(op_run_info.abstract);
  992. // set execution order
  993. std::vector<CNodePtr> exe_order = {cnode};
  994. graph->set_execution_order(exe_order);
  995. // set output
  996. CreateOutputNode(cnode, graph);
  997. return graph;
  998. }
  999. BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
  1000. if (utils::isa<VectorRef>(base_ref)) {
  1001. auto ref_list = utils::cast<VectorRef>(base_ref);
  1002. py::tuple output_tensors(ref_list.size());
  1003. for (size_t i = 0; i < ref_list.size(); ++i) {
  1004. auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
  1005. if (utils::isa<tensor::TensorPtr>(output)) {
  1006. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  1007. MS_EXCEPTION_IF_NULL(tensor_ptr);
  1008. output_tensors[i] = tensor_ptr;
  1009. } else if (utils::isa<PyObjectRef>(output)) {
  1010. py::object obj = utils::cast<PyObjectRef>(output).object_;
  1011. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  1012. output_tensors[i] = tensor_tuple;
  1013. } else {
  1014. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  1015. }
  1016. }
  1017. return output_tensors; // turn tuple to py::object and store in PyObjectRef
  1018. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  1019. return base_ref;
  1020. } else {
  1021. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  1022. }
  1023. }
  1024. KernelGraphPtr SessionBasic::NewKernelGraph() {
  1025. auto graph = std::make_shared<KernelGraph>();
  1026. graph->set_graph_id(graph_sum_);
  1027. graphs_[graph_sum_++] = graph;
  1028. return graph;
  1029. }
  1030. } // namespace session
  1031. } // namespace mindspore