You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_build_ascend.cc 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/ascend/kernel_build_ascend.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <functional>
  21. #include "device/ascend/kernel_select_ascend.h"
  22. #include "device/kernel_info.h"
  23. #include "kernel/kernel.h"
  24. #include "kernel/tbe/tbe_kernel_build.h"
  25. #include "kernel/tbe/tbe_kernel_parallel_build.h"
  26. #include "kernel/aicpu/aicpu_kernel_build.h"
  27. #include "kernel/hccl/hccl_kernel_build.h"
  28. #include "kernel/mng/rt_kernel_build.h"
  29. #include "kernel/tbe/tbe_utils.h"
  30. #include "operator/ops.h"
  31. #include "session/anf_runtime_algorithm.h"
  32. #include "./common.h"
  33. namespace mindspore {
  34. namespace device {
  35. namespace ascend {
  36. using mindspore::kernel::tbe::TbeUtils;
  37. using std::make_shared;
  38. static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) {
  39. kernel::KernelModPtr kernel_mod_ptr = nullptr;
  40. KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
  41. switch (kernel_type) {
  42. case KernelType::AICPU_KERNEL: {
  43. kernel_mod_ptr = kernel::AicpuOpBuild(anf_node);
  44. break;
  45. }
  46. case KernelType::RT_KERNEL: {
  47. kernel_mod_ptr = kernel::RtOpBuild(anf_node);
  48. break;
  49. }
  50. case KernelType::HCCL_KERNEL: {
  51. kernel_mod_ptr = kernel::HcclOpBuild(anf_node);
  52. break;
  53. }
  54. default: {
  55. MS_LOG(EXCEPTION) << "node [" << anf_node->DebugString() << "] Unsupported kernel_type:" << kernel_type;
  56. }
  57. }
  58. return kernel_mod_ptr;
  59. }
  60. static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) {
  61. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  62. std::vector<AnfNodePtr> tbe_nodes;
  63. std::vector<AnfNodePtr> other_nodes;
  64. for (const auto &anf_node : kernel_graph_ptr->execution_order()) {
  65. MS_EXCEPTION_IF_NULL(anf_node);
  66. if (!AnfAlgo::IsRealKernel(anf_node)) {
  67. continue;
  68. }
  69. KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
  70. switch (kernel_type) {
  71. case KernelType::TBE_KERNEL: {
  72. if (AnfAlgo::GetKernelMod(anf_node) == nullptr) {
  73. tbe_nodes.push_back(anf_node);
  74. }
  75. break;
  76. }
  77. default: {
  78. other_nodes.push_back(anf_node);
  79. break;
  80. }
  81. }
  82. }
  83. bool ret = kernel::TbeOpParallelBuild(tbe_nodes);
  84. for (const auto &anf_node : other_nodes) {
  85. kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node);
  86. MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
  87. AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
  88. }
  89. return ret;
  90. }
  91. static std::vector<int> CalCleanZerosSize(const CNodePtr &pre_node) {
  92. MS_EXCEPTION_IF_NULL(pre_node);
  93. std::vector<int> clean_size_list;
  94. // clean output
  95. if (AnfAlgo::HasNodeAttr(kAttrAutomicOutputIndexs, pre_node)) {
  96. auto clean_output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAutomicOutputIndexs);
  97. for (auto index : clean_output_indexs) {
  98. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(pre_node, index);
  99. size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
  100. std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(pre_node, index);
  101. auto size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
  102. clean_size_list.push_back((size + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize);
  103. }
  104. }
  105. // clean workspace
  106. auto workspaces_size = 0;
  107. if (AnfAlgo::HasNodeAttr(kAttrAutomicWorkspaceSize, pre_node)) {
  108. workspaces_size = AnfAlgo::GetNodeAttr<int>(pre_node, kAttrAutomicWorkspaceSize);
  109. clean_size_list.push_back(workspaces_size);
  110. }
  111. MS_LOG(INFO) << "clear output size:" << clean_size_list.size() << ", workspace size:" << workspaces_size
  112. << ",pre_node:" << pre_node->fullname_with_scope();
  113. return clean_size_list;
  114. }
  115. static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph,
  116. const mindspore::CNodePtr &pre_node, std::vector<mindspore::CNodePtr> *new_nodes) {
  117. MS_EXCEPTION_IF_NULL(kernel_graph);
  118. MS_EXCEPTION_IF_NULL(pre_node);
  119. MS_EXCEPTION_IF_NULL(new_nodes);
  120. auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName);
  121. MS_EXCEPTION_IF_NULL(clear_zero_prim);
  122. auto new_value_node = NewValueNode(clear_zero_prim);
  123. MS_EXCEPTION_IF_NULL(new_value_node);
  124. std::vector<AnfNodePtr> inputs = {new_value_node};
  125. inputs.push_back(pre_node);
  126. CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
  127. MS_EXCEPTION_IF_NULL(clear_zero);
  128. AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
  129. MS_EXCEPTION_IF_NULL(abstract);
  130. clear_zero->set_abstract(abstract);
  131. auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  132. builder->SetKernelType(KernelType::TBE_KERNEL);
  133. AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get());
  134. auto clean_size = CalCleanZerosSize(pre_node);
  135. AnfAlgo::SetNodeAttr(kAttrAutomicAddMemSize, MakeValue(clean_size), clear_zero);
  136. AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clear_zero.get());
  137. new_nodes->push_back(clear_zero);
  138. }
  139. bool IsAtomicNode(const CNodePtr &kernel_node) {
  140. MS_EXCEPTION_IF_NULL(kernel_node);
  141. auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node);
  142. MS_EXCEPTION_IF_NULL(kernel_mod);
  143. auto parameters_indexs = kernel_mod->GenParameters();
  144. if (parameters_indexs.empty()) {
  145. return false;
  146. }
  147. auto atomic_flag = false;
  148. size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  149. size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  150. auto workspace_size_list = kernel_mod->GetWorkspaceSizeList();
  151. size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size();
  152. if (input_num + workspace_num + output_num > parameters_indexs.size()) {
  153. size_t lossNum = (input_num + workspace_num + output_num) - parameters_indexs.size();
  154. for (size_t i = 0; i < lossNum; i++) {
  155. parameters_indexs.push_back(0);
  156. }
  157. }
  158. std::vector<size_t> clean_output_indexs;
  159. // in parameters data sort as input->workspace->output
  160. size_t index = 0;
  161. while (index < output_num) {
  162. if (parameters_indexs[input_num + workspace_num + index] == 1) {
  163. atomic_flag = true;
  164. clean_output_indexs.push_back(index);
  165. }
  166. index++;
  167. }
  168. if (atomic_flag) {
  169. AnfAlgo::SetNodeAttr(kAttrAutomicOutputIndexs, MakeValue(clean_output_indexs), kernel_node);
  170. }
  171. for (size_t i = 0; i < workspace_num; ++i) {
  172. if (parameters_indexs[input_num + i] == 1) {
  173. atomic_flag = true;
  174. AnfAlgo::SetNodeAttr(kAttrAutomicWorkspaceSize,
  175. MakeValue(std::accumulate(workspace_size_list.begin(), workspace_size_list.end(), 0)),
  176. kernel_node);
  177. break;
  178. }
  179. }
  180. return atomic_flag;
  181. }
  182. bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) {
  183. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  184. TbeUtils::LoadCache();
  185. bool ret;
  186. ret = device::ascend::KernelBuildParallelCompile(kernel_graph_ptr);
  187. return ret;
  188. }
  189. void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
  190. MS_EXCEPTION_IF_NULL(kernel_graph);
  191. std::vector<CNodePtr> new_nodes;
  192. for (const auto &anf_node : kernel_graph->execution_order()) {
  193. std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
  194. if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
  195. AnfAlgo::GetKernelType(anf_node) == KernelType::AUTO_DIFF_KERNEL) {
  196. auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
  197. MS_EXCEPTION_IF_NULL(clear_zero_prim);
  198. auto new_value_node = NewValueNode(clear_zero_prim);
  199. MS_EXCEPTION_IF_NULL(new_value_node);
  200. std::vector<AnfNodePtr> inputs = {new_value_node};
  201. inputs.push_back(anf_node);
  202. CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
  203. MS_EXCEPTION_IF_NULL(clear_zero);
  204. auto kernel_info = std::make_shared<device::KernelInfo>();
  205. MS_EXCEPTION_IF_NULL(kernel_info);
  206. clear_zero->set_kernel_info(kernel_info);
  207. AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
  208. MS_EXCEPTION_IF_NULL(abstract);
  209. AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector<std::string>({"x"})), clear_zero);
  210. SelectKernelInfo(clear_zero);
  211. // set the distinction label of clear same with anf
  212. AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get());
  213. new_nodes.push_back(clear_zero);
  214. } else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) {
  215. if (IsAtomicNode(anf_node)) {
  216. AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
  217. }
  218. }
  219. new_nodes.push_back(anf_node);
  220. }
  221. kernel_graph->set_execution_order(new_nodes);
  222. }
  223. } // namespace ascend
  224. } // namespace device
  225. } // namespace mindspore