You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_compiler.cc 33 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License"){}
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/graph_scheduler/graph_compiler.h"
  17. #include <numeric>
  18. #include <map>
  19. #include <utility>
  20. #include <algorithm>
  21. #include "runtime/graph_scheduler/graph_scheduler.h"
  22. #include "runtime/pynative/op_executor.h"
  23. #include "runtime/device/device_address.h"
  24. #include "runtime/device/ms_device_shape_transfer.h"
  25. #include "runtime/pynative/op_runtime_info.h"
  26. #include "include/common/utils/convert_utils.h"
  27. #include "include/common/utils/context/graph_kernel_flags.h"
  28. #include "utils/ms_context.h"
  29. #include "ir/tensor.h"
  30. #include "backend/common/optimizer/helper.h"
  31. #include "base/base_ref_utils.h"
  32. #include "include/common/debug/dump_proto.h"
  33. #ifdef ENABLE_DEBUGGER
  34. #include "debug/debugger/debugger.h"
  35. #endif
  36. #ifdef ENABLE_DUMP_IR
  37. #include "include/common/debug/anf_ir_dump.h"
  38. #endif
  39. #ifndef ENABLE_SECURITY
  40. #include "debug/data_dump/dump_json_parser.h"
  41. #endif
  42. namespace mindspore {
  43. namespace runtime {
  44. namespace {
  45. // Whether device address of anf node is valid and device address type
  46. // is consistent with device type, for example, device address type
  47. // DeviceAddressType::kGPU should be used on GPU device
  48. bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) {
  49. MS_EXCEPTION_IF_NULL(kernel);
  50. MS_EXCEPTION_IF_NULL(device_context);
  51. if (AnfAlgo::OutputAddrExist(kernel, index)) {
  52. const auto &address = AnfAlgo::GetOutputAddr(kernel, index, false);
  53. MS_EXCEPTION_IF_NULL(address);
  54. return address->DeviceType() == device_context->GetDeviceAddressType();
  55. }
  56. return false;
  57. }
  58. void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  59. MS_EXCEPTION_IF_NULL(device_context);
  60. MS_EXCEPTION_IF_NULL(graph);
  61. std::vector<AnfNodePtr> graph_inputs = graph->inputs();
  62. const std::vector<bool> &graph_valid_input = graph->valid_inputs();
  63. (void)graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
  64. // Anf nodes which need create device address.
  65. std::vector<AnfNodePtr> nodes_list;
  66. for (size_t i = 0; i < graph_inputs.size(); ++i) {
  67. AnfNodePtr item = graph_inputs[i];
  68. MS_EXCEPTION_IF_NULL(item);
  69. if (i < graph_valid_input.size() && !graph_valid_input[i]) {
  70. continue;
  71. }
  72. if (common::AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
  73. std::vector<AnfNodePtr> outs = common::AnfAlgo::GetAllOutput(item);
  74. for (const auto &out : outs) {
  75. MS_EXCEPTION_IF_NULL(out);
  76. if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
  77. continue;
  78. }
  79. nodes_list.push_back(out);
  80. }
  81. }
  82. if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
  83. continue;
  84. }
  85. nodes_list.push_back(item);
  86. }
  87. // Create device address for anf node in nodes_list
  88. for (const auto &item : nodes_list) {
  89. auto output_size = common::AnfAlgo::GetOutputTensorNum(item);
  90. for (size_t index = 0; index < output_size; index++) {
  91. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  92. if (output_type_id == kTypeUnknown) {
  93. output_type_id = common::AnfAlgo::GetOutputInferDataType(item, index);
  94. }
  95. size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
  96. auto device_address =
  97. device_context->CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id,
  98. trans::GetRuntimePaddingShape(item, index));
  99. device_address->set_from_persistent_mem(item->isa<Parameter>());
  100. MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(item)
  101. << " addr:" << device_address;
  102. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  103. }
  104. }
  105. }
  106. void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
  107. size_t output_idx, const ValueNodePtr &value_node, const KernelGraphPtr &graph) {
  108. MS_EXCEPTION_IF_NULL(device_context);
  109. MS_EXCEPTION_IF_NULL(node_value);
  110. MS_EXCEPTION_IF_NULL(value_node);
  111. const auto &ms_context = MsContext::GetInstance();
  112. MS_EXCEPTION_IF_NULL(ms_context);
  113. std::vector<TensorPtr> tensors;
  114. TensorValueToTensor(node_value, &tensors);
  115. for (const auto &tensor : tensors) {
  116. if (tensor == nullptr) {
  117. MS_LOG(WARNING) << "Tensor is null";
  118. return;
  119. }
  120. auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
  121. if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) {
  122. // The input of PyNative bprop graph is ValueNode.
  123. // Setting the address to the ValueNode will lead to memory leak.
  124. if (!graph->is_bprop()) {
  125. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
  126. value_node.get());
  127. }
  128. continue;
  129. }
  130. size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
  131. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
  132. if (output_type_id == kTypeUnknown) {
  133. output_type_id = common::AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  134. }
  135. std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
  136. device::DeviceAddressPtr address = device_context->CreateDeviceAddress(
  137. nullptr, tensor_size, output_format, output_type_id, trans::GetRuntimePaddingShape(value_node, output_idx));
  138. MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
  139. MS_EXCEPTION_IF_NULL(address);
  140. address->set_from_persistent_mem(true);
  141. AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
  142. }
  143. }
  144. void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  145. MS_EXCEPTION_IF_NULL(device_context);
  146. MS_EXCEPTION_IF_NULL(graph);
  147. for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
  148. MS_EXCEPTION_IF_NULL(value_node);
  149. if (NodeDeviceAddressExist(device_context, value_node, 0)) {
  150. continue;
  151. }
  152. const auto &node_value = value_node->value();
  153. MS_EXCEPTION_IF_NULL(node_value);
  154. if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
  155. CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node, graph);
  156. } else if (node_value->isa<StringImm>()) {
  157. auto value = GetValue<std::string>(node_value);
  158. size_t tensor_size = value.size();
  159. auto address =
  160. device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8, ShapeVector());
  161. MS_EXCEPTION_IF_NULL(address);
  162. address->set_from_persistent_mem(true);
  163. MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node)
  164. << " addr:" << address;
  165. AnfAlgo::SetOutputAddr(address, 0, value_node.get());
  166. }
  167. }
  168. }
  169. void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph,
  170. bool is_gradient_out) {
  171. MS_EXCEPTION_IF_NULL(device_context);
  172. MS_EXCEPTION_IF_NULL(graph);
  173. const std::vector<CNodePtr> &kernels = graph->execution_order();
  174. for (const auto &kernel : kernels) {
  175. MS_EXCEPTION_IF_NULL(kernel);
  176. if (common::AnfAlgo::IsControlOpExecInBackend(kernel)) {
  177. continue;
  178. }
  179. auto output_size = AnfAlgo::GetOutputAddressNum(kernel);
  180. for (size_t i = 0; i < output_size; ++i) {
  181. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  182. continue;
  183. }
  184. auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
  185. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  186. auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
  187. auto device_address = device_context->CreateDeviceAddress(nullptr, address_size, output_format, output_type,
  188. trans::GetRuntimePaddingShape(kernel, i));
  189. if (is_gradient_out) {
  190. device_address->set_from_persistent_mem(true);
  191. }
  192. MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
  193. << " addr:" << device_address;
  194. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  195. }
  196. }
  197. }
  198. void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  199. MS_EXCEPTION_IF_NULL(device_context);
  200. MS_EXCEPTION_IF_NULL(graph);
  201. const std::vector<CNodePtr> &kernels = graph->execution_order();
  202. for (const auto &kernel : kernels) {
  203. MS_EXCEPTION_IF_NULL(kernel);
  204. if (common::AnfAlgo::IsControlOpExecInBackend(kernel)) {
  205. continue;
  206. }
  207. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  208. MS_EXCEPTION_IF_NULL(kernel_mod);
  209. auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
  210. for (size_t i = 0; i < workspace_sizes.size(); ++i) {
  211. if (AnfAlgo::WorkspaceAddrExist(kernel, i)) {
  212. break;
  213. }
  214. auto device_address =
  215. device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown, ShapeVector());
  216. MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
  217. << " addr:" << device_address;
  218. AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
  219. }
  220. }
  221. }
  222. void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
  223. MS_EXCEPTION_IF_NULL(graph);
  224. // Collect the inplace groups.
  225. std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
  226. const std::vector<CNodePtr> &kernels = graph->execution_order();
  227. for (const auto &kernel : kernels) {
  228. if (!common::AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
  229. continue;
  230. }
  231. auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel);
  232. MS_EXCEPTION_IF_NULL(primitive);
  233. auto inplace_group_attr = primitive->GetAttr("inplace_group");
  234. MS_EXCEPTION_IF_NULL(inplace_group_attr);
  235. auto group_id = GetValue<uint32_t>(inplace_group_attr);
  236. (void)inplace_groups[group_id].emplace_back(kernel);
  237. }
  238. const size_t kMinInplaceGroupSize = 2;
  239. for (const auto &inplace_group : inplace_groups) {
  240. auto &group_nodes = inplace_group.second;
  241. if (group_nodes.size() < kMinInplaceGroupSize) {
  242. continue;
  243. }
  244. // Get the device address of the first node in the inplace group.
  245. auto node_primitive = common::AnfAlgo::GetCNodePrimitive(group_nodes[0]);
  246. MS_EXCEPTION_IF_NULL(node_primitive);
  247. auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
  248. auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
  249. MS_EXCEPTION_IF_NULL(device_address);
  250. // Update the device address of other nodes using device address of the first node in the inplace group.
  251. for (size_t i = 1; i < group_nodes.size(); ++i) {
  252. auto &group_node = group_nodes[i];
  253. auto prim = common::AnfAlgo::GetCNodePrimitive(group_node);
  254. MS_EXCEPTION_IF_NULL(prim);
  255. auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
  256. AnfAlgo::SetOutputAddr(device_address, index, group_node.get());
  257. // Update the reference count of device address.
  258. device_address->IncreaseOriginalRefCount();
  259. device_address->ResetRefCount();
  260. }
  261. }
  262. }
  263. void UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph) {
  264. MS_EXCEPTION_IF_NULL(graph);
  265. auto &kernels = graph->execution_order();
  266. for (auto &kernel : kernels) {
  267. MS_EXCEPTION_IF_NULL(kernel);
  268. auto output_num = common::AnfAlgo::GetOutputTensorNum(kernel);
  269. if (output_num == 0) {
  270. MS_LOG(DEBUG) << "This kernel has no output size.";
  271. continue;
  272. }
  273. for (size_t i = 0; i < output_num; ++i) {
  274. session::AnfWithOutIndex out_pair(kernel, i);
  275. if (graph->IsInRefOutputMap(out_pair)) {
  276. auto origin_pair = graph->GetRefCorrespondOutput(out_pair);
  277. MS_EXCEPTION_IF_NULL(origin_pair.first);
  278. auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second, false);
  279. MS_EXCEPTION_IF_NULL(origin_node_output_addr);
  280. auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
  281. if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
  282. MS_LOG(DEBUG) << "REF address is not same, ref node output need address update";
  283. MS_LOG(DEBUG) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
  284. << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
  285. AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
  286. // Update the reference count of device address.
  287. cur_node_output_addr->DecreaseOriginalRefCount();
  288. cur_node_output_addr->ResetRefCount();
  289. origin_node_output_addr->IncreaseOriginalRefCount();
  290. origin_node_output_addr->ResetRefCount();
  291. }
  292. }
  293. }
  294. }
  295. }
  296. void SetSummaryNodesRefCount(const KernelGraph *graph) {
  297. MS_EXCEPTION_IF_NULL(graph);
  298. if (!graph->summary_node_exist()) {
  299. return;
  300. }
  301. const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes = graph->summary_nodes();
  302. if (summary_nodes.empty()) {
  303. return;
  304. }
  305. for (const auto &item : summary_nodes) {
  306. const AnfNodePtr &node = item.second.first;
  307. size_t index = IntToSize(item.second.second);
  308. auto device_address = AnfAlgo::GetMutableOutputAddr(node, index, false);
  309. MS_EXCEPTION_IF_NULL(device_address);
  310. device_address->set_original_ref_count(SIZE_MAX);
  311. device_address->ResetRefCount();
  312. }
  313. }
  314. void UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> &output_with_index) {
  315. for (const auto &item_with_index : output_with_index) {
  316. if (!AnfAlgo::OutputAddrExist(item_with_index.first, item_with_index.second, false)) {
  317. continue;
  318. }
  319. auto device_address = AnfAlgo::GetMutableOutputAddr(item_with_index.first, item_with_index.second, false);
  320. MS_EXCEPTION_IF_NULL(device_address);
  321. device_address->set_original_ref_count(SIZE_MAX);
  322. device_address->ResetRefCount();
  323. }
  324. }
  325. } // namespace
  326. GraphCompilerInfo::~GraphCompilerInfo() {
  327. GraphScheduler::GetInstance().Clear(name_, graphs_, origin_parameters_order_, control_node_parser_);
  328. }
  329. GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs,
  330. const DeviceContext *device_context, bool run_in_pynative) {
  331. MS_EXCEPTION_IF_NULL(session_);
  332. MS_EXCEPTION_IF_NULL(segment);
  333. MS_LOG(INFO) << "Status record: start compile graph.";
  334. auto nodes = segment->nodes_;
  335. auto device_terget = device_context->GetDeviceAddressType();
  336. // Generate kernel graph.
  337. KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs, device_terget);
  338. MS_EXCEPTION_IF_NULL(graph);
  339. opt::EliminateIllegalDataTypePass(graph);
  340. SetGraphDependency(graph, segment);
  341. // Unify the MindIR, must be before of the graph optimization.
  342. device_context->UnifyMindIR(graph);
  343. // The graph common optimization.
  344. graph->UpdateGraphAquireGilAttr();
  345. opt::BackendCommonOptimization(graph);
  346. graph->SetInputNodes();
  347. auto manager = MakeManager({graph});
  348. if (manager) {
  349. manager->AddFuncGraph(graph);
  350. graph->set_manager(manager);
  351. }
  352. session_->SetInputNodeUsage(graph, manager);
  353. graph->SetOptimizerFlag();
  354. GraphId graph_id;
  355. if (run_in_pynative) {
  356. MS_EXCEPTION_IF_NULL(session_);
  357. // Graphkernel does not support pynative mode now, print a warning here.
  358. graphkernel::GraphKernelFlags::GetInstance().CheckSupport();
  359. session_->InitAllBucket(graph, device_context);
  360. graph_id = graph->graph_id();
  361. } else {
  362. graph_id = CompileGraphImpl(graph, device_context);
  363. }
  364. session_->DumpGraphs({graph});
  365. // Cache the backend graph output nodes to front nodes with output index.
  366. auto backend_node = graph->output();
  367. MS_EXCEPTION_IF_NULL(backend_node);
  368. graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs);
  369. auto ms_context = MsContext::GetInstance();
  370. MS_EXCEPTION_IF_NULL(ms_context);
  371. std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  372. if (device_target == kGPUDevice) {
  373. graph->set_root_graph_id(graph_id);
  374. }
  375. AnfAlgo::UpdateGraphValidRefPair(graph);
  376. MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id;
  377. return graph_id;
  378. }
  379. GraphId GraphCompiler::CompileGraph(const FuncGraphPtr &func_graph, const DeviceContext *device_context) {
  380. MS_EXCEPTION_IF_NULL(session_);
  381. MS_EXCEPTION_IF_NULL(func_graph);
  382. MS_LOG(INFO) << "Status record: start compile graph.";
  383. // Generate kernel graph.
  384. std::vector<KernelGraphPtr> all_graphs;
  385. auto device_target = device_context->GetDeviceAddressType();
  386. KernelGraphPtr root_graph = session_->ConstructKernelGraph(func_graph, &all_graphs, device_target);
  387. MS_EXCEPTION_IF_NULL(root_graph);
  388. for (const auto &graph : all_graphs) {
  389. MS_EXCEPTION_IF_NULL(graph);
  390. graph->set_root_graph_id(root_graph->graph_id());
  391. }
  392. // Unify the MindIR, must be before of the graph optimization.
  393. device_context->UnifyMindIR(root_graph);
  394. // The graph common optimization.
  395. opt::BackendCommonOptimization(root_graph);
  396. auto graph_id = CompileGraphImpl(root_graph, device_context);
  397. // dump all graphs.
  398. // for ascend mindRT.
  399. session_->DumpGraphs(all_graphs);
  400. // Cache the backend graph output nodes to front nodes with output index.
  401. auto output = func_graph->output();
  402. MS_EXCEPTION_IF_NULL(output);
  403. auto backend_node = root_graph->output();
  404. MS_EXCEPTION_IF_NULL(backend_node);
  405. root_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {output});
  406. AnfAlgo::UpdateGraphValidRefPair(root_graph);
  407. MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id;
  408. return graph_id;
  409. }
  410. GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
  411. MS_EXCEPTION_IF_NULL(graph);
  412. MS_EXCEPTION_IF_NULL(device_context);
  413. const auto &ms_context = MsContext::GetInstance();
  414. MS_EXCEPTION_IF_NULL(ms_context);
  415. #ifdef ENABLE_DUMP_IR
  416. bool save_graphs = ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  417. // Dump .pb graph before graph optimization.
  418. if (save_graphs) {
  419. DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
  420. }
  421. #endif
  422. // Set the graph sink flag.
  423. auto is_executing_sink = device_context->IsExecutingSink(graph);
  424. auto is_loop_count_sink = device_context->IsLoopCountSink(graph);
  425. graph->set_is_executing_sink(is_executing_sink);
  426. graph->set_is_loop_count_sink(is_loop_count_sink);
  427. // Execute optimization pass.
  428. device_context->OptimizeGraph(graph);
  429. // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
  430. // 'KernelMod' is real executive object of kernel.
  431. device_context->CreateKernel(graph->execution_order());
  432. // Read the output and input ref map and set to the kernel graph.
  433. AddOutInRefToGraph(graph);
  434. #ifndef ENABLE_SECURITY
  435. session_->SetSummaryNodes(graph.get());
  436. // Update needed dump kernels for mindRT.
  437. DumpJsonParser::GetInstance().UpdateNeedDumpKernels(*graph.get());
  438. #endif
  439. // Adjust kernel graph before run graph.
  440. device_context->PreprocessBeforeRunGraph(graph);
  441. // Create device address for all anf nodes of graph.
  442. CreateDeviceAddress(graph, device_context, false);
  443. graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
  444. MS_EXCEPTION_IF_NULL(session_);
  445. session_->InitAllBucket(graph, device_context);
  446. SetSummaryNodesRefCount(graph.get());
  447. #ifdef ENABLE_DUMP_IR
  448. // Dump .pb graph after graph optimization.
  449. if (save_graphs) {
  450. DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id()));
  451. }
  452. #endif
  453. #ifdef ENABLE_DEBUGGER
  454. auto debugger = Debugger::GetInstance();
  455. // Dump graph for GPU mindRT if dump is enabled.
  456. debugger->DumpInGraphCompiler(graph);
  457. if (debugger && debugger->DebuggerBackendEnabled()) {
  458. // Load graphs for GPU and Ascend mindRT.
  459. debugger->LoadGraphs(graph);
  460. }
  461. #endif
  462. graph->EnableRuntimeCache();
  463. return graph->graph_id();
  464. }
  465. GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, bool *single_op_cache_hit,
  466. const DeviceContext *device_context) {
  467. // Check if the graph cache exists.
  468. auto iter = run_op_graphs_.find(op_run_info.graph_info);
  469. auto &op_executor = runtime::OpExecutor::GetInstance();
  470. if (iter != run_op_graphs_.end() && op_executor.BuildQueueEmpty()) {
  471. const auto &graph = iter->second;
  472. MS_EXCEPTION_IF_NULL(graph);
  473. *single_op_cache_hit = true;
  474. return graph->graph_id();
  475. }
  476. *single_op_cache_hit = false;
  477. // Generate kernel graph.
  478. MS_EXCEPTION_IF_NULL(session_);
  479. KernelGraphPtr graph =
  480. session_->ConstructSingleOpGraph(op_run_info, op_run_info.input_tensors, op_run_info.tensor_mask,
  481. device_context->GetDeviceAddressType() == device::DeviceAddressType::kAscend);
  482. MS_EXCEPTION_IF_NULL(graph);
  483. // session_ is SessionBasic, AscendUnifyMindIR has not been executed.
  484. device_context->UnifyMindIR(graph);
  485. MS_EXCEPTION_IF_NULL(device_context);
  486. device_context->OptimizeSingleOpGraph(graph);
  487. // Create device address for all anf nodes of graph.
  488. CreateDeviceAddressWithoutWorkspace(graph, device_context, op_run_info.is_gradient_out);
  489. graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
  490. run_op_graphs_[op_run_info.graph_info] = graph;
  491. auto output_nodes = graph->outputs();
  492. auto &outputs_with_index = run_op_graph_output_nodes_[graph->graph_id()];
  493. for (auto &node : output_nodes) {
  494. MS_EXCEPTION_IF_NULL(node);
  495. (void)outputs_with_index.emplace_back(common::AnfAlgo::VisitKernelWithReturnType(node, 0, false));
  496. }
  497. UpdateRefCountForGraphOutput(outputs_with_index);
  498. AnfAlgo::UpdateGraphValidRefPair(graph);
  499. const std::vector<CNodePtr> &kernels = graph->execution_order();
  500. for (const auto &kernel : kernels) {
  501. common::AnfAlgo::SetNodeAttr(kAttrSingleOpCompile, MakeValue(true), kernel);
  502. }
  503. return graph->graph_id();
  504. }
  505. void GraphCompiler::BuildSingleOpGraphs(const std::vector<KernelGraphPtr> &graphs,
  506. const DeviceContext *device_context) const {
  507. MS_EXCEPTION_IF_NULL(device_context);
  508. std::vector<CNodePtr> node_to_build;
  509. for (const auto &graph : graphs) {
  510. const auto &nodes = graph->execution_order();
  511. std::copy(nodes.begin(), nodes.end(), std::back_inserter(node_to_build));
  512. }
  513. device_context->CreateKernel(node_to_build);
  514. for (const auto &graph : graphs) {
  515. device_context->PreprocessBeforeRunSingleOpGraph(graph);
  516. CreateKernelWorkspaceDeviceAddress(device_context, graph);
  517. // Need to execute after PreprocessBeforeRunSingleOpGraph
  518. runtime::OpRuntimeInfo::CacheGraphOpRuntimeInfo(graph);
  519. }
  520. }
  521. KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
  522. MS_EXCEPTION_IF_NULL(session_);
  523. return session_->GetGraph(graph_id);
  524. }
  525. KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
  526. auto iter = run_op_graphs_.find(graph_info);
  527. if (iter == run_op_graphs_.end()) {
  528. MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
  529. return nullptr;
  530. }
  531. return iter->second;
  532. }
  533. void GraphCompiler::AddOutInRefToGraph(const KernelGraphPtr &graph) const {
  534. MS_EXCEPTION_IF_NULL(graph);
  535. for (const auto &cnode : graph->execution_order()) {
  536. MS_EXCEPTION_IF_NULL(cnode);
  537. auto kernel_info = dynamic_cast<device::KernelInfo *>(cnode->kernel_info());
  538. MS_EXCEPTION_IF_NULL(kernel_info);
  539. for (const auto &ref : kernel_info->out_in_ref_map()) {
  540. size_t output_index = ref.first;
  541. size_t input_index = ref.second;
  542. auto final_pair = std::make_pair(cnode, output_index);
  543. auto origin_pair = common::AnfAlgo::VisitKernel(common::AnfAlgo::GetInputNode(cnode, input_index), 0);
  544. MS_LOG(INFO) << "The reference relation output " << final_pair.first->fullname_with_scope()
  545. << ", output index: " << final_pair.second << " to input "
  546. << origin_pair.first->fullname_with_scope() << ", output index: " << origin_pair.second;
  547. // Add to graph only if the input is not a monad.
  548. if (!HasAbstractUMonad(origin_pair.first) && !HasAbstractIOMonad(origin_pair.first)) {
  549. graph->AddRefCorrespondPairs(final_pair, origin_pair);
  550. }
  551. }
  552. }
  553. }
  554. void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context,
  555. bool is_gradient_out) const {
  556. MS_LOG(INFO) << "Status record: start create device address. graph id: " << graph->graph_id();
  557. CreateParameterDeviceAddress(device_context, graph);
  558. CreateValueNodeDeviceAddress(device_context, graph);
  559. CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
  560. CreateKernelWorkspaceDeviceAddress(device_context, graph);
  561. UpdateDeviceAddressForInplaceNode(graph);
  562. UpdateDeviceAddressForRefNode(graph);
  563. MS_LOG(INFO) << "Status record: end create device address. graph id: " << graph->graph_id();
  564. }
  565. void GraphCompiler::CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph,
  566. const DeviceContext *device_context,
  567. bool is_gradient_out) const {
  568. CreateParameterDeviceAddress(device_context, graph);
  569. CreateValueNodeDeviceAddress(device_context, graph);
  570. CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
  571. UpdateDeviceAddressForInplaceNode(graph);
  572. UpdateDeviceAddressForRefNode(graph);
  573. }
  574. void GraphCompiler::GetParamAndOutputIndex(
  575. const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs, VectorRef *const outputs,
  576. std::map<AnfNodePtr, size_t> *parameter_index,
  577. std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
  578. MS_EXCEPTION_IF_NULL(session_);
  579. session_->GetParameterIndex(graph.get(), inputs, parameter_index);
  580. session_->CreateOutputPlaceholder(graph, inputs, outputs, output_indexes);
  581. }
  582. void GraphCompiler::GetSingleOpInputTensors(const CNodePtr &kernel,
  583. const std::map<KernelWithIndex, TensorPtr> &op_output,
  584. const std::map<AnfNodePtr, size_t> &parameter_index,
  585. const std::vector<TensorPtr> &graph_inputs,
  586. InputTensorInfo *const input_tensor_info) {
  587. MS_EXCEPTION_IF_NULL(session_);
  588. session_->GetOpInputTensors(kernel, op_output, parameter_index, graph_inputs, input_tensor_info);
  589. }
  590. TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel,
  591. const std::map<KernelWithIndex, TensorPtr> &op_output,
  592. const std::map<AnfNodePtr, size_t> &parameter_index,
  593. const std::vector<TensorPtr> &graph_inputs,
  594. InputTensorInfo *const input_tensor_info, size_t input_index) {
  595. MS_EXCEPTION_IF_NULL(session_);
  596. return session_->GetOpInputTensorByIndex(kernel, op_output, parameter_index, graph_inputs, input_tensor_info,
  597. input_index);
  598. }
  599. void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputTensorInfo &tensor_info,
  600. OpRunInfo *run_info, GraphInfo *graph_info,
  601. GraphOutputInfo *const graph_output_info) {
  602. MS_EXCEPTION_IF_NULL(session_);
  603. MS_EXCEPTION_IF_NULL(graph_info);
  604. *graph_info = session_->GetSingleOpGraphInfo(kernel, tensor_info.input_tensors);
  605. *run_info = session_->GetSingleOpRunInfo(kernel, *graph_info, tensor_info, graph_output_info);
  606. }
  607. void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {
  608. MS_EXCEPTION_IF_NULL(session_);
  609. session_->GetRefCount(graph.get(), ref_count);
  610. }
  611. void GraphCompiler::CalculateForwardOpOutputCount(const KernelGraphPtr &graph,
  612. const std::vector<tensor::TensorPtr> &inputs,
  613. std::map<std::string, size_t> *forward_op_output_tensor_id) const {
  614. MS_EXCEPTION_IF_NULL(session_);
  615. forward_op_output_tensor_id->clear();
  616. session_->GetForwardOpOutputRefCount(graph.get(), inputs, forward_op_output_tensor_id);
  617. }
  618. void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
  619. std::map<KernelWithIndex, size_t> *ref_count,
  620. std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const {
  621. MS_EXCEPTION_IF_NULL(session_);
  622. session_->HandleOpInputs(input_kernels_with_index, ref_count, op_output_map);
  623. }
  624. void GraphCompiler::UpdateForwardOpOutputRefCount(const std::vector<tensor::TensorPtr> &input_tensor,
  625. std::map<std::string, size_t> *forward_op_output_tensor_id) const {
  626. MS_EXCEPTION_IF_NULL(session_);
  627. MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id);
  628. session_->ReleaseForwardOpOutput(input_tensor, forward_op_output_tensor_id);
  629. }
  630. void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
  631. const std::map<KernelWithIndex, size_t> &ref_count,
  632. std::map<KernelWithIndex, TensorPtr> *op_output_map,
  633. GraphOutputInfo *const graph_output_info) const {
  634. MS_EXCEPTION_IF_NULL(session_);
  635. session_->HandleOpOutputs(kernel, op_outputs, ref_count, op_output_map, graph_output_info);
  636. }
  637. void GraphCompiler::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
  638. MS_EXCEPTION_IF_NULL(session_);
  639. session_->AddGradAddrToBucket(graph_id, grad_tensor);
  640. }
  641. void GraphCompiler::ClearAllBucket(const GraphId &graph_id) {
  642. MS_EXCEPTION_IF_NULL(session_);
  643. session_->ClearAllBucket(graph_id);
  644. }
  645. const std::vector<KernelWithIndex> &GraphCompiler::GetGraphOutputNodes(GraphId graph_id) const {
  646. const auto &iter = run_op_graph_output_nodes_.find(graph_id);
  647. if (iter == run_op_graph_output_nodes_.end()) {
  648. MS_LOG(EXCEPTION) << "Can not find output nodes for graph id: " << graph_id;
  649. }
  650. return iter->second;
  651. }
  652. void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
  653. MS_EXCEPTION_IF_NULL(session_);
  654. #ifndef ENABLE_SECURITY
  655. session_->RegisterSummaryCallBackFunc(callback);
  656. #endif
  657. }
  658. void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
  659. MS_EXCEPTION_IF_NULL(session_);
  660. for (const auto &graph : graphs) {
  661. #ifndef ENABLE_SECURITY
  662. session_->Summary(graph.get());
  663. #endif
  664. }
  665. }
  666. void GraphCompiler::EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id) {
  667. (void)run_op_graphs_.erase(graph_info);
  668. (void)run_op_graph_output_nodes_.erase(graph_id);
  669. }
  670. void GraphCompiler::SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const {
  671. MS_EXCEPTION_IF_NULL(graph);
  672. MS_EXCEPTION_IF_NULL(segment);
  673. segment->graph_id_ = graph->graph_id();
  674. for (auto &pre_segment : segment->pre_segments_) {
  675. MS_EXCEPTION_IF_NULL(pre_segment);
  676. auto pre_graph = Fetch(pre_segment->graph_id_);
  677. MS_EXCEPTION_IF_NULL(pre_graph);
  678. pre_graph->AddPostGraph(graph);
  679. graph->AddPreGraph(pre_graph);
  680. MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph->graph_id();
  681. }
  682. }
  683. } // namespace runtime
  684. } // namespace mindspore