You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_compiler.cc 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License"){}
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/framework/graph_compiler.h"
  17. #include <numeric>
  18. #include <map>
  19. #include <utility>
  20. #include "runtime/device/device_address.h"
  21. #include "common/trans.h"
  22. #include "utils/convert_utils.h"
  23. #include "ir/tensor.h"
  24. #include "backend/optimizer/common/helper.h"
  25. #include "base/base_ref_utils.h"
  26. #include "debug/dump_proto.h"
  27. #ifdef ENABLE_DEBUGGER
  28. #include "debug/debugger/debugger.h"
  29. #endif
  30. #ifdef ENABLE_DUMP_IR
  31. #include "debug/anf_ir_dump.h"
  32. #include "debug/rdr/running_data_recorder.h"
  33. #endif
  34. #include "debug/data_dump/dump_json_parser.h"
  35. namespace mindspore {
  36. namespace runtime {
  37. namespace {
  38. // Whether device address of anf node is valid and device address type
  39. // is consistent with device type, for example, device address type
  40. // DeviceAddressType::kGPU should be used on GPU device
  41. bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) {
  42. MS_EXCEPTION_IF_NULL(kernel);
  43. MS_EXCEPTION_IF_NULL(device_context);
  44. if (AnfAlgo::OutputAddrExist(kernel, index)) {
  45. const auto &address = AnfAlgo::GetOutputAddr(kernel, index);
  46. MS_EXCEPTION_IF_NULL(address);
  47. return address->DeviceType() == device_context->GetDeviceAddressType();
  48. }
  49. return false;
  50. }
  51. void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  52. MS_EXCEPTION_IF_NULL(device_context);
  53. MS_EXCEPTION_IF_NULL(graph);
  54. std::vector<AnfNodePtr> graph_inputs = graph->inputs();
  55. const std::vector<bool> &graph_valid_input = graph->valid_inputs();
  56. (void)graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
  57. // Anf nodes which need create device address.
  58. std::vector<AnfNodePtr> nodes_list;
  59. for (size_t i = 0; i < graph_inputs.size(); ++i) {
  60. AnfNodePtr item = graph_inputs[i];
  61. MS_EXCEPTION_IF_NULL(item);
  62. if (i < graph_valid_input.size() && !graph_valid_input[i]) {
  63. continue;
  64. }
  65. if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
  66. std::vector<AnfNodePtr> outs = AnfAlgo::GetAllOutput(item);
  67. for (const auto &out : outs) {
  68. MS_EXCEPTION_IF_NULL(out);
  69. if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
  70. continue;
  71. }
  72. nodes_list.push_back(out);
  73. }
  74. }
  75. if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
  76. continue;
  77. }
  78. nodes_list.push_back(item);
  79. }
  80. // Create device address for anf node in nodes_list
  81. for (const auto &item : nodes_list) {
  82. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  83. for (size_t index = 0; index < output_size; index++) {
  84. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  85. if (output_type_id == kTypeUnknown) {
  86. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  87. }
  88. size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
  89. auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
  90. AnfAlgo::GetOutputFormat(item, index), output_type_id);
  91. MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(item) << " addr:" << device_address;
  92. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  93. }
  94. }
  95. }
  96. void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
  97. size_t output_idx, const ValueNodePtr &value_node) {
  98. MS_EXCEPTION_IF_NULL(device_context);
  99. MS_EXCEPTION_IF_NULL(node_value);
  100. MS_EXCEPTION_IF_NULL(value_node);
  101. const auto &ms_context = MsContext::GetInstance();
  102. MS_EXCEPTION_IF_NULL(ms_context);
  103. std::vector<TensorPtr> tensors;
  104. TensorValueToTensor(node_value, &tensors);
  105. for (const auto &tensor : tensors) {
  106. if (tensor == nullptr) {
  107. MS_LOG(WARNING) << "Tensor is null";
  108. return;
  109. }
  110. auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
  111. if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) {
  112. bool is_pynative_infer = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
  113. bool is_graph_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
  114. if (is_graph_mode || is_pynative_infer) {
  115. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
  116. value_node.get());
  117. }
  118. continue;
  119. }
  120. size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
  121. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
  122. if (output_type_id == kTypeUnknown) {
  123. output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  124. }
  125. std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
  126. device::DeviceAddressPtr address =
  127. device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
  128. MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
  129. MS_EXCEPTION_IF_NULL(address);
  130. AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
  131. }
  132. }
  133. void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  134. MS_EXCEPTION_IF_NULL(device_context);
  135. MS_EXCEPTION_IF_NULL(graph);
  136. for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
  137. MS_EXCEPTION_IF_NULL(value_node);
  138. if (NodeDeviceAddressExist(device_context, value_node, 0)) {
  139. continue;
  140. }
  141. const auto &node_value = value_node->value();
  142. MS_EXCEPTION_IF_NULL(node_value);
  143. if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
  144. CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
  145. } else if (node_value->isa<StringImm>()) {
  146. auto value = GetValue<std::string>(node_value);
  147. size_t tensor_size = value.size();
  148. auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  149. MS_EXCEPTION_IF_NULL(address);
  150. MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address;
  151. AnfAlgo::SetOutputAddr(address, 0, value_node.get());
  152. }
  153. }
  154. }
  155. void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  156. MS_EXCEPTION_IF_NULL(device_context);
  157. MS_EXCEPTION_IF_NULL(graph);
  158. const std::vector<CNodePtr> &kernels = graph->execution_order();
  159. for (const auto &kernel : kernels) {
  160. MS_EXCEPTION_IF_NULL(kernel);
  161. if (AnfAlgo::IsControlOpExecInBackend(kernel)) {
  162. continue;
  163. }
  164. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  165. MS_EXCEPTION_IF_NULL(kernel_mod);
  166. auto output_sizes = kernel_mod->GetOutputSizeList();
  167. for (size_t i = 0; i < output_sizes.size(); ++i) {
  168. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  169. continue;
  170. }
  171. std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
  172. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  173. auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  174. MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address;
  175. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  176. }
  177. }
  178. }
  179. void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  180. MS_EXCEPTION_IF_NULL(device_context);
  181. MS_EXCEPTION_IF_NULL(graph);
  182. const std::vector<CNodePtr> &kernels = graph->execution_order();
  183. for (const auto &kernel : kernels) {
  184. MS_EXCEPTION_IF_NULL(kernel);
  185. if (AnfAlgo::IsControlOpExecInBackend(kernel)) {
  186. continue;
  187. }
  188. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  189. MS_EXCEPTION_IF_NULL(kernel_mod);
  190. auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
  191. for (size_t i = 0; i < workspace_sizes.size(); ++i) {
  192. auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown);
  193. MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address;
  194. AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
  195. }
  196. }
  197. }
  198. void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
  199. MS_EXCEPTION_IF_NULL(graph);
  200. // Collect the inplace groups.
  201. std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
  202. const std::vector<CNodePtr> &kernels = graph->execution_order();
  203. for (const auto &kernel : kernels) {
  204. if (!AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
  205. continue;
  206. }
  207. auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
  208. MS_EXCEPTION_IF_NULL(primitive);
  209. auto inplace_group_attr = primitive->GetAttr("inplace_group");
  210. MS_EXCEPTION_IF_NULL(inplace_group_attr);
  211. auto group_id = GetValue<uint32_t>(inplace_group_attr);
  212. (void)inplace_groups[group_id].emplace_back(kernel);
  213. }
  214. const size_t kMinInplaceGroupSize = 2;
  215. for (const auto &inplace_group : inplace_groups) {
  216. auto &group_nodes = inplace_group.second;
  217. if (group_nodes.size() < kMinInplaceGroupSize) {
  218. continue;
  219. }
  220. // Get the device address of the first node in the inplace group.
  221. auto node_primitive = AnfAlgo::GetCNodePrimitive(group_nodes[0]);
  222. MS_EXCEPTION_IF_NULL(node_primitive);
  223. auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
  224. auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
  225. MS_EXCEPTION_IF_NULL(device_address);
  226. // Update the device address of other nodes using device address of the first node in the inplace group.
  227. for (size_t i = 1; i < group_nodes.size(); ++i) {
  228. auto &group_node = group_nodes[i];
  229. auto prim = AnfAlgo::GetCNodePrimitive(group_node);
  230. MS_EXCEPTION_IF_NULL(prim);
  231. auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
  232. AnfAlgo::SetOutputAddr(device_address, index, group_node.get());
  233. // Update the reference count of device address.
  234. device_address->IncreaseOriginalRefCount();
  235. device_address->ResetRefCount();
  236. }
  237. }
  238. }
  239. void SetSummaryNodesRefCount(const KernelGraph *graph) {
  240. if (!graph->summary_node_exist()) {
  241. return;
  242. }
  243. const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes = graph->summary_nodes();
  244. if (summary_nodes.empty()) {
  245. return;
  246. }
  247. for (const auto &item : summary_nodes) {
  248. const AnfNodePtr &node = item.second.first;
  249. size_t index = IntToSize(item.second.second);
  250. auto device_address = AnfAlgo::GetMutableOutputAddr(node, index, false);
  251. MS_EXCEPTION_IF_NULL(device_address);
  252. device_address->set_original_ref_count(SIZE_MAX);
  253. device_address->ResetRefCount();
  254. }
  255. }
  256. void UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> &output_with_index) {
  257. for (const auto &item_with_index : output_with_index) {
  258. if (!AnfAlgo::OutputAddrExist(item_with_index.first, item_with_index.second, false)) {
  259. continue;
  260. }
  261. auto device_address = AnfAlgo::GetMutableOutputAddr(item_with_index.first, item_with_index.second, false);
  262. MS_EXCEPTION_IF_NULL(device_address);
  263. device_address->set_original_ref_count(SIZE_MAX);
  264. device_address->ResetRefCount();
  265. }
  266. }
  267. } // namespace
  268. GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs,
  269. const DeviceContext *device_context) {
  270. MS_EXCEPTION_IF_NULL(session_);
  271. // Generate kernel graph.
  272. KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs);
  273. MS_EXCEPTION_IF_NULL(graph);
  274. // Cache the backend graph output nodes to front nodes with output index.
  275. for (auto &output : outputs) {
  276. auto backend_node = graph->GetBackendAnfByFrontAnf(output);
  277. if (backend_node != nullptr) {
  278. graph->CacheGraphOutputToFrontNodeWithIndex(backend_node, output);
  279. }
  280. }
  281. return CompileGraphImpl(graph, device_context);
  282. }
  283. GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
  284. MS_EXCEPTION_IF_NULL(graph);
  285. MS_EXCEPTION_IF_NULL(device_context);
  286. const auto &ms_context = MsContext::GetInstance();
  287. MS_EXCEPTION_IF_NULL(ms_context);
  288. bool save_graphs = ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  289. // Dump .pb graph before graph optimization.
  290. if (save_graphs) {
  291. DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
  292. }
  293. MS_LOG(INFO) << "Get graph outputs before optimizer, graph id: " << graph->graph_id();
  294. auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
  295. // Execute optimization pass.
  296. device_context->OptimizeGraph(graph);
  297. // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
  298. // 'KernelMod' is real executive object of kernel.
  299. device_context->CreateKernel(graph->execution_order());
  300. // Adjust kernel graph before run graph.
  301. device_context->PreprocessBeforeRunGraph(graph);
  302. MS_LOG(INFO) << "Get graph outputs after optimizer, graph id: " << graph->graph_id();
  303. auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
  304. // Update the output map of kernel graph by modified output nodes.
  305. graph->UpdateGraphOutputMap(outputs_before_optimizer, outputs_after_optimizer);
  306. if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
  307. // Create device address for all anf nodes of graph.
  308. CreateDeviceAddress(graph, device_context);
  309. }
  310. graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
  311. MS_EXCEPTION_IF_NULL(session_);
  312. session_->InitAllBucket(graph, device_context);
  313. session_->SetSummaryNodes(graph.get());
  314. SetSummaryNodesRefCount(graph.get());
  315. // Dump .pb graph after graph optimization.
  316. if (save_graphs) {
  317. DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id()));
  318. }
  319. #ifdef ENABLE_DEBUGGER
  320. auto debugger = Debugger::GetInstance();
  321. debugger->DumpInGraphCompiler(graph);
  322. if (debugger && debugger->DebuggerBackendEnabled()) {
  323. debugger->LoadGraphs(graph);
  324. }
  325. #endif
  326. #ifdef ENABLE_DUMP_IR
  327. std::string name = "graph_build";
  328. DumpGraphParams dump_params = {true, static_cast<int>(kWholeStack)};
  329. (void)mindspore::RDR::RecordAnfGraph(SubModuleId::SM_SESSION, name, graph, dump_params, ".ir,.pb");
  330. auto &kernels = graph->execution_order();
  331. std::string exec_order_name = "graph_exec_order." + std::to_string(graph->graph_id());
  332. (void)mindspore::RDR::RecordGraphExecOrder(SubModuleId::SM_SESSION, exec_order_name, kernels);
  333. #endif
  334. session_->DumpGraph(graph);
  335. return graph->graph_id();
  336. }
  337. GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, const GraphInfo &graph_info,
  338. const std::vector<int64_t> *tensors_mask,
  339. std::vector<TensorPtr> *const input_tensors, bool *single_op_cache_hit,
  340. const DeviceContext *device_context) {
  341. // Check if the graph cache exists.
  342. auto iter = run_op_graphs_.find(graph_info);
  343. if (iter != run_op_graphs_.end()) {
  344. const auto &graph = iter->second;
  345. MS_EXCEPTION_IF_NULL(graph);
  346. *single_op_cache_hit = true;
  347. return graph->graph_id();
  348. }
  349. *single_op_cache_hit = false;
  350. // Generate kernel graph.
  351. MS_EXCEPTION_IF_NULL(session_);
  352. KernelGraphPtr graph = session_->ConstructSingleOpGraph(op_run_info, *input_tensors, *tensors_mask);
  353. MS_EXCEPTION_IF_NULL(graph);
  354. MS_EXCEPTION_IF_NULL(device_context);
  355. device_context->OptimizeSingleOpGraph(graph);
  356. // Generate 'KernelMod' for kernel in graph.
  357. device_context->CreateKernel(graph->execution_order());
  358. device_context->PreprocessBeforeRunSingleOpGraph(graph);
  359. // Create device address for all anf nodes of graph.
  360. CreateDeviceAddress(graph, device_context);
  361. graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
  362. run_op_graphs_[graph_info] = graph;
  363. auto output_nodes = graph->outputs();
  364. auto &outputs_with_index = run_op_graph_output_nodes_[graph->graph_id()];
  365. for (auto &node : output_nodes) {
  366. MS_EXCEPTION_IF_NULL(node);
  367. (void)outputs_with_index.emplace_back(AnfAlgo::VisitKernelWithReturnType(node, 0, false));
  368. }
  369. UpdateRefCountForGraphOutput(outputs_with_index);
  370. return graph->graph_id();
  371. }
  372. KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
  373. MS_EXCEPTION_IF_NULL(session_);
  374. return session_->GetGraph(graph_id);
  375. }
  376. KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
  377. auto iter = run_op_graphs_.find(graph_info);
  378. if (iter == run_op_graphs_.end()) {
  379. MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
  380. return nullptr;
  381. }
  382. return iter->second;
  383. }
  384. void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
  385. CreateParameterDeviceAddress(device_context, graph);
  386. CreateValueNodeDeviceAddress(device_context, graph);
  387. CreateKernelOutputDeviceAddress(device_context, graph);
  388. CreateKernelWorkspaceDeviceAddress(device_context, graph);
  389. UpdateDeviceAddressForInplaceNode(graph);
  390. }
  391. void GraphCompiler::GetParamAndOutputIndex(
  392. const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs, VectorRef *const outputs,
  393. std::map<AnfNodePtr, size_t> *parameter_index,
  394. std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
  395. MS_EXCEPTION_IF_NULL(session_);
  396. session_->GetParameterIndex(graph.get(), inputs, parameter_index);
  397. session_->CreateOutputPlaceholder(graph, inputs, outputs, output_indexes);
  398. }
  399. void GraphCompiler::GetSingleOpInputTensors(const CNodePtr &kernel,
  400. const std::map<KernelWithIndex, TensorPtr> &op_output,
  401. const std::map<AnfNodePtr, size_t> &parameter_index,
  402. const std::vector<TensorPtr> &graph_inputs,
  403. InputTensorInfo *const input_tensor_info) {
  404. MS_EXCEPTION_IF_NULL(session_);
  405. session_->GetOpInputTensors(kernel, op_output, parameter_index, graph_inputs, input_tensor_info);
  406. }
  407. TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel,
  408. const std::map<KernelWithIndex, TensorPtr> &op_output,
  409. const std::map<AnfNodePtr, size_t> &parameter_index,
  410. const std::vector<TensorPtr> &graph_inputs,
  411. InputTensorInfo *const input_tensor_info, size_t input_index) {
  412. MS_EXCEPTION_IF_NULL(session_);
  413. return session_->GetOpInputTensorByIndex(kernel, op_output, parameter_index, graph_inputs, input_tensor_info,
  414. input_index);
  415. }
  416. void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector<TensorPtr> &input_tensors,
  417. OpRunInfo *const run_info, GraphInfo *const graph_info) {
  418. MS_EXCEPTION_IF_NULL(session_);
  419. session_->GetSingleOpRunInfo(kernel, run_info);
  420. *graph_info = session_->GetSingleOpGraphInfo(kernel, input_tensors);
  421. }
  422. void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {
  423. MS_EXCEPTION_IF_NULL(session_);
  424. session_->GetRefCount(graph.get(), ref_count);
  425. }
  426. void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
  427. std::map<KernelWithIndex, size_t> *ref_count,
  428. std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const {
  429. MS_EXCEPTION_IF_NULL(session_);
  430. session_->HandleOpInputs(input_kernels_with_index, ref_count, op_output_map);
  431. }
  432. void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
  433. const std::map<KernelWithIndex, size_t> &ref_count,
  434. std::map<KernelWithIndex, TensorPtr> *op_output_map,
  435. GraphOutputInfo *const graph_output_info) const {
  436. MS_EXCEPTION_IF_NULL(session_);
  437. session_->HandleOpOutputs(kernel, op_outputs, ref_count, op_output_map, graph_output_info);
  438. }
  439. void GraphCompiler::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
  440. MS_EXCEPTION_IF_NULL(session_);
  441. session_->AddGradAddrToBucket(graph_id, grad_tensor);
  442. }
  443. void GraphCompiler::ClearAllBucket(const GraphId &graph_id) {
  444. MS_EXCEPTION_IF_NULL(session_);
  445. session_->ClearAllBucket(graph_id);
  446. }
  447. const std::vector<KernelWithIndex> &GraphCompiler::GetGraphOutputNodes(GraphId graph_id) const {
  448. const auto &iter = run_op_graph_output_nodes_.find(graph_id);
  449. if (iter == run_op_graph_output_nodes_.end()) {
  450. MS_LOG(EXCEPTION) << "Can not find output nodes for graph id: " << graph_id;
  451. }
  452. return iter->second;
  453. }
  454. void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
  455. MS_EXCEPTION_IF_NULL(session_);
  456. session_->RegisterSummaryCallBackFunc(callback);
  457. }
  458. void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
  459. MS_EXCEPTION_IF_NULL(session_);
  460. for (const auto &graph : graphs) {
  461. session_->Summary(graph.get());
  462. }
  463. }
  464. void GraphCompiler::EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id) {
  465. run_op_graphs_.erase(graph_info);
  466. run_op_graph_output_nodes_.erase(graph_id);
  467. }
  468. } // namespace runtime
  469. } // namespace mindspore