You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_compiler.cc 27 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License"){}
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/framework/graph_compiler.h"
  17. #include <numeric>
  18. #include <map>
  19. #include <utility>
  20. #include <algorithm>
  21. #include "runtime/framework/graph_scheduler.h"
  22. #include "runtime/op_builder/op_lazy_builder.h"
  23. #include "runtime/device/device_address.h"
  24. #include "common/trans.h"
  25. #include "utils/convert_utils.h"
  26. #include "ir/tensor.h"
  27. #include "backend/optimizer/common/helper.h"
  28. #include "base/base_ref_utils.h"
  29. #include "debug/dump_proto.h"
  30. #ifdef ENABLE_DEBUGGER
  31. #include "debug/debugger/debugger.h"
  32. #endif
  33. #ifdef ENABLE_DUMP_IR
  34. #include "debug/anf_ir_dump.h"
  35. #include "debug/rdr/running_data_recorder.h"
  36. #endif
  37. #ifndef ENABLE_SECURITY
  38. #include "debug/data_dump/dump_json_parser.h"
  39. #endif
  40. namespace mindspore {
  41. namespace runtime {
  42. namespace {
  43. // Whether device address of anf node is valid and device address type
  44. // is consistent with device type, for example, device address type
  45. // DeviceAddressType::kGPU should be used on GPU device
  46. bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) {
  47. MS_EXCEPTION_IF_NULL(kernel);
  48. MS_EXCEPTION_IF_NULL(device_context);
  49. if (AnfAlgo::OutputAddrExist(kernel, index)) {
  50. const auto &address = AnfAlgo::GetOutputAddr(kernel, index, false);
  51. MS_EXCEPTION_IF_NULL(address);
  52. return address->DeviceType() == device_context->GetDeviceAddressType();
  53. }
  54. return false;
  55. }
  56. void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  57. MS_EXCEPTION_IF_NULL(device_context);
  58. MS_EXCEPTION_IF_NULL(graph);
  59. std::vector<AnfNodePtr> graph_inputs = graph->inputs();
  60. const std::vector<bool> &graph_valid_input = graph->valid_inputs();
  61. (void)graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
  62. // Anf nodes which need create device address.
  63. std::vector<AnfNodePtr> nodes_list;
  64. for (size_t i = 0; i < graph_inputs.size(); ++i) {
  65. AnfNodePtr item = graph_inputs[i];
  66. MS_EXCEPTION_IF_NULL(item);
  67. if (i < graph_valid_input.size() && !graph_valid_input[i]) {
  68. continue;
  69. }
  70. if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
  71. std::vector<AnfNodePtr> outs = AnfAlgo::GetAllOutput(item);
  72. for (const auto &out : outs) {
  73. MS_EXCEPTION_IF_NULL(out);
  74. if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
  75. continue;
  76. }
  77. nodes_list.push_back(out);
  78. }
  79. }
  80. if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
  81. continue;
  82. }
  83. nodes_list.push_back(item);
  84. }
  85. // Create device address for anf node in nodes_list
  86. for (const auto &item : nodes_list) {
  87. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  88. for (size_t index = 0; index < output_size; index++) {
  89. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  90. if (output_type_id == kTypeUnknown) {
  91. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  92. }
  93. size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
  94. auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
  95. AnfAlgo::GetOutputFormat(item, index), output_type_id);
  96. MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(item) << " addr:" << device_address;
  97. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  98. }
  99. }
  100. }
  101. void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
  102. size_t output_idx, const ValueNodePtr &value_node) {
  103. MS_EXCEPTION_IF_NULL(device_context);
  104. MS_EXCEPTION_IF_NULL(node_value);
  105. MS_EXCEPTION_IF_NULL(value_node);
  106. const auto &ms_context = MsContext::GetInstance();
  107. MS_EXCEPTION_IF_NULL(ms_context);
  108. std::vector<TensorPtr> tensors;
  109. TensorValueToTensor(node_value, &tensors);
  110. for (const auto &tensor : tensors) {
  111. if (tensor == nullptr) {
  112. MS_LOG(WARNING) << "Tensor is null";
  113. return;
  114. }
  115. auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
  116. if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) {
  117. bool is_pynative_infer = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
  118. bool is_graph_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
  119. if (is_graph_mode || is_pynative_infer) {
  120. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
  121. value_node.get());
  122. }
  123. continue;
  124. }
  125. size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
  126. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
  127. if (output_type_id == kTypeUnknown) {
  128. output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  129. }
  130. std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
  131. device::DeviceAddressPtr address =
  132. device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
  133. MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
  134. MS_EXCEPTION_IF_NULL(address);
  135. AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
  136. }
  137. }
  138. void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  139. MS_EXCEPTION_IF_NULL(device_context);
  140. MS_EXCEPTION_IF_NULL(graph);
  141. for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
  142. MS_EXCEPTION_IF_NULL(value_node);
  143. if (NodeDeviceAddressExist(device_context, value_node, 0)) {
  144. continue;
  145. }
  146. const auto &node_value = value_node->value();
  147. MS_EXCEPTION_IF_NULL(node_value);
  148. if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
  149. CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
  150. } else if (node_value->isa<StringImm>()) {
  151. auto value = GetValue<std::string>(node_value);
  152. size_t tensor_size = value.size();
  153. auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  154. MS_EXCEPTION_IF_NULL(address);
  155. MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
  156. AnfAlgo::SetOutputAddr(address, 0, value_node.get());
  157. }
  158. }
  159. }
  160. void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  161. MS_EXCEPTION_IF_NULL(device_context);
  162. MS_EXCEPTION_IF_NULL(graph);
  163. const std::vector<CNodePtr> &kernels = graph->execution_order();
  164. for (const auto &kernel : kernels) {
  165. MS_EXCEPTION_IF_NULL(kernel);
  166. if (AnfAlgo::IsControlOpExecInBackend(kernel)) {
  167. continue;
  168. }
  169. auto output_size = AnfAlgo::GetOutputAddressNum(kernel);
  170. for (size_t i = 0; i < output_size; ++i) {
  171. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  172. continue;
  173. }
  174. auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
  175. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  176. auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
  177. auto device_address = device_context->CreateDeviceAddress(nullptr, address_size, output_format, output_type);
  178. MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address;
  179. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  180. }
  181. }
  182. }
  183. void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  184. MS_EXCEPTION_IF_NULL(device_context);
  185. MS_EXCEPTION_IF_NULL(graph);
  186. const std::vector<CNodePtr> &kernels = graph->execution_order();
  187. for (const auto &kernel : kernels) {
  188. MS_EXCEPTION_IF_NULL(kernel);
  189. if (AnfAlgo::IsControlOpExecInBackend(kernel)) {
  190. continue;
  191. }
  192. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  193. MS_EXCEPTION_IF_NULL(kernel_mod);
  194. auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
  195. for (size_t i = 0; i < workspace_sizes.size(); ++i) {
  196. auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown);
  197. MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address;
  198. AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
  199. }
  200. }
  201. }
  202. void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
  203. MS_EXCEPTION_IF_NULL(graph);
  204. // Collect the inplace groups.
  205. std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
  206. const std::vector<CNodePtr> &kernels = graph->execution_order();
  207. for (const auto &kernel : kernels) {
  208. if (!AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
  209. continue;
  210. }
  211. auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
  212. MS_EXCEPTION_IF_NULL(primitive);
  213. auto inplace_group_attr = primitive->GetAttr("inplace_group");
  214. MS_EXCEPTION_IF_NULL(inplace_group_attr);
  215. auto group_id = GetValue<uint32_t>(inplace_group_attr);
  216. (void)inplace_groups[group_id].emplace_back(kernel);
  217. }
  218. const size_t kMinInplaceGroupSize = 2;
  219. for (const auto &inplace_group : inplace_groups) {
  220. auto &group_nodes = inplace_group.second;
  221. if (group_nodes.size() < kMinInplaceGroupSize) {
  222. continue;
  223. }
  224. // Get the device address of the first node in the inplace group.
  225. auto node_primitive = AnfAlgo::GetCNodePrimitive(group_nodes[0]);
  226. MS_EXCEPTION_IF_NULL(node_primitive);
  227. auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
  228. auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
  229. MS_EXCEPTION_IF_NULL(device_address);
  230. // Update the device address of other nodes using device address of the first node in the inplace group.
  231. for (size_t i = 1; i < group_nodes.size(); ++i) {
  232. auto &group_node = group_nodes[i];
  233. auto prim = AnfAlgo::GetCNodePrimitive(group_node);
  234. MS_EXCEPTION_IF_NULL(prim);
  235. auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
  236. AnfAlgo::SetOutputAddr(device_address, index, group_node.get());
  237. // Update the reference count of device address.
  238. device_address->IncreaseOriginalRefCount();
  239. device_address->ResetRefCount();
  240. }
  241. }
  242. }
  243. void UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph) {
  244. MS_EXCEPTION_IF_NULL(graph);
  245. auto &kernels = graph->execution_order();
  246. for (auto &kernel : kernels) {
  247. MS_EXCEPTION_IF_NULL(kernel);
  248. auto output_num = AnfAlgo::GetOutputTensorNum(kernel);
  249. if (output_num == 0) {
  250. MS_LOG(DEBUG) << "This kernel has no output size.";
  251. continue;
  252. }
  253. for (size_t i = 0; i < output_num; ++i) {
  254. session::AnfWithOutIndex out_pair(kernel, i);
  255. if (graph->IsInRefOutputMap(out_pair)) {
  256. auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
  257. MS_EXCEPTION_IF_NULL(origin_pair.first);
  258. auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second);
  259. MS_EXCEPTION_IF_NULL(origin_node_output_addr);
  260. auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i);
  261. if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
  262. MS_LOG(DEBUG) << "REF address is not same, ref node output need address update";
  263. MS_LOG(DEBUG) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
  264. << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
  265. AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
  266. // Update the reference count of device address.
  267. cur_node_output_addr->DecreaseOriginalRefCount();
  268. cur_node_output_addr->ResetRefCount();
  269. origin_node_output_addr->IncreaseOriginalRefCount();
  270. origin_node_output_addr->ResetRefCount();
  271. }
  272. }
  273. }
  274. }
  275. }
  276. void SetSummaryNodesRefCount(const KernelGraph *graph) {
  277. MS_EXCEPTION_IF_NULL(graph);
  278. if (!graph->summary_node_exist()) {
  279. return;
  280. }
  281. const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes = graph->summary_nodes();
  282. if (summary_nodes.empty()) {
  283. return;
  284. }
  285. for (const auto &item : summary_nodes) {
  286. const AnfNodePtr &node = item.second.first;
  287. size_t index = IntToSize(item.second.second);
  288. auto device_address = AnfAlgo::GetMutableOutputAddr(node, index, false);
  289. MS_EXCEPTION_IF_NULL(device_address);
  290. device_address->set_original_ref_count(SIZE_MAX);
  291. device_address->ResetRefCount();
  292. }
  293. }
  294. void UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> &output_with_index) {
  295. for (const auto &item_with_index : output_with_index) {
  296. if (!AnfAlgo::OutputAddrExist(item_with_index.first, item_with_index.second, false)) {
  297. continue;
  298. }
  299. auto device_address = AnfAlgo::GetMutableOutputAddr(item_with_index.first, item_with_index.second, false);
  300. MS_EXCEPTION_IF_NULL(device_address);
  301. device_address->set_original_ref_count(SIZE_MAX);
  302. device_address->ResetRefCount();
  303. }
  304. }
  305. } // namespace
  306. GraphCompilerInfo::~GraphCompilerInfo() { GraphScheduler::GetInstance().Clear(name_, graphs_); }
  307. GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs,
  308. const DeviceContext *device_context) {
  309. MS_EXCEPTION_IF_NULL(session_);
  310. // Generate kernel graph.
  311. KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs);
  312. MS_EXCEPTION_IF_NULL(graph);
  313. // Unify the MindIR, must be before of the graph optimization.
  314. device_context->UnifyMindIR(graph);
  315. // The graph common optimization.
  316. graph->UpdateGraphAquireGilAttr();
  317. opt::BackendCommonOptimization(graph);
  318. graph->SetInputNodes();
  319. auto manager = MakeManager({graph});
  320. if (manager) {
  321. manager->AddFuncGraph(graph);
  322. graph->set_manager(manager);
  323. }
  324. session_->SetInputNodeUsage(graph, manager);
  325. graph->SetOptimizerFlag();
  326. auto graph_id = CompileGraphImpl(graph, device_context);
  327. // Cache the backend graph output nodes to front nodes with output index.
  328. auto backend_node = graph->output();
  329. MS_EXCEPTION_IF_NULL(backend_node);
  330. graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs);
  331. graph->set_root_graph_id(graph_id);
  332. return graph_id;
  333. }
  334. GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const DeviceContext *device_context) {
  335. MS_EXCEPTION_IF_NULL(session_);
  336. MS_EXCEPTION_IF_NULL(func_graph);
  337. // Generate kernel graph.
  338. std::vector<KernelGraphPtr> all_graphs;
  339. KernelGraphPtr root_graph = session_->ConstructKernelGraph(func_graph, &all_graphs);
  340. MS_EXCEPTION_IF_NULL(root_graph);
  341. for (const auto &graph : all_graphs) {
  342. MS_EXCEPTION_IF_NULL(graph);
  343. graph->set_root_graph_id(root_graph->graph_id());
  344. }
  345. // Unify the MindIR, must be before of the graph optimization.
  346. device_context->UnifyMindIR(root_graph);
  347. // The graph common optimization.
  348. opt::BackendCommonOptimization(root_graph);
  349. auto graph_id = CompileGraphImpl(root_graph, device_context);
  350. // dump all graphs.
  351. device_context->DumpAllGraphs(all_graphs);
  352. // Cache the backend graph output nodes to front nodes with output index.
  353. auto output = func_graph->output();
  354. MS_EXCEPTION_IF_NULL(output);
  355. auto backend_node = root_graph->output();
  356. MS_EXCEPTION_IF_NULL(backend_node);
  357. root_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {output});
  358. return graph_id;
  359. }
  360. GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
  361. MS_EXCEPTION_IF_NULL(graph);
  362. MS_EXCEPTION_IF_NULL(device_context);
  363. const auto &ms_context = MsContext::GetInstance();
  364. MS_EXCEPTION_IF_NULL(ms_context);
  365. #ifdef ENABLE_DUMP_IR
  366. bool save_graphs = ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  367. // Dump .pb graph before graph optimization.
  368. if (save_graphs) {
  369. DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
  370. }
  371. #endif
  372. // Set the graph sink flag.
  373. auto is_executing_sink = device_context->IsExecutingSink(graph);
  374. auto is_loop_count_sink = device_context->IsLoopCountSink(graph);
  375. graph->set_is_executing_sink(is_executing_sink);
  376. graph->set_is_loop_count_sink(is_loop_count_sink);
  377. // Execute optimization pass.
  378. device_context->OptimizeGraph(graph);
  379. // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
  380. // 'KernelMod' is real executive object of kernel.
  381. device_context->CreateKernel(graph->execution_order());
  382. // Adjust kernel graph before run graph.
  383. device_context->PreprocessBeforeRunGraph(graph);
  384. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  385. // Create device address for all anf nodes of graph.
  386. CreateDeviceAddress(graph, device_context);
  387. }
  388. graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
  389. MS_EXCEPTION_IF_NULL(session_);
  390. session_->InitAllBucket(graph, device_context);
  391. #ifndef ENABLE_SECURITY
  392. session_->SetSummaryNodes(graph.get());
  393. #endif
  394. SetSummaryNodesRefCount(graph.get());
  395. #ifdef ENABLE_DUMP_IR
  396. // Dump .pb graph after graph optimization.
  397. if (save_graphs) {
  398. DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id()));
  399. }
  400. #endif
  401. #ifdef ENABLE_DEBUGGER
  402. auto debugger = Debugger::GetInstance();
  403. debugger->DumpInGraphCompiler(graph);
  404. if (debugger && debugger->DebuggerBackendEnabled()) {
  405. debugger->LoadGraphs(graph);
  406. }
  407. #endif
  408. #ifdef ENABLE_DUMP_IR
  409. std::string name = "graph_build";
  410. DumpGraphParams dump_params = {true, static_cast<int>(kWholeStack)};
  411. (void)mindspore::RDR::RecordAnfGraph(SubModuleId::SM_SESSION, name, graph, dump_params, ".ir,.pb");
  412. auto &kernels = graph->execution_order();
  413. std::string exec_order_name = "graph_exec_order." + std::to_string(graph->graph_id());
  414. (void)mindspore::RDR::RecordGraphExecOrder(SubModuleId::SM_SESSION, exec_order_name, kernels);
  415. #endif
  416. session_->DumpGraph(graph);
  417. return graph->graph_id();
  418. }
  419. GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool *single_op_cache_hit,
  420. const DeviceContext *device_context) {
  421. // Check if the graph cache exists.
  422. auto iter = run_op_graphs_.find(op_run_info.graph_info);
  423. auto &op_lazy_builder = runtime::OpLazyBuilder::GetInstance();
  424. if (iter != run_op_graphs_.end() && op_lazy_builder.QueueEmpty()) {
  425. const auto &graph = iter->second;
  426. MS_EXCEPTION_IF_NULL(graph);
  427. *single_op_cache_hit = true;
  428. return graph->graph_id();
  429. }
  430. *single_op_cache_hit = false;
  431. // Generate kernel graph.
  432. MS_EXCEPTION_IF_NULL(session_);
  433. KernelGraphPtr graph =
  434. session_->ConstructSingleOpGraph(op_run_info, op_run_info.input_tensors, op_run_info.tensor_mask);
  435. MS_EXCEPTION_IF_NULL(graph);
  436. MS_EXCEPTION_IF_NULL(device_context);
  437. device_context->OptimizeSingleOpGraph(graph);
  438. device_context->PreprocessBeforeRunSingleOpGraph(graph);
  439. // Create device address for all anf nodes of graph.
  440. CreateDeviceAddressWithoutWorkspace(graph, device_context);
  441. graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
  442. run_op_graphs_[op_run_info.graph_info] = graph;
  443. auto output_nodes = graph->outputs();
  444. auto &outputs_with_index = run_op_graph_output_nodes_[graph->graph_id()];
  445. for (auto &node : output_nodes) {
  446. MS_EXCEPTION_IF_NULL(node);
  447. (void)outputs_with_index.emplace_back(AnfAlgo::VisitKernelWithReturnType(node, 0, false));
  448. }
  449. UpdateRefCountForGraphOutput(outputs_with_index);
  450. return graph->graph_id();
  451. }
  452. void GraphCompiler::BuildSingleOpGraphs(const std::vector<KernelGraphPtr> &graphs,
  453. const DeviceContext *device_context) const {
  454. MS_EXCEPTION_IF_NULL(device_context);
  455. std::vector<CNodePtr> node_to_build;
  456. for (const auto &graph : graphs) {
  457. const auto &nodes = graph->execution_order();
  458. std::copy(nodes.begin(), nodes.end(), std::back_inserter(node_to_build));
  459. }
  460. device_context->CreateKernel(node_to_build);
  461. for (const auto &graph : graphs) {
  462. CreateKernelWorkspaceDeviceAddress(device_context, graph);
  463. }
  464. }
  465. KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
  466. MS_EXCEPTION_IF_NULL(session_);
  467. return session_->GetGraph(graph_id);
  468. }
  469. KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
  470. auto iter = run_op_graphs_.find(graph_info);
  471. if (iter == run_op_graphs_.end()) {
  472. MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
  473. return nullptr;
  474. }
  475. return iter->second;
  476. }
  477. void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
  478. CreateParameterDeviceAddress(device_context, graph);
  479. CreateValueNodeDeviceAddress(device_context, graph);
  480. CreateKernelOutputDeviceAddress(device_context, graph);
  481. CreateKernelWorkspaceDeviceAddress(device_context, graph);
  482. UpdateDeviceAddressForInplaceNode(graph);
  483. UpdateDeviceAddressForRefNode(graph);
  484. }
  485. void GraphCompiler::CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph,
  486. const DeviceContext *device_context) const {
  487. CreateParameterDeviceAddress(device_context, graph);
  488. CreateValueNodeDeviceAddress(device_context, graph);
  489. CreateKernelOutputDeviceAddress(device_context, graph);
  490. UpdateDeviceAddressForInplaceNode(graph);
  491. }
  492. void GraphCompiler::GetParamAndOutputIndex(
  493. const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs, VectorRef *const outputs,
  494. std::map<AnfNodePtr, size_t> *parameter_index,
  495. std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
  496. MS_EXCEPTION_IF_NULL(session_);
  497. session_->GetParameterIndex(graph.get(), inputs, parameter_index);
  498. session_->CreateOutputPlaceholder(graph, inputs, outputs, output_indexes);
  499. }
  500. void GraphCompiler::GetSingleOpInputTensors(const CNodePtr &kernel,
  501. const std::map<KernelWithIndex, TensorPtr> &op_output,
  502. const std::map<AnfNodePtr, size_t> &parameter_index,
  503. const std::vector<TensorPtr> &graph_inputs,
  504. InputTensorInfo *const input_tensor_info) {
  505. MS_EXCEPTION_IF_NULL(session_);
  506. session_->GetOpInputTensors(kernel, op_output, parameter_index, graph_inputs, input_tensor_info);
  507. }
  508. TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel,
  509. const std::map<KernelWithIndex, TensorPtr> &op_output,
  510. const std::map<AnfNodePtr, size_t> &parameter_index,
  511. const std::vector<TensorPtr> &graph_inputs,
  512. InputTensorInfo *const input_tensor_info, size_t input_index) {
  513. MS_EXCEPTION_IF_NULL(session_);
  514. return session_->GetOpInputTensorByIndex(kernel, op_output, parameter_index, graph_inputs, input_tensor_info,
  515. input_index);
  516. }
  517. void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info,
  518. OpRunInfo *run_info, GraphInfo *graph_info) {
  519. MS_EXCEPTION_IF_NULL(session_);
  520. MS_EXCEPTION_IF_NULL(graph_info);
  521. *graph_info = session_->GetSingleOpGraphInfo(kernel, tensor_info.input_tensors);
  522. *run_info = session_->GetSingleOpRunInfo(kernel, *graph_info, tensor_info);
  523. }
  524. void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {
  525. MS_EXCEPTION_IF_NULL(session_);
  526. session_->GetRefCount(graph.get(), ref_count);
  527. }
  528. void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
  529. std::map<KernelWithIndex, size_t> *ref_count,
  530. std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const {
  531. MS_EXCEPTION_IF_NULL(session_);
  532. session_->HandleOpInputs(input_kernels_with_index, ref_count, op_output_map);
  533. }
  534. void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
  535. const std::map<KernelWithIndex, size_t> &ref_count,
  536. std::map<KernelWithIndex, TensorPtr> *op_output_map,
  537. GraphOutputInfo *const graph_output_info) const {
  538. MS_EXCEPTION_IF_NULL(session_);
  539. session_->HandleOpOutputs(kernel, op_outputs, ref_count, op_output_map, graph_output_info);
  540. }
  541. void GraphCompiler::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
  542. MS_EXCEPTION_IF_NULL(session_);
  543. session_->AddGradAddrToBucket(graph_id, grad_tensor);
  544. }
  545. void GraphCompiler::ClearAllBucket(const GraphId &graph_id) {
  546. MS_EXCEPTION_IF_NULL(session_);
  547. session_->ClearAllBucket(graph_id);
  548. }
  549. const std::vector<KernelWithIndex> &GraphCompiler::GetGraphOutputNodes(GraphId graph_id) const {
  550. const auto &iter = run_op_graph_output_nodes_.find(graph_id);
  551. if (iter == run_op_graph_output_nodes_.end()) {
  552. MS_LOG(EXCEPTION) << "Can not find output nodes for graph id: " << graph_id;
  553. }
  554. return iter->second;
  555. }
  556. void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
  557. MS_EXCEPTION_IF_NULL(session_);
  558. #ifndef ENABLE_SECURITY
  559. session_->RegisterSummaryCallBackFunc(callback);
  560. #endif
  561. }
  562. void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
  563. MS_EXCEPTION_IF_NULL(session_);
  564. for (const auto &graph : graphs) {
  565. #ifndef ENABLE_SECURITY
  566. session_->Summary(graph.get());
  567. #endif
  568. }
  569. }
  570. void GraphCompiler::EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id) {
  571. (void)run_op_graphs_.erase(graph_info);
  572. (void)run_op_graph_output_nodes_.erase(graph_id);
  573. }
  574. } // namespace runtime
  575. } // namespace mindspore