You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_compiler.cc 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License"){}
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/framework/graph_compiler.h"
  17. #include <numeric>
  18. #include <map>
  19. #include "runtime/framework/graph_scheduler.h"
  20. #include "runtime/device/device_address.h"
  21. #include "common/trans.h"
  22. #include "utils/convert_utils.h"
  23. #include "ir/tensor.h"
  24. #include "backend/optimizer/common/helper.h"
  25. namespace mindspore {
  26. namespace runtime {
  27. namespace {
  28. // Whether device address of anf node is valid and device address type
  29. // is consistent with device type, for example, device address type
  30. // DeviceAddressType::kGPU should be used on GPU device
  31. bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) {
  32. MS_EXCEPTION_IF_NULL(kernel);
  33. MS_EXCEPTION_IF_NULL(device_context);
  34. if (AnfAlgo::OutputAddrExist(kernel, index)) {
  35. const auto &address = AnfAlgo::GetOutputAddr(kernel, index);
  36. MS_EXCEPTION_IF_NULL(address);
  37. return address->DeviceType() == device_context->GetDeviceAddressType();
  38. }
  39. return false;
  40. }
  41. void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  42. MS_EXCEPTION_IF_NULL(device_context);
  43. MS_EXCEPTION_IF_NULL(graph);
  44. std::vector<AnfNodePtr> graph_inputs = graph->inputs();
  45. const std::vector<bool> &graph_valid_input = graph->valid_inputs();
  46. graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
  47. // Anf nodes which need create device address.
  48. std::vector<AnfNodePtr> nodes_list;
  49. for (size_t i = 0; i < graph_inputs.size(); ++i) {
  50. AnfNodePtr item = graph_inputs[i];
  51. MS_EXCEPTION_IF_NULL(item);
  52. if (i < graph_valid_input.size() && !graph_valid_input[i]) {
  53. continue;
  54. }
  55. if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
  56. std::vector<AnfNodePtr> outs = AnfAlgo::GetAllOutput(item);
  57. for (const auto &out : outs) {
  58. MS_EXCEPTION_IF_NULL(out);
  59. if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
  60. continue;
  61. }
  62. nodes_list.push_back(out);
  63. }
  64. }
  65. if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
  66. continue;
  67. }
  68. nodes_list.push_back(item);
  69. }
  70. // Create device address for anf node in nodes_list
  71. for (const auto &item : nodes_list) {
  72. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  73. for (size_t index = 0; index < output_size; index++) {
  74. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  75. if (output_type_id == kTypeUnknown) {
  76. output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
  77. }
  78. size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
  79. auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
  80. AnfAlgo::GetOutputFormat(item, index), output_type_id);
  81. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  82. }
  83. }
  84. }
  85. void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
  86. size_t output_idx, const ValueNodePtr &value_node) {
  87. MS_EXCEPTION_IF_NULL(device_context);
  88. MS_EXCEPTION_IF_NULL(node_value);
  89. MS_EXCEPTION_IF_NULL(value_node);
  90. const auto &ms_context = MsContext::GetInstance();
  91. MS_EXCEPTION_IF_NULL(ms_context);
  92. std::vector<tensor::TensorPtr> tensors;
  93. TensorValueToTensor(node_value, &tensors);
  94. for (const auto &tensor : tensors) {
  95. if (tensor == nullptr) {
  96. MS_LOG(WARNING) << "Tensor is null";
  97. return;
  98. }
  99. auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
  100. if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) {
  101. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
  102. value_node.get());
  103. continue;
  104. }
  105. size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
  106. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
  107. if (output_type_id == kTypeUnknown) {
  108. output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  109. }
  110. std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
  111. device::DeviceAddressPtr address =
  112. device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
  113. MS_EXCEPTION_IF_NULL(address);
  114. AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
  115. }
  116. }
  117. void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  118. MS_EXCEPTION_IF_NULL(device_context);
  119. MS_EXCEPTION_IF_NULL(graph);
  120. for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
  121. MS_EXCEPTION_IF_NULL(value_node);
  122. if (NodeDeviceAddressExist(device_context, value_node, 0)) {
  123. continue;
  124. }
  125. const auto &node_value = value_node->value();
  126. MS_EXCEPTION_IF_NULL(node_value);
  127. if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
  128. CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
  129. } else if (node_value->isa<StringImm>()) {
  130. auto value = GetValue<std::string>(node_value);
  131. size_t tensor_size = value.size();
  132. auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  133. MS_EXCEPTION_IF_NULL(address);
  134. AnfAlgo::SetOutputAddr(address, 0, value_node.get());
  135. }
  136. }
  137. }
  138. void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  139. MS_EXCEPTION_IF_NULL(device_context);
  140. MS_EXCEPTION_IF_NULL(graph);
  141. const std::vector<CNodePtr> &kernels = graph->execution_order();
  142. for (const auto &kernel : kernels) {
  143. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  144. MS_EXCEPTION_IF_NULL(kernel_mod);
  145. auto output_sizes = kernel_mod->GetOutputSizeList();
  146. for (size_t i = 0; i < output_sizes.size(); ++i) {
  147. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  148. continue;
  149. }
  150. std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
  151. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  152. auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  153. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  154. }
  155. }
  156. }
  157. void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  158. MS_EXCEPTION_IF_NULL(device_context);
  159. MS_EXCEPTION_IF_NULL(graph);
  160. const std::vector<CNodePtr> &kernels = graph->execution_order();
  161. for (const auto &kernel : kernels) {
  162. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  163. MS_EXCEPTION_IF_NULL(kernel_mod);
  164. auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
  165. for (size_t i = 0; i < workspace_sizes.size(); ++i) {
  166. auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown);
  167. AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
  168. }
  169. }
  170. }
  171. } // namespace
  172. void GraphCompiler::set_device_context(DeviceContext *device_context) {
  173. MS_EXCEPTION_IF_NULL(device_context);
  174. device_context_ = device_context;
  175. // The member variable 'session_' will be removed after removing session module.
  176. if (session_ == nullptr) {
  177. session_ = std::make_shared<session::SessionBasic>();
  178. const device::DeviceContextKey &device_context_key = device_context->device_context_key();
  179. session_->InitExecutor(device_context_key.device_name_, device_context_key.device_id_);
  180. }
  181. }
  182. GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs) {
  183. MS_EXCEPTION_IF_NULL(session_);
  184. // Generate kernel graph.
  185. KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs);
  186. MS_EXCEPTION_IF_NULL(graph);
  187. return CompileGraphImpl(graph);
  188. }
  189. GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const {
  190. MS_EXCEPTION_IF_NULL(graph);
  191. MS_EXCEPTION_IF_NULL(device_context_);
  192. // Execute optimization pass.
  193. device_context_->OptimizeGraph(graph);
  194. // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
  195. // 'KernelMod' is real executive object of kernel.
  196. device_context_->CreateKernel(graph->execution_order());
  197. // Create device address for all anf nodes of graph.
  198. CreateDeviceAddress(graph);
  199. graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
  200. return graph->graph_id();
  201. }
  202. GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
  203. std::vector<tensor::TensorPtr> *input_tensors,
  204. const std::vector<int64_t> &tensors_mask) {
  205. // Check if the graph cache exists.
  206. auto iter = run_op_graphs_.find(graph_info);
  207. if (iter != run_op_graphs_.end()) {
  208. const auto &graph = iter->second;
  209. MS_EXCEPTION_IF_NULL(graph);
  210. return graph->graph_id();
  211. }
  212. // Generate kernel graph.
  213. MS_EXCEPTION_IF_NULL(session_);
  214. KernelGraphPtr graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask);
  215. MS_EXCEPTION_IF_NULL(graph);
  216. MS_EXCEPTION_IF_NULL(device_context_);
  217. device_context_->OptimizeSingleOpGraph(graph);
  218. MS_EXCEPTION_IF_NULL(session_);
  219. session_->RunOpHideNopNode(graph);
  220. session_->RunOpRemoveNopNode(graph);
  221. // Generate 'KernelMod' for kernel in graph.
  222. device_context_->CreateKernel(graph->execution_order());
  223. // Create device address for all anf nodes of graph.
  224. CreateDeviceAddress(graph);
  225. graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
  226. // Transform graph to actor DAG, contains build and link.
  227. GraphScheduler::GetInstance().Transform({graph}, {device_context_}, input_tensors, nullptr,
  228. GraphExecutionStrategy::kStep);
  229. run_op_graphs_[graph_info] = graph;
  230. return graph->graph_id();
  231. }
  232. KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
  233. MS_EXCEPTION_IF_NULL(session_);
  234. return session_->GetGraph(graph_id);
  235. }
  236. KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
  237. auto iter = run_op_graphs_.find(graph_info);
  238. if (iter == run_op_graphs_.end()) {
  239. MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
  240. return nullptr;
  241. }
  242. return iter->second;
  243. }
  244. void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph) const {
  245. CreateParameterDeviceAddress(device_context_, graph);
  246. CreateValueNodeDeviceAddress(device_context_, graph);
  247. CreateKernelOutputDeviceAddress(device_context_, graph);
  248. CreateKernelWorkspaceDeviceAddress(device_context_, graph);
  249. }
  250. } // namespace runtime
  251. } // namespace mindspore