You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_compiler.cc 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License"){}
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/framework/graph_compiler.h"
  17. #include <numeric>
  18. #include <map>
  19. #include "runtime/framework/graph_scheduler.h"
  20. #include "runtime/device/device_address.h"
  21. #include "common/trans.h"
  22. #include "utils/convert_utils.h"
  23. #include "ir/tensor.h"
  24. namespace mindspore {
  25. namespace runtime {
  26. namespace {
  27. // Whether device address of anf node is valid and device address type
  28. // is consistent with device type, for example, device address type
  29. // DeviceAddressType::kGPU should be used on GPU device
  30. bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) {
  31. MS_EXCEPTION_IF_NULL(kernel);
  32. MS_EXCEPTION_IF_NULL(device_context);
  33. if (AnfAlgo::OutputAddrExist(kernel, index)) {
  34. const auto &address = AnfAlgo::GetOutputAddr(kernel, index);
  35. MS_EXCEPTION_IF_NULL(address);
  36. return address->DeviceType() == device_context->GetDeviceAddressType();
  37. }
  38. return false;
  39. }
  40. void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  41. MS_EXCEPTION_IF_NULL(device_context);
  42. MS_EXCEPTION_IF_NULL(graph);
  43. std::vector<AnfNodePtr> graph_inputs = graph->inputs();
  44. const std::vector<bool> &graph_valid_input = graph->valid_inputs();
  45. graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
  46. // Anf nodes which need create device address.
  47. std::vector<AnfNodePtr> nodes_list;
  48. for (size_t i = 0; i < graph_inputs.size(); ++i) {
  49. AnfNodePtr item = graph_inputs[i];
  50. MS_EXCEPTION_IF_NULL(item);
  51. if (i < graph_valid_input.size() && !graph_valid_input[i]) {
  52. continue;
  53. }
  54. if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
  55. std::vector<AnfNodePtr> outs = AnfAlgo::GetAllOutput(item);
  56. for (const auto &out : outs) {
  57. MS_EXCEPTION_IF_NULL(out);
  58. if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
  59. continue;
  60. }
  61. nodes_list.push_back(out);
  62. }
  63. }
  64. if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
  65. continue;
  66. }
  67. nodes_list.push_back(item);
  68. }
  69. // Create device address for anf node in nodes_list
  70. for (const auto &item : nodes_list) {
  71. auto output_size = AnfAlgo::GetOutputTensorNum(item);
  72. for (size_t index = 0; index < output_size; index++) {
  73. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
  74. // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
  75. if (output_type_id == kTypeUnknown) {
  76. MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
  77. continue;
  78. }
  79. size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
  80. auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
  81. AnfAlgo::GetOutputFormat(item, index), output_type_id);
  82. AnfAlgo::SetOutputAddr(device_address, index, item.get());
  83. }
  84. }
  85. }
  86. void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
  87. size_t output_idx, const ValueNodePtr &value_node) {
  88. MS_EXCEPTION_IF_NULL(device_context);
  89. MS_EXCEPTION_IF_NULL(node_value);
  90. MS_EXCEPTION_IF_NULL(value_node);
  91. const auto &ms_context = MsContext::GetInstance();
  92. MS_EXCEPTION_IF_NULL(ms_context);
  93. std::vector<tensor::TensorPtr> tensors;
  94. TensorValueToTensor(node_value, &tensors);
  95. for (const auto &tensor : tensors) {
  96. if (tensor == nullptr) {
  97. MS_LOG(WARNING) << "Tensor is null";
  98. return;
  99. }
  100. auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
  101. if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) {
  102. AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
  103. value_node.get());
  104. continue;
  105. }
  106. size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
  107. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
  108. if (output_type_id == kTypeUnknown) {
  109. output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
  110. }
  111. std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
  112. device::DeviceAddressPtr address =
  113. device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
  114. MS_EXCEPTION_IF_NULL(address);
  115. AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
  116. }
  117. }
  118. void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  119. MS_EXCEPTION_IF_NULL(device_context);
  120. MS_EXCEPTION_IF_NULL(graph);
  121. for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
  122. MS_EXCEPTION_IF_NULL(value_node);
  123. if (NodeDeviceAddressExist(device_context, value_node, 0)) {
  124. continue;
  125. }
  126. const auto &node_value = value_node->value();
  127. MS_EXCEPTION_IF_NULL(node_value);
  128. if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
  129. CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
  130. } else if (node_value->isa<StringImm>()) {
  131. auto value = GetValue<std::string>(node_value);
  132. size_t tensor_size = value.size();
  133. auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
  134. MS_EXCEPTION_IF_NULL(address);
  135. AnfAlgo::SetOutputAddr(address, 0, value_node.get());
  136. }
  137. }
  138. }
  139. void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  140. MS_EXCEPTION_IF_NULL(device_context);
  141. MS_EXCEPTION_IF_NULL(graph);
  142. const std::vector<CNodePtr> &kernels = graph->execution_order();
  143. for (const auto &kernel : kernels) {
  144. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  145. MS_EXCEPTION_IF_NULL(kernel_mod);
  146. auto output_sizes = kernel_mod->GetOutputSizeList();
  147. for (size_t i = 0; i < output_sizes.size(); ++i) {
  148. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  149. continue;
  150. }
  151. std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
  152. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  153. auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  154. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  155. }
  156. }
  157. }
  158. void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
  159. MS_EXCEPTION_IF_NULL(device_context);
  160. MS_EXCEPTION_IF_NULL(graph);
  161. const std::vector<CNodePtr> &kernels = graph->execution_order();
  162. for (const auto &kernel : kernels) {
  163. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  164. MS_EXCEPTION_IF_NULL(kernel_mod);
  165. auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
  166. for (size_t i = 0; i < workspace_sizes.size(); ++i) {
  167. auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown);
  168. AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
  169. }
  170. }
  171. }
  172. } // namespace
  173. void GraphCompiler::set_device_context(DeviceContext *device_context) {
  174. MS_EXCEPTION_IF_NULL(device_context);
  175. device_context_ = device_context;
  176. // The member variable 'session_' will be removed after removing session module.
  177. if (session_ == nullptr) {
  178. session_ = std::make_shared<session::SessionBasic>();
  179. const device::DeviceContextKey &device_context_key = device_context->device_context_key();
  180. session_->InitExecutor(device_context_key.device_name_, device_context_key.device_id_);
  181. }
  182. }
  183. GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs) {
  184. MS_EXCEPTION_IF_NULL(session_);
  185. // Generate kernel graph.
  186. KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs);
  187. MS_EXCEPTION_IF_NULL(graph);
  188. return CompileGraphImpl(graph);
  189. }
  190. GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const {
  191. MS_EXCEPTION_IF_NULL(device_context_);
  192. // Optimization pass which is irrelevant to device type or format.
  193. device_context_->OptimizeGraphWithoutDeviceInfo(graph);
  194. device_context_->SetOperatorInfo(graph->execution_order());
  195. // Optimization pass which is relevant to device type or format.
  196. device_context_->OptimizeGraphWithDeviceInfo(graph);
  197. // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
  198. // 'KernelMod' is real executive object of kernel.
  199. device_context_->CreateKernel(graph->execution_order());
  200. // Create device address for all anf nodes of graph.
  201. CreateDeviceAddress(graph);
  202. // Transform graph to actor DAG, contains build and link.
  203. GraphScheduler::GetInstance().Transform(graph, device_context_);
  204. return graph->graph_id();
  205. }
  206. GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const GraphInfo &graph_info,
  207. std::vector<tensor::TensorPtr> *input_tensors,
  208. const std::vector<int64_t> &tensors_mask) {
  209. // Check if the graph cache exists.
  210. auto iter = run_op_graphs_.find(graph_info);
  211. if (iter != run_op_graphs_.end()) {
  212. const auto &graph = iter->second;
  213. MS_EXCEPTION_IF_NULL(graph);
  214. return graph->graph_id();
  215. }
  216. // Generate kernel graph.
  217. MS_EXCEPTION_IF_NULL(session_);
  218. KernelGraphPtr graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask);
  219. MS_EXCEPTION_IF_NULL(graph);
  220. MS_EXCEPTION_IF_NULL(device_context_);
  221. device_context_->SetOperatorInfo(graph->execution_order());
  222. device_context_->OptimizeSingleOpGraph(graph);
  223. MS_EXCEPTION_IF_NULL(session_);
  224. session_->RunOpHideNopNode(graph);
  225. session_->RunOpRemoveNopNode(graph);
  226. // Generate 'KernelMod' for kernel in graph.
  227. device_context_->CreateKernel(graph->execution_order());
  228. // Create device address for all anf nodes of graph.
  229. CreateDeviceAddress(graph);
  230. // Transform graph to actor DAG, contains build and link.
  231. GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep);
  232. run_op_graphs_[graph_info] = graph;
  233. return graph->graph_id();
  234. }
  235. KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
  236. MS_EXCEPTION_IF_NULL(session_);
  237. return session_->GetGraph(graph_id);
  238. }
  239. KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
  240. auto iter = run_op_graphs_.find(graph_info);
  241. if (iter == run_op_graphs_.end()) {
  242. MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
  243. return nullptr;
  244. }
  245. return iter->second;
  246. }
  247. void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph) const {
  248. CreateParameterDeviceAddress(device_context_, graph);
  249. CreateValueNodeDeviceAddress(device_context_, graph);
  250. CreateKernelOutputDeviceAddress(device_context_, graph);
  251. CreateKernelWorkspaceDeviceAddress(device_context_, graph);
  252. }
  253. } // namespace runtime
  254. } // namespace mindspore