You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

session_basic.cc 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "session/session_basic.h"
  17. #include <utility>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include <unordered_set>
  21. #include "pipeline/parse/data_converter.h"
  22. #include "ir/manager.h"
  23. #include "operator/ops.h"
  24. #include "common/trans.h"
  25. #include "utils/context/ms_context.h"
  26. #include "utils/config_manager.h"
  27. #include "session/anf_runtime_algorithm.h"
  28. #include "kernel/oplib/oplib.h"
  29. #include "pre_activate/common/common_backend_optimization.h"
  30. #include "pre_activate/pass/const_input_to_attr_registry.h"
  31. #include "pre_activate/common/helper.h"
  32. #include "common/utils.h"
  33. #include "ir/dtype.h"
  34. namespace mindspore {
  35. namespace session {
  36. namespace {
  37. const int kSummaryGetItem = 2;
  38. void GetSummaryNodes(const KernelGraph *graph, std::unordered_map<std::string, std::pair<AnfNodePtr, int>> *summary) {
  39. MS_LOG(DEBUG) << "Update summary Start";
  40. MS_EXCEPTION_IF_NULL(graph);
  41. MS_EXCEPTION_IF_NULL(summary);
  42. summary->clear();
  43. auto apply_list = TopoSort(graph->get_return());
  44. for (auto &n : apply_list) {
  45. MS_EXCEPTION_IF_NULL(n);
  46. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  47. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  48. int index = 0;
  49. auto cnode = n->cast<CNodePtr>();
  50. MS_EXCEPTION_IF_NULL(cnode);
  51. if (cnode->inputs().size() <= kSummaryGetItem) {
  52. MS_LOG(EXCEPTION) << "the node Summary should have 2 inputs at least!";
  53. }
  54. auto node = cnode->input(kSummaryGetItem);
  55. MS_EXCEPTION_IF_NULL(node);
  56. if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
  57. auto c = node->cast<CNodePtr>();
  58. MS_EXCEPTION_IF_NULL(c);
  59. if (c->inputs().size() != kTupleGetItemInputSize) {
  60. MS_LOG(EXCEPTION) << "the node tuple_get_item must have 2 inputs!";
  61. }
  62. MS_EXCEPTION_IF_NULL(c->input(kInputNodeOutputIndexInTupleGetItem));
  63. auto value_node = c->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>();
  64. auto value = value_node->value();
  65. MS_EXCEPTION_IF_NULL(value);
  66. Int32ImmPtr int_imm_ptr = value->cast<Int32ImmPtr>();
  67. MS_EXCEPTION_IF_NULL(int_imm_ptr);
  68. index = int_imm_ptr->value();
  69. node = c->input(kRealInputNodeIndexInTupleGetItem);
  70. }
  71. std::pair<AnfNodePtr, int> output_pair(node, index);
  72. // get full name with scope will add scalar or tensor or image summary tag.
  73. (*summary)[n->fullname_with_scope()] = output_pair;
  74. }
  75. }
  76. MS_LOG(DEBUG) << "Update summary end size: " << (*summary).size();
  77. }
  78. bool ExistSummaryNode(const KernelGraph *graph) {
  79. auto ret = graph->get_return();
  80. MS_EXCEPTION_IF_NULL(ret);
  81. auto all_nodes = DeepLinkedGraphSearch(ret);
  82. for (auto &n : all_nodes) {
  83. if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
  84. IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
  85. return true;
  86. }
  87. }
  88. return false;
  89. }
  90. BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
  91. const std::vector<tensor::TensorPtr> &input_tensors) {
  92. MS_EXCEPTION_IF_NULL(node);
  93. MS_LOG(INFO) << "create tensor for output[" << node->DebugString() << "] index[" << output_index << "]";
  94. // if node is a value node, no need sync addr from device to host
  95. if (!AnfAlgo::OutputAddrExist(node, output_index)) {
  96. if (node->isa<ValueNode>()) {
  97. auto value_node = node->cast<ValueNodePtr>();
  98. MS_EXCEPTION_IF_NULL(value_node);
  99. return value_node->value();
  100. }
  101. if (node->isa<Parameter>()) {
  102. for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) {
  103. if (input_idx > input_tensors.size()) {
  104. MS_LOG(EXCEPTION) << "input idx:" << input_idx << "out of range:" << input_tensors.size();
  105. }
  106. if (graph.inputs()[input_idx] == node) {
  107. return input_tensors[input_idx];
  108. }
  109. }
  110. MS_LOG(EXCEPTION) << "parameter : " << node->DebugString() << "has no output addr";
  111. }
  112. }
  113. // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode)
  114. auto address = AnfAlgo::GetOutputAddr(node, output_index);
  115. MS_EXCEPTION_IF_NULL(address);
  116. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  117. TypeId type_id = kNumberTypeFloat32;
  118. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  119. std::vector<int> temp_shape;
  120. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  121. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  122. // if in paynative mode,data only copyed to host when user want to print data
  123. auto ms_context = MsContext::GetInstance();
  124. MS_EXCEPTION_IF_NULL(ms_context);
  125. if (ms_context->execution_mode() == kPynativeMode) {
  126. tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index));
  127. } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index),
  128. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  129. tensor->data_c(true))) {
  130. MS_LOG(INFO) << "output sync device to host error!!!";
  131. tensor->set_dirty(false);
  132. }
  133. return tensor;
  134. }
  135. BaseRef CreatTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  136. const std::vector<tensor::TensorPtr> &input_tensors) {
  137. MS_EXCEPTION_IF_NULL(anf);
  138. MS_LOG(INFO) << "create tensor for output[" << anf->DebugString() << "]";
  139. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
  140. MS_EXCEPTION_IF_NULL(item_with_index.first);
  141. // special handle for maketuple
  142. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  143. auto cnode = item_with_index.first->cast<CNodePtr>();
  144. MS_EXCEPTION_IF_NULL(cnode);
  145. VectorRef ret;
  146. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  147. auto out = CreatTensorForOutput(cnode->input(i), graph, input_tensors);
  148. ret.push_back(out);
  149. }
  150. return ret;
  151. }
  152. // if is graph return nothing ,the function should return a null anylist
  153. size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
  154. if (size == 0) {
  155. return VectorRef();
  156. }
  157. return CreateOneTensor(item_with_index.first, item_with_index.second, graph, input_tensors);
  158. }
  159. BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
  160. const std::vector<tensor::TensorPtr> &input_tensors) {
  161. MS_EXCEPTION_IF_NULL(anf);
  162. if (!AnfAlgo::IsRealKernel(anf)) {
  163. MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] should be a executable kernel";
  164. }
  165. if (anf->isa<ValueNode>()) {
  166. return CreateOneTensor(anf, 0, graph, input_tensors);
  167. }
  168. VectorRef ret;
  169. if (anf->isa<CNode>() && AnfAlgo::GetCNodeName(anf) != prim::kPrimMakeTuple->name()) {
  170. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf); ++i) {
  171. auto out = CreateOneTensor(anf, i, graph, input_tensors);
  172. ret.emplace_back(out);
  173. }
  174. }
  175. return ret;
  176. }
  177. ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
  178. auto value_node = anf->cast<ValueNodePtr>();
  179. MS_EXCEPTION_IF_NULL(value_node);
  180. auto value = value_node->value();
  181. MS_EXCEPTION_IF_NULL(value);
  182. if (value->isa<None>()) {
  183. return nullptr;
  184. }
  185. auto new_value_node = graph->NewValueNode(value_node);
  186. graph->FrontBackendlMapAdd(anf, new_value_node);
  187. graph->AddValueNodeToGraph(new_value_node);
  188. return new_value_node;
  189. }
  190. ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
  191. MS_EXCEPTION_IF_NULL(anf);
  192. if (!anf->isa<Parameter>()) {
  193. MS_LOG(EXCEPTION) << "anf[" << anf->DebugString() << "] is not a parameter";
  194. }
  195. auto graph_inputs = graph->MutableInputs();
  196. MS_EXCEPTION_IF_NULL(graph_inputs);
  197. auto valid_inputs = graph->MutableValidInputs();
  198. MS_EXCEPTION_IF_NULL(valid_inputs);
  199. ParameterPtr new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
  200. graph_inputs->push_back(new_parameter);
  201. valid_inputs->push_back(valid_input);
  202. return new_parameter;
  203. }
  204. std::vector<AnfNodePtr> CreateParameterFromTuple(const AnfNodePtr &node, bool valid_input, KernelGraph *graph) {
  205. MS_EXCEPTION_IF_NULL(node);
  206. MS_EXCEPTION_IF_NULL(graph);
  207. std::vector<AnfNodePtr> parameters;
  208. std::vector<AnfNodePtr> pre_graph_out = {node};
  209. // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
  210. if (!AnfAlgo::IsRealKernel(node)) {
  211. pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem});
  212. }
  213. auto valid_inputs = graph->MutableValidInputs();
  214. MS_EXCEPTION_IF_NULL(valid_inputs);
  215. auto graph_inputs = graph->MutableInputs();
  216. MS_EXCEPTION_IF_NULL(graph_inputs);
  217. auto create_parameter = [&](const AbstractBasePtr &abstract) -> void {
  218. auto parameter = graph->NewParameter();
  219. MS_EXCEPTION_IF_NULL(parameter);
  220. parameter->set_abstract(abstract);
  221. auto new_parameter = graph->NewParameter(parameter);
  222. parameters.push_back(new_parameter);
  223. valid_inputs->push_back(valid_input);
  224. graph_inputs->push_back(new_parameter);
  225. };
  226. for (const auto &out_node : pre_graph_out) {
  227. MS_EXCEPTION_IF_NULL(out_node);
  228. auto abstract = out_node->abstract();
  229. MS_EXCEPTION_IF_NULL(abstract);
  230. // create multiple parameters if is a tuple output real kernel
  231. if (abstract->isa<abstract::AbstractTuple>() && !AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
  232. auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
  233. MS_EXCEPTION_IF_NULL(tuple_abstract);
  234. MS_LOG(INFO) << "tuple_size [" << tuple_abstract->size() << "]";
  235. for (size_t output_idx = 0; output_idx < tuple_abstract->size(); output_idx++) {
  236. create_parameter((*tuple_abstract)[output_idx]);
  237. }
  238. continue;
  239. }
  240. // create single parameter if is a abstract real kernel
  241. create_parameter(out_node->abstract());
  242. }
  243. return parameters;
  244. }
  245. AnfNodePtr CreateNewParameterFromCNode(const AnfNodePtr &anf, bool valid_input, KernelGraph *graph) {
  246. MS_EXCEPTION_IF_NULL(anf);
  247. if (!anf->isa<CNode>()) {
  248. MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a cnode";
  249. }
  250. MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
  251. auto parameters = CreateParameterFromTuple(anf, valid_input, graph);
  252. if (parameters.empty()) {
  253. MS_LOG(EXCEPTION) << "No parameter exist!!";
  254. }
  255. if (parameters.size() == 1) {
  256. return parameters[0];
  257. }
  258. std::vector<AnfNodePtr> make_tuple_input = {NewValueNode(prim::kPrimMakeTuple)};
  259. (void)std::copy(parameters.begin(), parameters.end(), std::back_inserter(make_tuple_input));
  260. auto make_tuple = graph->NewCNode(make_tuple_input);
  261. MS_EXCEPTION_IF_NULL(make_tuple);
  262. MS_LOG(INFO) << "New make tuple [" << make_tuple->DebugString() << "] of parameters";
  263. return make_tuple;
  264. }
  265. bool NeedInsertSwitch() {
  266. auto context_ptr = MsContext::GetInstance();
  267. MS_EXCEPTION_IF_NULL(context_ptr);
  268. return (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() &&
  269. ConfigManager::GetInstance().iter_num() > 1);
  270. }
  271. size_t LoadCtrlInputTensor(const std::shared_ptr<Context> &context, std::vector<tensor::TensorPtr> *inputs) {
  272. MS_EXCEPTION_IF_NULL(context);
  273. if (!NeedInsertSwitch()) {
  274. (void)context->results_.erase(kInputCtrlTensors);
  275. return 0;
  276. }
  277. MS_LOG(INFO) << "Load kInputCtrlTensors";
  278. auto inputs_params =
  279. context->GetResult(kInputCtrlTensors).cast<const std::shared_ptr<std::vector<tensor::TensorPtr>>>();
  280. MS_EXCEPTION_IF_NULL(inputs_params);
  281. if (inputs_params->empty()) {
  282. MS_LOG(EXCEPTION) << "Illegal empty inputs_params";
  283. }
  284. auto tensor = (*inputs_params)[0];
  285. MS_EXCEPTION_IF_NULL(tensor);
  286. auto *val = static_cast<int32_t *>(tensor->data_c(true));
  287. MS_EXCEPTION_IF_NULL(val);
  288. *val = 0;
  289. tensor->set_dirty(true);
  290. // set loop_count to zero
  291. MS_EXCEPTION_IF_NULL(inputs);
  292. inputs->push_back(tensor);
  293. return inputs_params->size();
  294. }
  295. ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
  296. bool is_weight) {
  297. auto param = graph->NewParameter();
  298. MS_EXCEPTION_IF_NULL(param);
  299. if (is_weight) {
  300. py::object obj;
  301. param->set_default_param(obj);
  302. }
  303. // set the kernel info of parameter
  304. auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  305. MS_EXCEPTION_IF_NULL(input_tensor);
  306. if (input_tensor->device_address().get() == nullptr) {
  307. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
  308. TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
  309. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
  310. } else {
  311. kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{input_tensor->device_address()->format()});
  312. kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{input_tensor->device_address()->type_id()});
  313. }
  314. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
  315. // ftruct abstract of parameter
  316. auto abstract = std::make_shared<abstract::AbstractTensor>(input_tensor);
  317. param->set_abstract(abstract);
  318. return param;
  319. }
  320. void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
  321. MS_LOG(INFO) << "graph outputs:";
  322. const size_t max_deep = 10;
  323. if (recurse_level > max_deep) {
  324. MS_LOG(INFO) << "recurse too deep";
  325. return;
  326. }
  327. std::string tab_str;
  328. for (size_t i = 0; i < recurse_level; i++) {
  329. tab_str = tab_str.append(" ");
  330. }
  331. if (any.is<AnyList>()) {
  332. (void)tab_str.append("{");
  333. MS_LOG(INFO) << tab_str;
  334. auto any_list = any.cast<AnyList>();
  335. for (auto &it : any_list) {
  336. DumpGraphOutput(it, recurse_level + 1);
  337. }
  338. (void)tab_str.append("}");
  339. MS_LOG(INFO) << tab_str;
  340. }
  341. (void)tab_str.append(any.ToString());
  342. MS_LOG(INFO) << tab_str;
  343. }
  344. } // namespace
  345. GraphId SessionBasic::graph_sum_ = 0;
  346. CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, KernelGraph *graph,
  347. bool *from_other_graph,
  348. std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
  349. MS_EXCEPTION_IF_NULL(cnode);
  350. MS_EXCEPTION_IF_NULL(graph);
  351. MS_EXCEPTION_IF_NULL(from_other_graph);
  352. MS_EXCEPTION_IF_NULL(other_graph_cnode);
  353. *from_other_graph = false;
  354. // get primitive of old node
  355. auto prim = AnfAlgo::GetCNodePrimitive(cnode);
  356. MS_EXCEPTION_IF_NULL(prim);
  357. // push attr to inputs[0] of new cnode
  358. std::vector<AnfNodePtr> cnode_inputs = {std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim))};
  359. // if has multiple depends,only select first depend as parameter
  360. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); input_idx++) {
  361. auto anf = cnode->inputs()[input_idx];
  362. MS_EXCEPTION_IF_NULL(anf);
  363. // anf has been created before
  364. if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
  365. cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf));
  366. continue;
  367. } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
  368. cnode_inputs.push_back((*other_graph_cnode)[anf]);
  369. continue;
  370. } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
  371. // if input is a value node,
  372. auto new_value_node = CreateNewValueNode(anf, graph);
  373. if (new_value_node != nullptr) {
  374. cnode_inputs.emplace_back(new_value_node);
  375. }
  376. continue;
  377. } else if (anf->isa<Parameter>()) {
  378. // if anf is a parameter
  379. auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
  380. cnode_inputs.push_back(new_parameter);
  381. if (GetGraphIdByNode(anf) == kInvalidGraphId) {
  382. graph->FrontBackendlMapAdd(anf, new_parameter);
  383. } else {
  384. (*other_graph_cnode)[anf] = new_parameter;
  385. }
  386. continue;
  387. } else if (anf->isa<CNode>()) {
  388. *from_other_graph = true;
  389. // the input node is a cnode from other graph
  390. auto parameter_from_cnode = CreateNewParameterFromCNode(anf, valid_input, graph);
  391. cnode_inputs.push_back(parameter_from_cnode);
  392. (*other_graph_cnode)[anf] = parameter_from_cnode;
  393. continue;
  394. }
  395. MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
  396. }
  397. TraceManager::DebugTrace(std::make_shared<TraceCopy>(cnode->debug_info()));
  398. auto new_cnode = graph->NewCNode(cnode_inputs);
  399. TraceManager::EndTrace();
  400. return new_cnode;
  401. }
  402. KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
  403. std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
  404. auto graph = std::make_shared<KernelGraph>();
  405. graph->set_graph_id(graph_sum_);
  406. MS_LOG(INFO) << "Create graph: " << graph_sum_;
  407. size_t from_other_graph_depend_num = 0;
  408. for (const auto &node : lst) {
  409. MS_EXCEPTION_IF_NULL(node);
  410. MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
  411. if (!node->isa<CNode>()) {
  412. MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
  413. }
  414. auto cnode = node->cast<CNodePtr>();
  415. MS_EXCEPTION_IF_NULL(cnode);
  416. // create a new cnode object
  417. bool from_other_graph = false;
  418. // only first depend from other graph can create
  419. bool valid_input = true;
  420. if (from_other_graph_depend_num != 0 && AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) {
  421. valid_input = false;
  422. }
  423. auto new_cnode = CreateNewCNode(cnode, valid_input, graph.get(), &from_other_graph, &other_graph_cnode);
  424. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) && from_other_graph) {
  425. from_other_graph_depend_num++;
  426. }
  427. MS_EXCEPTION_IF_NULL(new_cnode);
  428. new_cnode->set_abstract(cnode->abstract());
  429. new_cnode->set_scope(cnode->scope());
  430. // record map relations between anf from ME and new anf node used in backend
  431. graph->FrontBackendlMapAdd(node, new_cnode);
  432. }
  433. // add a make_tuple at the end of graph as output
  434. graph->set_output(ConstructOutput(outputs, graph));
  435. MS_EXCEPTION_IF_NULL(context_);
  436. FuncGraphManagerPtr manager = context_->manager();
  437. if (manager) {
  438. manager->AddFuncGraph(graph);
  439. graph->set_manager(manager);
  440. }
  441. graph->SetExecOrderByDefault();
  442. opt::BackendCommonOptimization(graph);
  443. graphs_[graph_sum_++] = graph;
  444. return graph;
  445. }
  446. // run graph steps
  447. void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
  448. const std::vector<tensor::TensorPtr> &inputs_const) const {
  449. std::vector<tensor::TensorPtr> inputs(inputs_const);
  450. size_t input_ctrl_size = 1;
  451. MS_EXCEPTION_IF_NULL(context_);
  452. if (context_->HasResult(kInputCtrlTensors)) {
  453. input_ctrl_size = LoadCtrlInputTensor(context_, &inputs);
  454. }
  455. MS_EXCEPTION_IF_NULL(kernel_graph);
  456. auto input_nodes = kernel_graph->inputs();
  457. if ((inputs.size() + input_ctrl_size) - 1 != input_nodes.size()) {
  458. MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
  459. << ", input_ctrl_size:" << input_ctrl_size;
  460. }
  461. auto ms_context = MsContext::GetInstance();
  462. MS_EXCEPTION_IF_NULL(ms_context);
  463. for (size_t i = 0; i < inputs.size(); ++i) {
  464. auto tensor = inputs[i];
  465. MS_EXCEPTION_IF_NULL(tensor);
  466. auto input_node = input_nodes[i];
  467. MS_EXCEPTION_IF_NULL(input_node);
  468. if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
  469. auto pk_node = input_node->cast<ParameterPtr>();
  470. auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
  471. bool need_sync = false;
  472. if (ms_context->enable_pynative_infer()) {
  473. if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
  474. need_sync = true;
  475. }
  476. } else {
  477. if (tensor->is_dirty() || !AnfAlgo::IsParameterWeight(pk_node)) {
  478. need_sync = true;
  479. } else if (tensor->device_address() != device_address) {
  480. (void)tensor->data_sync();
  481. need_sync = true;
  482. }
  483. }
  484. if (need_sync) {
  485. tensor->set_device_address(device_address);
  486. MS_EXCEPTION_IF_NULL(device_address);
  487. if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
  488. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  489. tensor->data_c(false))) {
  490. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  491. }
  492. }
  493. }
  494. tensor->set_dirty(false);
  495. }
  496. }
  497. void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
  498. const std::vector<tensor::TensorPtr> &input_tensors) const {
  499. MS_EXCEPTION_IF_NULL(kernel_graph);
  500. MS_EXCEPTION_IF_NULL(outputs);
  501. auto anf_outputs = kernel_graph->outputs();
  502. for (auto &item : anf_outputs) {
  503. MS_LOG(INFO) << "update output[" << item->DebugString() << "]";
  504. MS_EXCEPTION_IF_NULL(item);
  505. if (AnfAlgo::IsTupleOutput(item) && AnfAlgo::IsRealKernel(item)) {
  506. outputs->emplace_back(CreatTupleForOutput(item, *kernel_graph, input_tensors));
  507. continue;
  508. }
  509. outputs->emplace_back(CreatTensorForOutput(item, *kernel_graph, input_tensors));
  510. }
  511. }
  512. void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
  513. MS_EXCEPTION_IF_NULL(callback);
  514. summary_callback_ = callback;
  515. }
  516. void SessionBasic::Reorder(std::vector<CNodePtr> *node_list) {
  517. MS_EXCEPTION_IF_NULL(node_list);
  518. std::vector<CNodePtr> all_opt_list;
  519. std::vector<CNodePtr> non_opt_list;
  520. for (const auto &node : *node_list) {
  521. MS_EXCEPTION_IF_NULL(node);
  522. if (kOptOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kOptOperatorSet.end()) {
  523. all_opt_list.emplace_back(node);
  524. } else {
  525. non_opt_list.emplace_back(node);
  526. }
  527. }
  528. node_list->clear();
  529. (void)std::copy(non_opt_list.begin(), non_opt_list.end(), std::back_inserter(*node_list));
  530. (void)std::copy(all_opt_list.begin(), all_opt_list.end(), std::back_inserter(*node_list));
  531. }
  532. void SessionBasic::Summary(KernelGraph *graph) {
  533. if (summary_callback_ == nullptr) {
  534. return;
  535. }
  536. MS_EXCEPTION_IF_NULL(graph);
  537. bool exist_summary = ExistSummaryNode(graph);
  538. if (!exist_summary) {
  539. return;
  540. }
  541. std::unordered_map<std::string, std::pair<AnfNodePtr, int>> summary_outputs;
  542. GetSummaryNodes(graph, &summary_outputs);
  543. std::map<std::string, tensor::TensorPtr> params_list;
  544. // fetch outputs apply kernel in session & run callback functions
  545. for (auto &output_item : summary_outputs) {
  546. auto node = output_item.second.first;
  547. size_t index = IntToSize(output_item.second.second);
  548. auto address = AnfAlgo::GetOutputAddr(node, index);
  549. auto shape = AnfAlgo::GetOutputInferShape(node, index);
  550. TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
  551. std::vector<int> temp_shape;
  552. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  553. tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  554. MS_EXCEPTION_IF_NULL(address);
  555. if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
  556. tensor->data_type(), tensor->data_c(true))) {
  557. MS_LOG(ERROR) << "Failed to sync output from device to host.";
  558. }
  559. tensor->set_dirty(false);
  560. params_list[output_item.first] = tensor;
  561. }
  562. // call callback function here
  563. summary_callback_(0, params_list);
  564. }
  565. CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
  566. MS_EXCEPTION_IF_NULL(graph);
  567. std::vector<AnfNodePtr> output_args;
  568. auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr {
  569. auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
  570. if (backend_anf != nullptr) {
  571. return backend_anf;
  572. }
  573. for (const auto &output : outputs) {
  574. MS_LOG(INFO) << "output:" << output->DebugString();
  575. }
  576. MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
  577. };
  578. output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
  579. (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
  580. [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
  581. return graph->NewCNode(output_args);
  582. }
  583. void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
  584. MS_LOG(INFO) << "Start!";
  585. std::vector<AnfNodePtr> make_tuple_inputs;
  586. make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
  587. if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
  588. for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
  589. auto idx = NewValueNode(SizeToInt(output_index));
  590. MS_EXCEPTION_IF_NULL(idx);
  591. auto imm = std::make_shared<Int32Imm>(output_index);
  592. idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
  593. MS_EXCEPTION_IF_NULL(graph);
  594. auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
  595. std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
  596. std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
  597. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
  598. make_tuple_inputs.push_back(getitem);
  599. }
  600. } else {
  601. make_tuple_inputs.push_back(cnode);
  602. }
  603. // create output
  604. auto g_output = graph->NewCNode(make_tuple_inputs);
  605. graph->set_output(g_output);
  606. // set graph manager,which now is only used to get valuenodes and hardware optimizing
  607. MS_EXCEPTION_IF_NULL(context_);
  608. FuncGraphManagerPtr manager = context_->manager();
  609. if (manager != nullptr) {
  610. manager->AddFuncGraph(graph);
  611. graph->set_manager(manager);
  612. }
  613. MS_LOG(INFO) << "Finish!";
  614. }
  615. std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
  616. const std::vector<tensor::TensorPtr> &input_tensors,
  617. const std::vector<bool> &tensors_mask) {
  618. auto graph = std::make_shared<KernelGraph>();
  619. std::vector<AnfNodePtr> inputs;
  620. // set input[0]
  621. PrimitivePtr op_prim = op_run_info.py_primitive;
  622. MS_EXCEPTION_IF_NULL(op_prim);
  623. inputs.push_back(std::make_shared<ValueNode>(op_prim));
  624. // set input parameter
  625. MS_LOG(INFO) << "Input tensor size: " << input_tensors.size();
  626. if (input_tensors.size() != tensors_mask.size()) {
  627. MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
  628. << tensors_mask.size();
  629. }
  630. for (size_t i = 0; i < input_tensors.size(); ++i) {
  631. auto parameter = ConstructRunOpParameter(graph, input_tensors.at(i), tensors_mask[i]);
  632. inputs.push_back(parameter);
  633. graph->MutableInputs()->push_back(parameter);
  634. }
  635. // set execution order
  636. auto cnode = graph->NewCNode(inputs);
  637. MS_EXCEPTION_IF_NULL(cnode);
  638. // set abstract,which include inferred shapes and types
  639. cnode->set_abstract(op_run_info.abstract);
  640. // set execution order
  641. std::vector<CNodePtr> exe_order = {cnode};
  642. graph->set_execution_order(exe_order);
  643. // set output
  644. CreateOutputNode(cnode, graph);
  645. return graph;
  646. }
  647. BaseRef SessionBasic::TransformBaseRefListToTuple(const BaseRef &base_ref) {
  648. if (utils::isa<VectorRef>(base_ref)) {
  649. auto ref_list = utils::cast<VectorRef>(base_ref);
  650. py::tuple output_tensors(ref_list.size());
  651. for (size_t i = 0; i < ref_list.size(); ++i) {
  652. auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
  653. if (utils::isa<tensor::TensorPtr>(output)) {
  654. auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
  655. MS_EXCEPTION_IF_NULL(tensor_ptr);
  656. output_tensors[i] = tensor_ptr;
  657. } else if (utils::isa<PyObjectRef>(output)) {
  658. py::object obj = utils::cast<PyObjectRef>(output).object_;
  659. py::tuple tensor_tuple = py::cast<py::tuple>(obj);
  660. output_tensors[i] = tensor_tuple;
  661. } else {
  662. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  663. }
  664. }
  665. return output_tensors; // turn tuple to py::object and store in PyObjectRef
  666. } else if (utils::isa<tensor::TensorPtr>(base_ref)) {
  667. return base_ref;
  668. } else {
  669. MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
  670. }
  671. }
  672. } // namespace session
  673. } // namespace mindspore