You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.cc 47 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/session_basic.h"
  17. #include <utility>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include "pipeline/parse/data_converter.h"
  22. #include "ir/manager.h"
  23. #include "ir/param_value.h"
  24. #include "kernel/common_utils.h"
  25. #include "operator/ops.h"
  26. #include "common/trans.h"
  27. #include "utils/context/ms_context.h"
  28. #include "utils/config_manager.h"
  29. #include "session/anf_runtime_algorithm.h"
  30. #include "kernel/oplib/oplib.h"
  31. #include "pre_activate/common/common_backend_optimization.h"
  32. #include "pre_activate/pass/const_input_to_attr_registry.h"
  33. #include "pre_activate/common/helper.h"
  34. #include "common/utils.h"
  35. #include "ir/dtype.h"
  36. #include "ir/anf.h"
  37. #include "ir/func_graph_cloner.h"
  38. namespace mindspore {
  39. namespace session {
  40. static std::shared_ptr<std::map<ParamValuePtr, ParameterPtr>> python_paras;
  41. void ClearPythonParasMap() { python_paras = nullptr; }
  42. namespace {
  43. const int kSummaryGetItem = 2;
  44. ParamValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
  45. if (node == nullptr) {
  46. return nullptr;
  47. }
  48. auto parameter = node->cast<ParameterPtr>();
  49. if (parameter == nullptr || !parameter->has_default()) {
  50. return nullptr;
  51. }
  52. return parameter->default_param();
  53. }
  54. BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
  55. const std::vector<tensor::TensorPtr> &input_tensors) {
  56. MS_EXCEPTION_IF_NULL(node);
  57. MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
  58. // if node is a value node, no need sync addr from device to host
  59. if (!AnfAlgo::OutputAddrExist(node, output_index)) {
  60. if (node->isa<ValueNode>()) {
  61. auto value_node = node->cast<ValueNodePtr>();
  62. MS_EXCEPTION_IF_NULL(value_node);
  63. return value_node->value();
  64. }
  65. if (node->isa<Parameter>()) {
  66. for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
  67. if (input_idx >= input_tensors.size()) {
  68. MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
  69. }
  70. if (graph.inputs()[input_idx] == node) {
  71. return input_tensors[input_idx];
  72. }
  73. }
  74. MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << "has no output addr";
  75. }
  76. }
  77. // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
  78. auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
  79. MS_EXCEPTION_IF_NULL(address);
  80. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  81. TypeId type_id = kNumberTypeFloat32;
  82. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  83. std::vector<int> temp_shape;
  84. if (graph.IsInternalOutput(node)) {
  85. temp_shape.emplace_back(1);
  86. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  87. tensor->set_device_address(address);
  88. tensor->set_dirty(false);
  89. return tensor;
  90. }
  91. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  92. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  93. // if in paynative mode,data only copyed to host when user want to print data
  94. auto ms_context = MsContext::GetInstance();
  95. MS_EXCEPTION_IF_NULL(ms_context);
  96. if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) {
  97. tensor->set_device_address(address);
  98. tensor->set_dirty(false);
  99. } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
  100. LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c())) {
  101. MS_LOG(INFO) << "Output sync device to host error!!!";
  102. tensor->set_dirty(false);
  103. }
  104. return tensor;
  105. }
  106. BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  107. const std::vector<tensor::TensorPtr> &input_tensors) {
  108. MS_EXCEPTION_IF_NULL(anf);
  109. MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]";
  110. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  111. MS_EXCEPTION_IF_NULL(item_with_index.first);
  112. MS_LOG(INFO) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
  113. // special handle for maketuple
  114. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  115. auto cnode = item_with_index.first->cast<CNodePtr>();
  116. MS_EXCEPTION_IF_NULL(cnode);
  117. VectorRef ret;
  118. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  119. auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
  120. ret.push_back(out);
  121. }
  122. return ret;
  123. }
  124. // if is graph return nothing ,the function should return a null anylist
  125. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  126. if (size == 0) {
  127. return VectorRef();
  128. }
  129. return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
  130. }
  131. BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  132. const std::vector<tensor::TensorPtr> &input_tensors) {
  133. MS_EXCEPTION_IF_NULL(anf);
  134. if (!AnfAlgo::IsRealKernel(anf)) {
  135. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] should be a executable kernel";
  136. }
  137. if (anf->isa<ValueNode>()) {
  138. return CreateOneTensor(anf, 0, graph, input_tensors);
  139. }
  140. VectorRef ret;
  141. if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
  142. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
  143. auto out = CreateOneTensor(anf, i, graph, input_tensors);
  144. ret.emplace_back(out);
  145. }
  146. }
  147. return ret;
  148. }
  149. ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
  150. MS_EXCEPTION_IF_NULL(anf);
  151. MS_EXCEPTION_IF_NULL(graph);
  152. auto value_node = anf->cast<ValueNodePtr>();
  153. MS_EXCEPTION_IF_NULL(value_node);
  154. auto value = value_node->value();
  155. MS_EXCEPTION_IF_NULL(value);
  156. if (value->isa<None>()) {
  157. return nullptr;
  158. }
  159. auto new_value_node = graph->NewValueNode(value_node);
  160. graph->FrontBackendlMapAdd(anf, new_value_node);
  161. graph->AddValueNodeToGraph(new_value_node);
  162. return new_value_node;
  163. }
  164. size_t LoadCtrlInputTensor(const std::shared_ptr<KernelGraph> &graph, std::vector<tensor::TensorPtr> *inputs) {
  165. MS_EXCEPTION_IF_NULL(graph);
  166. MS_LOG(INFO) << "Load kInputCtrlTensors";
  167. auto inputs_params = graph->input_ctrl_tensors();
  168. if (inputs_params == nullptr) {
  169. return 0;
  170. }
  171. if (inputs_params->empty()) {
  172. MS_LOG(EXCEPTION) << "Illegal empty inputs_params";
  173. }
  174. auto tensor = (*inputs_params)[0];
  175. MS_EXCEPTION_IF_NULL(tensor);
  176. auto *val = static_cast<int32_t *>(tensor->data_c());
  177. MS_EXCEPTION_IF_NULL(val);
  178. *val = 0;
  179. tensor->set_dirty(true);
  180. // set loop_count to zero
  181. MS_EXCEPTION_IF_NULL(inputs);
  182. inputs->push_back(tensor);
  183. auto epoch_tensor = (*inputs_params)[1];
  184. MS_EXCEPTION_IF_NULL(epoch_tensor);
  185. auto *epoch_val = static_cast<int32_t *>(epoch_tensor->data_c());
  186. MS_EXCEPTION_IF_NULL(epoch_val);
  187. *epoch_val = graph->current_epoch();
  188. epoch_tensor->set_dirty(true);
  189. inputs->push_back(epoch_tensor);
  190. MS_LOG(INFO) << "Load epoch_val:" << *epoch_val;
  191. graph->set_current_epoch(graph->current_epoch() + 1);
  192. return inputs_params->size();
  193. }
  194. ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor) {
  195. MS_EXCEPTION_IF_NULL(graph);
  196. MS_EXCEPTION_IF_NULL(input_tensor);
  197. auto value_node = std::make_shared<ValueNode>(input_tensor);
  198. MS_EXCEPTION_IF_NULL(value_node);
  199. // construct abstract of value node
  200. auto type_of_tensor = input_tensor->Dtype();
  201. auto shape_of_tensor = input_tensor->shape();
  202. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  203. value_node->set_abstract(abstract);
  204. // add value node to graph
  205. auto input_value_node = graph->NewValueNode(value_node);
  206. graph->AddValueNodeToGraph(input_value_node);
  207. return input_value_node;
  208. }
  209. ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
  210. int tensor_mask) {
  211. MS_EXCEPTION_IF_NULL(graph);
  212. auto param = graph->NewParameter();
  213. MS_EXCEPTION_IF_NULL(param);
  214. if (tensor_mask == kParameterWeightTensorMask) {
  215. auto param_value_new = std::make_shared<ParamValue>();
  216. param->set_default_param(param_value_new);
  217. }
  218. // set the kernel info of parameter
  219. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  220. MS_EXCEPTION_IF_NULL(input_tensor);
  221. if (input_tensor->device_address().get() == nullptr) {
  222. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  223. TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
  224. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
  225. } else {
  226. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{input_tensor->device_address()->format()});
  227. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()});
  228. }
  229. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
  230. // construct abstract of parameter
  231. auto type_of_tensor = input_tensor->Dtype();
  232. auto shape_of_tensor = input_tensor->shape();
  233. auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
  234. param->set_abstract(abstract);
  235. return param;
  236. }
  237. void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
  238. MS_LOG(INFO) << "Graph outputs:";
  239. const size_t max_deep = 10;
  240. if (recurse_level > max_deep) {
  241. MS_LOG(INFO) << "Recurse too deep";
  242. return;
  243. }
  244. std::string tab_str;
  245. for (size_t i = 0; i < recurse_level; i++) {
  246. tab_str = tab_str.append(" ");
  247. }
  248. if (any.is<AnyList>()) {
  249. (void)tab_str.append("{");
  250. MS_LOG(INFO) << tab_str;
  251. auto any_list = any.cast<AnyList>();
  252. for (auto &it : any_list) {
  253. DumpGraphOutput(it, recurse_level + 1);
  254. }
  255. (void)tab_str.append("}");
  256. MS_LOG(INFO) << tab_str;
  257. }
  258. (void)tab_str.append(any.ToString());
  259. MS_LOG(INFO) << tab_str;
  260. }
  261. bool ExistSummaryNode(const KernelGraph *graph) {
  262. MS_EXCEPTION_IF_NULL(graph);
  263. auto ret = graph->get_return();
  264. MS_EXCEPTION_IF_NULL(ret);
  265. auto all_nodes = DeepLinkedGraphSearch(ret);
  266. for (auto &n : all_nodes) {
  267. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  268. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  269. return true;
  270. }
  271. }
  272. return false;
  273. }
  274. } // namespace
  275. GraphId SessionBasic::graph_sum_ = 0;
  276. KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) {
  277. auto it = graphs_.find(graph_id);
  278. if (it == graphs_.end()) {
  279. MS_LOG(WARNING) << "Can't find graph " << graph_id;
  280. return nullptr;
  281. }
  282. return it->second;
  283. }
  284. void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter) {
  285. auto graph_id = GetGraphIdByNode(out_node);
  286. if (graph_id == kInvalidGraphId) {
  287. return;
  288. }
  289. auto node_graph = GetGraph(graph_id);
  290. if (node_graph == nullptr) {
  291. return;
  292. }
  293. MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
  294. auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node);
  295. if (ref_node == nullptr) {
  296. MS_LOG(INFO) << "No corresponding internal output for output node";
  297. return;
  298. }
  299. auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0);
  300. auto ref_real_node = real_kernel.first;
  301. auto ref_real_node_index = real_kernel.second;
  302. if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) &&
  303. node_graph->IsFinalOutputKernel(ref_real_node)) {
  304. auto kernel_info = ref_real_node->kernel_info();
  305. if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) {
  306. MS_LOG(INFO) << "No kernel info";
  307. return;
  308. }
  309. auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
  310. if (address == nullptr) {
  311. MS_LOG(INFO) << "No kernel address";
  312. return;
  313. }
  314. auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
  315. auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
  316. parameter->set_kernel_info(std::make_shared<device::KernelInfo>());
  317. auto d_kernel_info = parameter->kernel_info();
  318. MS_EXCEPTION_IF_NULL(d_kernel_info);
  319. kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
  320. builder.SetOutputsDeviceType({type});
  321. builder.SetOutputsFormat({format});
  322. d_kernel_info->set_select_kernel_build_info(builder.Build());
  323. AnfAlgo::SetOutputAddr(address, 0, parameter.get());
  324. }
  325. }
  326. std::vector<AnfNodePtr> SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input,
  327. KernelGraph *graph) {
  328. MS_EXCEPTION_IF_NULL(node);
  329. MS_EXCEPTION_IF_NULL(graph);
  330. std::vector<AnfNodePtr> parameters;
  331. std::vector<AnfNodePtr> pre_graph_out = {node};
  332. // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
  333. if (!AnfAlgo::IsRealKernel(node)) {
  334. pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
  335. }
  336. auto valid_inputs = graph->MutableValidInputs();
  337. MS_EXCEPTION_IF_NULL(valid_inputs);
  338. auto graph_inputs = graph->MutableInputs();
  339. MS_EXCEPTION_IF_NULL(graph_inputs);
  340. auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
  341. auto parameter = graph->NewParameter();
  342. MS_EXCEPTION_IF_NULL(parameter);
  343. parameter->set_abstract(abstract);
  344. auto new_parameter = graph->NewParameter(parameter);
  345. parameters.push_back(new_parameter);
  346. valid_inputs->push_back(valid_input);
  347. graph_inputs->push_back(new_parameter);
  348. };
  349. for (const auto &out_node : pre_graph_out) {
  350. MS_EXCEPTION_IF_NULL(out_node);
  351. auto abstract = out_node->abstract();
  352. MS_EXCEPTION_IF_NULL(abstract);
  353. // create multiple parameters if is a tuple output real kernel
  354. if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  355. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  356. MS_EXCEPTION_IF_NULL(tuple_abstract);
  357. MS_LOG(INFO) << "Tuple_size [" << tuple_abstract->size() << "]";
  358. for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
  359. create_parameter((*tuple_abstract)[output_idx]);
  360. }
  361. continue;
  362. }
  363. // create single parameter if is a abstract real kernel
  364. create_parameter(out_node->abstract());
  365. InitInternalOutputParameter(out_node, parameters[parameters.size() - 1]);
  366. }
  367. return parameters;
  368. }
  369. ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input,
  370. KernelGraph *graph) {
  371. MS_EXCEPTION_IF_NULL(anf);
  372. if (!anf->isa<Parameter>()) {
  373. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  374. }
  375. MS_EXCEPTION_IF_NULL(graph);
  376. auto param_value = GetParamDefaultValue(anf);
  377. auto valid_inputs = graph->MutableValidInputs();
  378. MS_EXCEPTION_IF_NULL(valid_inputs);
  379. auto graph_inputs = graph->MutableInputs();
  380. MS_EXCEPTION_IF_NULL(graph_inputs);
  381. ParameterPtr new_parameter = nullptr;
  382. // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
  383. if (python_paras == nullptr) {
  384. python_paras = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
  385. }
  386. auto iter = python_paras->find(param_value);
  387. if (iter != python_paras->end()) {
  388. new_parameter = iter->second;
  389. } else {
  390. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  391. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  392. if (param_value != nullptr) {
  393. (*python_paras)[param_value] = new_parameter;
  394. }
  395. TraceManager::EndTrace();
  396. }
  397. graph_inputs->push_back(new_parameter);
  398. valid_inputs->push_back(valid_input);
  399. return new_parameter;
  400. }
  401. AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
  402. MS_EXCEPTION_IF_NULL(anf);
  403. MS_EXCEPTION_IF_NULL(graph);
  404. MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
  405. auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
  406. if (parameters.empty()) {
  407. MS_LOG(EXCEPTION) << "No parameter exist!!";
  408. }
  409. if (parameters.size() == 1) {
  410. return parameters[0];
  411. }
  412. std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
  413. (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
  414. auto make_tuple = graph->NewCNode(make_tuple_input);
  415. MS_EXCEPTION_IF_NULL(make_tuple);
  416. MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
  417. return make_tuple;
  418. }
  419. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
  420. bool *from_other_graph,
  421. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  422. MS_EXCEPTION_IF_NULL(cnode);
  423. MS_EXCEPTION_IF_NULL(graph);
  424. MS_EXCEPTION_IF_NULL(from_other_graph);
  425. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  426. *from_other_graph = false;
  427. // get primitive of old node
  428. std::vector<AnfNodePtr> cnode_inputs;
  429. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  430. if (prim != nullptr) {
  431. // push attr to inputs[0] of new cnode
  432. cnode_inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
  433. } else {
  434. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  435. MS_EXCEPTION_IF_NULL(fg);
  436. auto new_fg = BasicClone(fg);
  437. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  438. }
  439. auto origin_inputs = cnode->inputs();
  440. bool optimize_depend = false;
  441. if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 &&
  442. origin_inputs[kRealInputIndexInDepend]->isa<ValueNode>()) {
  443. optimize_depend = true;
  444. }
  445. // if has multiple depends,only select first depend as parameter
  446. for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
  447. auto anf = origin_inputs[input_idx];
  448. MS_EXCEPTION_IF_NULL(anf);
  449. // anf has been created before
  450. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  451. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  452. continue;
  453. } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
  454. cnode_inputs.push_back((*other_graph_cnode)[anf]);
  455. continue;
  456. } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
  457. // if input is a value node,
  458. auto new_value_node = CreateNewValueNode(anf, graph);
  459. if (new_value_node != nullptr) {
  460. cnode_inputs.emplace_back(new_value_node);
  461. }
  462. continue;
  463. } else if (anf->isa<Parameter>()) {
  464. auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
  465. cnode_inputs.push_back(new_parameter);
  466. if (GetGraphIdByNode(anf) == kInvalidGraphId) {
  467. graph->FrontBackendlMapAdd(anf, new_parameter);
  468. } else {
  469. (*other_graph_cnode)[anf] = new_parameter;
  470. }
  471. continue;
  472. } else if (optimize_depend && input_idx == kDependAttachNodeIndex) {
  473. cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]);
  474. continue;
  475. } else {
  476. *from_other_graph = true;
  477. // the input node is a cnode from other graph
  478. auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
  479. cnode_inputs.push_back(parameter_from_cnode);
  480. (*other_graph_cnode)[anf] = parameter_from_cnode;
  481. }
  482. }
  483. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  484. auto new_cnode = graph->NewCNode(cnode_inputs);
  485. TraceManager::EndTrace();
  486. return new_cnode;
  487. }
  488. CNodePtr SessionBasic::CreateSwitchInput(const AnfNodePtr &node_input, KernelGraph *graph) {
  489. MS_EXCEPTION_IF_NULL(node_input);
  490. MS_EXCEPTION_IF_NULL(graph);
  491. // switch input generalizes partial
  492. if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial) ||
  493. AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimCall)) {
  494. return node_input->cast<CNodePtr>();
  495. }
  496. if (node_input->isa<CNode>()) {
  497. MS_LOG(EXCEPTION) << "If switch input is " << node_input->DebugString() << ", it mast be partial or call.";
  498. }
  499. std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
  500. if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
  501. partial_inputs.emplace_back(node_input);
  502. auto partial_node = graph->NewCNode(partial_inputs);
  503. return partial_node;
  504. }
  505. KernelGraphPtr kernel_graph = NewKernelGraph();
  506. MS_EXCEPTION_IF_NULL(kernel_graph);
  507. kernel_graph->set_output(graph->GetBackendAnfByFrontAnf(node_input));
  508. partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
  509. auto partial_node = graph->NewCNode(partial_inputs);
  510. return partial_node;
  511. }
  512. CNodePtr SessionBasic::HandleSwitchInputs(const AnfNodePtr &anf_node, KernelGraph *graph) {
  513. MS_EXCEPTION_IF_NULL(anf_node);
  514. MS_EXCEPTION_IF_NULL(graph);
  515. auto node = anf_node->cast<CNodePtr>();
  516. MS_EXCEPTION_IF_NULL(node);
  517. if (node->inputs().size() < kSwitchInputSize) {
  518. MS_LOG(EXCEPTION) << "Switch input size less than " << kSwitchInputSize;
  519. }
  520. auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimSwitch->name()));
  521. std::vector<AnfNodePtr> switch_inputs = {primitive, node->input(1)};
  522. for (size_t index = 2; index < node->inputs().size(); index++) {
  523. auto input = CreateSwitchInput(node->input(index), graph);
  524. switch_inputs.emplace_back(input);
  525. }
  526. auto switch_node = graph->NewCNode(switch_inputs);
  527. return switch_node;
  528. }
  529. std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
  530. MS_EXCEPTION_IF_NULL(cnode);
  531. MS_EXCEPTION_IF_NULL(graph);
  532. // create primitive of cnode:call(partial or switch)
  533. std::vector<AnfNodePtr> cnode_inputs = {
  534. graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  535. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  536. MS_EXCEPTION_IF_NULL(attr_input);
  537. auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
  538. if (cnode_input == nullptr) {
  539. MS_LOG(EXCEPTION) << "CNode input[0] is CNode:" << attr_input->DebugString()
  540. << ", but input[0] has not been created.";
  541. }
  542. // if the node is partial, insert the inputs of partial to the call
  543. if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
  544. auto partial_node = attr_input->cast<CNodePtr>();
  545. MS_EXCEPTION_IF_NULL(partial_node);
  546. auto partial_inputs = partial_node->inputs();
  547. std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
  548. std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
  549. MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
  550. return graph->GetBackendAnfByFrontAnf(node);
  551. });
  552. return cnode_inputs;
  553. } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
  554. auto switch_node = HandleSwitchInputs(cnode_input, graph);
  555. cnode_inputs.emplace_back(switch_node);
  556. return cnode_inputs;
  557. }
  558. MS_LOG(EXCEPTION) << "CNode input[0] must be partial or switch.";
  559. }
  560. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
  561. MS_EXCEPTION_IF_NULL(cnode);
  562. MS_EXCEPTION_IF_NULL(graph);
  563. std::vector<AnfNodePtr> cnode_inputs;
  564. auto attr_input = cnode->input(kAnfPrimitiveIndex);
  565. MS_EXCEPTION_IF_NULL(attr_input);
  566. if (AnfAlgo::IsGraphKernel(cnode)) {
  567. auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
  568. MS_EXCEPTION_IF_NULL(fg);
  569. auto new_fg = BasicClone(fg);
  570. cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
  571. } else if (IsValueNode<FuncGraph>(attr_input)) {
  572. // create primitive of cnode:call
  573. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
  574. // create a ValueNode<KernelGraph> as input of cnode:call
  575. if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
  576. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
  577. } else {
  578. auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
  579. if (new_value_node != nullptr) {
  580. cnode_inputs.emplace_back(new_value_node);
  581. }
  582. }
  583. } else if (attr_input->isa<CNode>()) {
  584. cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
  585. } else {
  586. // get primitive of old node
  587. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  588. MS_EXCEPTION_IF_NULL(prim);
  589. // push attr to inputs[0] of new cnode
  590. cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
  591. }
  592. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  593. auto anf = cnode->input(input_idx);
  594. MS_EXCEPTION_IF_NULL(anf);
  595. // anf has been created before
  596. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  597. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  598. continue;
  599. } else if (IsValueNode<None>(anf)) {
  600. continue;
  601. }
  602. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  603. }
  604. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  605. auto new_cnode = graph->NewCNode(cnode_inputs);
  606. TraceManager::EndTrace();
  607. return new_cnode;
  608. }
  609. ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
  610. MS_EXCEPTION_IF_NULL(anf);
  611. MS_EXCEPTION_IF_NULL(graph);
  612. auto value_node = anf->cast<ValueNodePtr>();
  613. MS_EXCEPTION_IF_NULL(value_node);
  614. auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
  615. MS_EXCEPTION_IF_NULL(sub_func_graph);
  616. if (front_backend_graph_map_.find(sub_func_graph) == front_backend_graph_map_.end()) {
  617. MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
  618. }
  619. auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph];
  620. ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
  621. new_value_node->set_abstract(value_node->abstract());
  622. // create new kernel_info of new value_node
  623. auto kernel_info = std::make_shared<device::KernelInfo>();
  624. kernel_info->SetFeatureMapFlag(false);
  625. new_value_node->set_kernel_info(kernel_info);
  626. // create kernel_build_info for new value node
  627. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  628. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
  629. AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
  630. graph->FrontBackendlMapAdd(anf, new_value_node);
  631. return new_value_node;
  632. }
  633. ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
  634. MS_EXCEPTION_IF_NULL(anf);
  635. MS_EXCEPTION_IF_NULL(graph);
  636. if (!anf->isa<Parameter>()) {
  637. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
  638. }
  639. auto param_value = GetParamDefaultValue(anf);
  640. ParameterPtr new_parameter = nullptr;
  641. if (python_paras == nullptr) {
  642. python_paras = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
  643. }
  644. auto iter = python_paras->find(param_value);
  645. if (iter != python_paras->end()) {
  646. new_parameter = iter->second;
  647. } else {
  648. TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
  649. new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  650. if (param_value != nullptr) {
  651. (*python_paras)[param_value] = new_parameter;
  652. }
  653. TraceManager::EndTrace();
  654. }
  655. return new_parameter;
  656. }
  657. KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  658. std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
  659. auto graph = NewKernelGraph();
  660. MS_EXCEPTION_IF_NULL(graph);
  661. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  662. size_t from_other_graph_depend_num = 0;
  663. for (const auto &node : lst) {
  664. MS_EXCEPTION_IF_NULL(node);
  665. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  666. if (!node->isa<CNode>()) {
  667. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
  668. }
  669. auto cnode = node->cast<CNodePtr>();
  670. MS_EXCEPTION_IF_NULL(cnode);
  671. // create a new cnode object
  672. bool from_other_graph = false;
  673. // only first depend from other graph can create
  674. bool valid_input = true;
  675. if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
  676. valid_input = false;
  677. }
  678. auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
  679. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
  680. from_other_graph_depend_num++;
  681. }
  682. MS_EXCEPTION_IF_NULL(new_cnode);
  683. new_cnode->set_abstract(cnode->abstract());
  684. new_cnode->set_scope(cnode->scope());
  685. // record map relations between anf from ME and new anf node used in backend
  686. graph->FrontBackendlMapAdd(node, new_cnode);
  687. }
  688. // add a make_tuple at the end of graph as output
  689. graph->set_output(ConstructOutput(outputs, graph));
  690. MS_EXCEPTION_IF_NULL(context_);
  691. FuncGraphManagerPtr manager = MakeManager({graph});
  692. if (manager) {
  693. manager->AddFuncGraph(graph);
  694. graph->set_manager(manager);
  695. }
  696. graph->SetExecOrderByDefault();
  697. if (ExistSummaryNode(graph.get())) {
  698. graph->set_summary_node_exist(true);
  699. }
  700. opt::BackendCommonOptimization(graph);
  701. return graph;
  702. }
  703. void SessionBasic::CreateCNodeKernelGraph(const AnfNodePtr node, KernelGraphPtr graph) {
  704. MS_EXCEPTION_IF_NULL(node);
  705. MS_EXCEPTION_IF_NULL(graph);
  706. auto cnode = node->cast<CNodePtr>();
  707. MS_EXCEPTION_IF_NULL(cnode);
  708. // create a new cnode object
  709. auto new_cnode = CreateNewCNode(cnode, graph.get());
  710. MS_EXCEPTION_IF_NULL(new_cnode);
  711. new_cnode->set_abstract(cnode->abstract());
  712. new_cnode->set_fullname_with_scope(cnode->fullname_with_scope());
  713. new_cnode->set_scope(cnode->scope());
  714. graph->FrontBackendlMapAdd(node, new_cnode);
  715. if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) {
  716. graph->set_return(new_cnode);
  717. }
  718. }
  719. std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
  720. std::vector<KernelGraphPtr> *all_out_graph) {
  721. MS_EXCEPTION_IF_NULL(func_graph);
  722. MS_EXCEPTION_IF_NULL(all_out_graph);
  723. auto node_list = TopoSort(func_graph->get_return());
  724. auto graph = NewKernelGraph();
  725. MS_EXCEPTION_IF_NULL(graph);
  726. front_backend_graph_map_[func_graph] = graph;
  727. MS_LOG(INFO) << "Create graph: " << graph->graph_id();
  728. bool is_trace_back = false;
  729. for (const auto &node : node_list) {
  730. MS_EXCEPTION_IF_NULL(node);
  731. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  732. if (node->isa<Parameter>()) {
  733. auto graph_inputs = graph->MutableInputs();
  734. MS_EXCEPTION_IF_NULL(graph_inputs);
  735. auto new_parameter = CreateNewParameter(node, graph.get());
  736. graph_inputs->push_back(new_parameter);
  737. graph->FrontBackendlMapAdd(node, new_parameter);
  738. continue;
  739. } else if (node->isa<ValueNode>()) {
  740. if (!IsValueNode<FuncGraph>(node)) {
  741. // if input is a common value node,
  742. (void)CreateNewValueNode(node, graph.get());
  743. } else {
  744. // if input is a ValueNode<FuncGraph>
  745. FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
  746. if (front_backend_graph_map_.find(child_graph) != front_backend_graph_map_.end()) {
  747. is_trace_back = true;
  748. } else {
  749. (void)ConstructKernelGraph(child_graph, all_out_graph);
  750. }
  751. (void)CreateValueNodeKernelGraph(node, graph.get());
  752. }
  753. continue;
  754. } else {
  755. CreateCNodeKernelGraph(node, graph);
  756. }
  757. }
  758. // if a graph jump back unconditionally, return op of this graph will never be executed, so output is null.
  759. graph->set_output_null(is_trace_back);
  760. AddParameterToGraphInputs(func_graph->parameters(), graph.get());
  761. graph->SetExecOrderByDefault();
  762. if (ExistSummaryNode(graph.get())) {
  763. graph->set_summary_node_exist(true);
  764. }
  765. all_out_graph->push_back(graph);
  766. return graph;
  767. }
  768. void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) {
  769. MS_EXCEPTION_IF_NULL(graph);
  770. auto graph_inputs = graph->MutableInputs();
  771. MS_EXCEPTION_IF_NULL(graph_inputs);
  772. graph_inputs->clear();
  773. for (auto &parameter : parameters) {
  774. MS_EXCEPTION_IF_NULL(parameter);
  775. auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
  776. if (backend_parameter == nullptr) {
  777. // for example "def f(x,y,z) {return x + y}", parameter z in unused
  778. auto new_parameter = CreateNewParameter(parameter, graph);
  779. graph_inputs->push_back(new_parameter);
  780. MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
  781. continue;
  782. }
  783. MS_LOG(INFO) << "Graph[" << graph->graph_id() << "],parameter:" << parameter->DebugString();
  784. graph_inputs->push_back(backend_parameter);
  785. }
  786. }
  787. // run graph steps
  788. void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  789. const std::vector<tensor::TensorPtr> &inputs_const) const {
  790. std::vector<tensor::TensorPtr> inputs(inputs_const);
  791. size_t input_ctrl_size = 2;
  792. MS_EXCEPTION_IF_NULL(kernel_graph);
  793. if (kernel_graph->input_ctrl_tensors()) {
  794. input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
  795. }
  796. auto input_nodes = kernel_graph->inputs();
  797. if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) {
  798. MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  799. << ", input_ctrl_size:" << input_ctrl_size;
  800. }
  801. auto ms_context = MsContext::GetInstance();
  802. MS_EXCEPTION_IF_NULL(ms_context);
  803. for (size_t i = 0; i < inputs.size(); ++i) {
  804. auto tensor = inputs[i];
  805. MS_EXCEPTION_IF_NULL(tensor);
  806. auto input_node = input_nodes[i];
  807. MS_EXCEPTION_IF_NULL(input_node);
  808. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  809. auto pk_node = input_node->cast<ParameterPtr>();
  810. auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
  811. bool need_sync = false;
  812. if (ms_context->enable_pynative_infer()) {
  813. if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
  814. need_sync = true;
  815. }
  816. } else {
  817. if (tensor->is_dirty()) {
  818. need_sync = true;
  819. } else if (tensor->device_address() != device_address) {
  820. (void)tensor->data_sync();
  821. need_sync = true;
  822. }
  823. }
  824. if (need_sync) {
  825. if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) {
  826. tensor->set_device_address(device_address);
  827. }
  828. MS_EXCEPTION_IF_NULL(device_address);
  829. if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
  830. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  831. tensor->data_c())) {
  832. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  833. }
  834. }
  835. }
  836. tensor->set_dirty(false);
  837. }
  838. }
  839. void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  840. const std::vector<tensor::TensorPtr> &input_tensors) const {
  841. MS_EXCEPTION_IF_NULL(kernel_graph);
  842. MS_EXCEPTION_IF_NULL(outputs);
  843. if (!kernel_graph->child_graph_order().empty()) {
  844. // use the last child graph output as the root graph output
  845. UpdateOutputs(kernel_graph->child_graph_order().back(), outputs, input_tensors);
  846. return;
  847. }
  848. auto anf_outputs = kernel_graph->outputs();
  849. for (auto &item : anf_outputs) {
  850. MS_EXCEPTION_IF_NULL(item);
  851. MS_LOG(INFO) << "Update output[" << item->DebugString() << "]";
  852. if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
  853. outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
  854. continue;
  855. }
  856. outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
  857. }
  858. }
  859. void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
  860. MS_EXCEPTION_IF_NULL(callback);
  861. summary_callback_ = callback;
  862. }
  863. void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); }
  864. void SessionBasic::GetSummaryNodes(KernelGraph *graph) {
  865. MS_LOG(DEBUG) << "Update summary Start";
  866. MS_EXCEPTION_IF_NULL(graph);
  867. if (!graph->summary_node_exist()) {
  868. return;
  869. }
  870. auto summary = graph->summary_nodes();
  871. auto apply_list = TopoSort(graph->get_return());
  872. for (auto &n : apply_list) {
  873. MS_EXCEPTION_IF_NULL(n);
  874. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  875. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  876. auto cnode = n->cast<CNodePtr>();
  877. MS_EXCEPTION_IF_NULL(cnode);
  878. if (cnode->inputs().size() <= kSummaryGetItem) {
  879. MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!";
  880. }
  881. auto node = cnode->input(kSummaryGetItem);
  882. MS_EXCEPTION_IF_NULL(node);
  883. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
  884. MS_EXCEPTION_IF_NULL(item_with_index.first);
  885. if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
  886. MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
  887. }
  888. summary[n->fullname_with_scope()] = item_with_index;
  889. }
  890. }
  891. graph->set_summary_nodes(summary);
  892. MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
  893. }
  894. void SessionBasic::Summary(KernelGraph *graph) {
  895. if (summary_callback_ == nullptr) {
  896. return;
  897. }
  898. MS_EXCEPTION_IF_NULL(graph);
  899. bool exist_summary = graph->summary_node_exist();
  900. if (!exist_summary) {
  901. return;
  902. }
  903. GetSummaryNodes(graph);
  904. auto summary_outputs = graph->summary_nodes();
  905. std::map<std::string, tensor::TensorPtr> params_list;
  906. // fetch outputs apply kernel in session & run callback functions
  907. for (auto &output_item : summary_outputs) {
  908. auto node = output_item.second.first;
  909. size_t index = IntToSize(output_item.second.second);
  910. auto address = AnfAlgo::GetOutputAddr(node, index);
  911. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  912. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  913. std::vector<int> temp_shape;
  914. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  915. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  916. MS_EXCEPTION_IF_NULL(address);
  917. if (!address->GetPtr()) {
  918. continue;
  919. }
  920. if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
  921. tensor->data_type(), tensor->data_c())) {
  922. MS_LOG(ERROR) << "Failed to sync output from device to host.";
  923. }
  924. tensor->set_dirty(false);
  925. params_list[output_item.first] = tensor;
  926. }
  927. // call callback function here
  928. summary_callback_(0, params_list);
  929. }
  930. CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
  931. MS_EXCEPTION_IF_NULL(graph);
  932. std::vector<AnfNodePtr> output_args;
  933. for (const auto &output : outputs) {
  934. MS_EXCEPTION_IF_NULL(output);
  935. MS_LOG(INFO) << "Output:" << output->DebugString();
  936. }
  937. auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
  938. auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
  939. if (backend_anf != nullptr) {
  940. auto context_ptr = MsContext::GetInstance();
  941. MS_EXCEPTION_IF_NULL(context_ptr);
  942. if (context_ptr->execution_mode() == kPynativeMode) {
  943. return backend_anf;
  944. }
  945. auto front_real_kernel = AnfAlgo::VisitKernel(out, 0);
  946. auto backend_real_kernel = AnfAlgo::VisitKernel(backend_anf, 0);
  947. MS_EXCEPTION_IF_NULL(out);
  948. auto out_func_graph = out->func_graph();
  949. MS_EXCEPTION_IF_NULL(out_func_graph);
  950. auto out_func_graph_manager = out_func_graph->manager();
  951. if (out_func_graph_manager == nullptr) {
  952. return backend_anf;
  953. }
  954. auto node_users = out_func_graph_manager->node_users();
  955. auto users = node_users[out];
  956. bool internal_output = true;
  957. std::string kernel_target = GetCNodeTarget(front_real_kernel.first);
  958. for (auto user : users) {
  959. if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) {
  960. internal_output = false;
  961. break;
  962. }
  963. }
  964. if (internal_output) {
  965. MS_LOG(INFO) << "Internal output1: " << out->DebugString() << "To " << backend_real_kernel.first->DebugString();
  966. graph->AddInternalOutput(out, backend_real_kernel.first);
  967. }
  968. return backend_anf;
  969. }
  970. MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
  971. };
  972. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  973. (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
  974. [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
  975. return graph->NewCNode(output_args);
  976. }
  977. void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
  978. MS_LOG(INFO) << "Start!";
  979. std::vector<AnfNodePtr> make_tuple_inputs;
  980. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  981. MS_EXCEPTION_IF_NULL(graph);
  982. if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
  983. for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
  984. auto idx = NewValueNode(SizeToInt(output_index));
  985. MS_EXCEPTION_IF_NULL(idx);
  986. auto imm = std::make_shared<Int32Imm>(output_index);
  987. idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
  988. auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
  989. std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
  990. std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
  991. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
  992. make_tuple_inputs.push_back(getitem);
  993. }
  994. } else {
  995. make_tuple_inputs.push_back(cnode);
  996. }
  997. // create output
  998. auto g_output = graph->NewCNode(make_tuple_inputs);
  999. graph->set_output(g_output);
  1000. // set graph manager,which now is only used to get valuenodes and hardware optimizing
  1001. MS_EXCEPTION_IF_NULL(context_);
  1002. FuncGraphManagerPtr manager = context_->manager();
  1003. if (manager != nullptr) {
  1004. manager->AddFuncGraph(graph);
  1005. graph->set_manager(manager);
  1006. }
  1007. MS_LOG(INFO) << "Finish!";
  1008. }
  1009. std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  1010. const std::vector<tensor::TensorPtr> &input_tensors,
  1011. const std::vector<int> &tensors_mask) {
  1012. auto graph = std::make_shared<KernelGraph>();
  1013. std::vector<AnfNodePtr> inputs;
  1014. // set input[0]
  1015. PrimitivePtr op_prim = op_run_info.py_primitive;
  1016. MS_EXCEPTION_IF_NULL(op_prim);
  1017. inputs.push_back(std::make_shared<ValueNode>(op_prim));
  1018. // set input parameter
  1019. MS_LOG(INFO) << "Input tensor size: " << input_tensors.size();
  1020. if (input_tensors.size() != tensors_mask.size()) {
  1021. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  1022. << tensors_mask.size();
  1023. }
  1024. for (size_t i = 0; i < input_tensors.size(); ++i) {
  1025. if (tensors_mask[i] == kValueNodeTensorMask) {
  1026. auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]);
  1027. inputs.push_back(value_node);
  1028. continue;
  1029. }
  1030. auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
  1031. inputs.push_back(parameter);
  1032. auto mutable_inputs = graph->MutableInputs();
  1033. MS_EXCEPTION_IF_NULL(mutable_inputs);
  1034. mutable_inputs->push_back(parameter);
  1035. }
  1036. // set execution order
  1037. auto cnode = graph->NewCNode(inputs);
  1038. MS_EXCEPTION_IF_NULL(cnode);
  1039. // set abstract,which include inferred shapes and types
  1040. cnode->set_abstract(op_run_info.abstract);
  1041. // set execution order
  1042. std::vector<CNodePtr> exe_order = {cnode};
  1043. graph->set_execution_order(exe_order);
  1044. // set output
  1045. CreateOutputNode(cnode, graph);
  1046. return graph;
  1047. }
  1048. BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
  1049. if (utils::isa<VectorRef>(base_ref)) {
  1050. auto ref_list = utils::cast<VectorRef>(base_ref);
  1051. py::tuple output_tensors(ref_list.size());
  1052. for (size_t i = 0; i < ref_list.size(); ++i) {
  1053. auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
  1054. if (utils::isa<tensor::TensorPtr>(output)) {
  1055. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  1056. MS_EXCEPTION_IF_NULL(tensor_ptr);
  1057. output_tensors[i] = tensor_ptr;
  1058. } else if (utils::isa<PyObjectRef>(output)) {
  1059. py::object obj = utils::cast<PyObjectRef>(output).object_;
  1060. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  1061. output_tensors[i] = tensor_tuple;
  1062. } else {
  1063. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  1064. }
  1065. }
  1066. return output_tensors; // turn tuple to py::object and store in PyObjectRef
  1067. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  1068. return base_ref;
  1069. } else {
  1070. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  1071. }
  1072. }
  1073. KernelGraphPtr SessionBasic::NewKernelGraph() {
  1074. auto graph = std::make_shared<KernelGraph>();
  1075. graph->set_graph_id(graph_sum_);
  1076. graphs_[graph_sum_++] = graph;
  1077. return graph;
  1078. }
  1079. } // namespace session
  1080. } // namespace mindspore