You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_build_ascend.cc 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/ascend/kernel_build_ascend.h"
  17. #include <vector>
  18. #include <string>
  19. #include <memory>
  20. #include <functional>
  21. #include "device/ascend/kernel_select_ascend.h"
  22. #include "device/kernel_info.h"
  23. #include "kernel/kernel.h"
  24. #include "kernel/tbe/tbe_kernel_build.h"
  25. #include "kernel/tbe/tbe_kernel_parallel_build.h"
  26. #include "kernel/akg/ascend/akg_ascend_kernel_build.h"
  27. #include "kernel/aicpu/aicpu_kernel_build.h"
  28. #include "kernel/hccl/hccl_kernel_build.h"
  29. #include "kernel/rts/rt_kernel_build.h"
  30. #include "kernel/tbe/tbe_utils.h"
  31. #include "kernel/common_utils.h"
  32. #include "operator/ops.h"
  33. #include "session/anf_runtime_algorithm.h"
  34. #include "./common.h"
  35. namespace mindspore {
  36. namespace device {
  37. namespace ascend {
  38. using mindspore::kernel::tbe::TbeUtils;
  39. using std::make_shared;
  40. static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) {
  41. kernel::KernelModPtr kernel_mod_ptr = nullptr;
  42. KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
  43. switch (kernel_type) {
  44. case KernelType::AICPU_KERNEL: {
  45. kernel_mod_ptr = kernel::AicpuOpBuild(anf_node);
  46. break;
  47. }
  48. case KernelType::RT_KERNEL: {
  49. kernel_mod_ptr = kernel::RtOpBuild(anf_node);
  50. break;
  51. }
  52. case KernelType::HCCL_KERNEL: {
  53. kernel_mod_ptr = kernel::HcclOpBuild(anf_node);
  54. break;
  55. }
  56. default: {
  57. MS_LOG(EXCEPTION) << "node [" << anf_node->DebugString() << "] Unsupported kernel_type:" << kernel_type;
  58. }
  59. }
  60. return kernel_mod_ptr;
  61. }
  62. static bool KernelPreBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) {
  63. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  64. std::vector<AnfNodePtr> tbe_nodes;
  65. for (const auto &anf_node : kernel_graph_ptr->execution_order()) {
  66. MS_EXCEPTION_IF_NULL(anf_node);
  67. if (!AnfAlgo::IsRealKernel(anf_node)) {
  68. continue;
  69. }
  70. KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
  71. switch (kernel_type) {
  72. case KernelType::TBE_KERNEL: {
  73. if (AnfAlgo::GetKernelMod(anf_node) == nullptr &&
  74. AnfAlgo::GetFusionType(anf_node) == kernel::FusionType::DYNAMIC) {
  75. tbe_nodes.push_back(anf_node);
  76. }
  77. break;
  78. }
  79. default: {
  80. break;
  81. }
  82. }
  83. }
  84. bool ret = kernel::TbeOpParallelPreBuild(tbe_nodes);
  85. return ret;
  86. }
  87. static bool KernelBuildParallelCompile(const mindspore::session::KernelGraph *kernel_graph_ptr) {
  88. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  89. std::vector<AnfNodePtr> tbe_nodes;
  90. std::vector<AnfNodePtr> akg_nodes;
  91. std::vector<AnfNodePtr> other_nodes;
  92. for (const auto &anf_node : kernel_graph_ptr->execution_order()) {
  93. MS_EXCEPTION_IF_NULL(anf_node);
  94. if (!AnfAlgo::IsRealKernel(anf_node)) {
  95. continue;
  96. }
  97. KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);
  98. switch (kernel_type) {
  99. case KernelType::TBE_KERNEL: {
  100. if (AnfAlgo::GetKernelMod(anf_node) == nullptr) {
  101. tbe_nodes.push_back(anf_node);
  102. }
  103. break;
  104. }
  105. case KernelType::AKG_KERNEL: {
  106. akg_nodes.push_back(anf_node);
  107. break;
  108. }
  109. default: {
  110. other_nodes.push_back(anf_node);
  111. break;
  112. }
  113. }
  114. }
  115. bool tbe_ret = kernel::TbeOpParallelBuild(tbe_nodes);
  116. bool akg_ret = kernel::AkgAscendKernelParallelBuild(akg_nodes);
  117. auto bin_map = kernel::tbe::KernelMeta::GetInstance();
  118. (void)bin_map->ReadIndex(kernel::kCceKernelMeta);
  119. for (const auto &anf_node : other_nodes) {
  120. kernel::KernelModPtr kernel_mod_ptr = SerialCompileImpl(anf_node);
  121. MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
  122. AnfAlgo::SetKernelMod(kernel_mod_ptr, anf_node.get());
  123. }
  124. return tbe_ret && akg_ret;
  125. }
  126. static std::vector<size_t> CalCleanZerosSize(const CNodePtr &pre_node) {
  127. MS_EXCEPTION_IF_NULL(pre_node);
  128. auto kernel_mod = AnfAlgo::GetKernelMod(pre_node);
  129. MS_EXCEPTION_IF_NULL(kernel_mod);
  130. std::vector<size_t> clean_size_list;
  131. // clean output
  132. if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
  133. auto output_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
  134. auto output_men_size = kernel_mod->GetOutputSizeList();
  135. for (auto index : output_indexs) {
  136. auto clean_item = (output_men_size.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize;
  137. clean_size_list.emplace_back(clean_item);
  138. }
  139. }
  140. // clean workspace
  141. if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
  142. auto workspace_indexs = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
  143. auto workspace_men_sizes = kernel_mod->GetWorkspaceSizeList();
  144. for (const auto &index : workspace_indexs) {
  145. auto clean_item = (workspace_men_sizes.at(index) + kMemAlignSize + 31) / kMemAlignSize * kMemAlignSize;
  146. clean_size_list.emplace_back(clean_item);
  147. }
  148. }
  149. MS_LOG(INFO) << "clear output size:" << clean_size_list.size() << ",pre_node:" << pre_node->fullname_with_scope();
  150. return clean_size_list;
  151. }
  152. static void AddTbeClearZeroNode(mindspore::session::KernelGraph *const kernel_graph,
  153. const mindspore::CNodePtr &pre_node, std::vector<mindspore::CNodePtr> *new_nodes) {
  154. MS_EXCEPTION_IF_NULL(kernel_graph);
  155. MS_EXCEPTION_IF_NULL(pre_node);
  156. MS_EXCEPTION_IF_NULL(new_nodes);
  157. auto clear_zero_prim = std::make_shared<Primitive>(kAtomicAddrCleanOpName);
  158. MS_EXCEPTION_IF_NULL(clear_zero_prim);
  159. auto new_value_node = NewValueNode(clear_zero_prim);
  160. MS_EXCEPTION_IF_NULL(new_value_node);
  161. std::vector<AnfNodePtr> inputs = {new_value_node};
  162. inputs.push_back(pre_node);
  163. CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
  164. MS_EXCEPTION_IF_NULL(clear_zero);
  165. AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
  166. MS_EXCEPTION_IF_NULL(abstract);
  167. clear_zero->set_abstract(abstract);
  168. auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  169. builder->SetKernelType(KernelType::TBE_KERNEL);
  170. AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), clear_zero.get());
  171. auto clean_size = CalCleanZerosSize(pre_node);
  172. AnfAlgo::SetNodeAttr(kAttrAtomicAddMemSize, MakeValue(clean_size), clear_zero);
  173. AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(pre_node.get()), clear_zero.get());
  174. new_nodes->push_back(clear_zero);
  175. }
  176. static bool IsAtomicNode(const CNodePtr &kernel_node) {
  177. MS_EXCEPTION_IF_NULL(kernel_node);
  178. auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node);
  179. MS_EXCEPTION_IF_NULL(kernel_mod);
  180. auto parameters_indexs = kernel_mod->GenParameters();
  181. if (parameters_indexs.empty()) {
  182. return false;
  183. }
  184. size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  185. size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  186. size_t workspace_num = kernel_mod->GetWorkspaceSizeList().size();
  187. size_t param_num = parameters_indexs.size();
  188. size_t total_num = input_num + workspace_num + output_num;
  189. MS_LOG(INFO) << "parameters size: " << param_num << ", input & workspace & output num: " << total_num;
  190. size_t pad_index = param_num;
  191. for (; pad_index < total_num; ++pad_index) {
  192. parameters_indexs.emplace_back(0);
  193. }
  194. // process input
  195. for (size_t j = 0; j < input_num; ++j) {
  196. if (parameters_indexs.at(j) == 1) {
  197. MS_LOG(EXCEPTION) << "Atomic addr clean does't support clean input address, input index: " << j;
  198. }
  199. }
  200. // process output
  201. std::vector<size_t> output_indexs = {};
  202. for (size_t i = 0; i < output_num; ++i) {
  203. auto param_output = parameters_indexs.at(input_num + workspace_num + i);
  204. if (param_output == 1) {
  205. output_indexs.emplace_back(i);
  206. MS_LOG(INFO) << "Atomic clear output index: " << i;
  207. }
  208. }
  209. if (!output_indexs.empty()) {
  210. AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(output_indexs), kernel_node);
  211. }
  212. // process workspace
  213. std::vector<size_t> workspace_indexs = {};
  214. for (size_t k = 0; k < workspace_num; ++k) {
  215. auto param_workspace = parameters_indexs.at(input_num + k);
  216. if (param_workspace == 1) {
  217. workspace_indexs.emplace_back(k);
  218. MS_LOG(INFO) << "Atomic clear workspace index: " << k;
  219. }
  220. }
  221. if (!workspace_indexs.empty()) {
  222. AnfAlgo::SetNodeAttr(kAttrAtomicWorkspaceIndexs, MakeValue(workspace_indexs), kernel_node);
  223. }
  224. return !(workspace_indexs.empty() && output_indexs.empty());
  225. }
  226. bool KernelPreBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) {
  227. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  228. bool ret = device::ascend::KernelPreBuildParallelCompile(kernel_graph_ptr);
  229. return ret;
  230. }
  231. bool KernelBuild(const mindspore::session::KernelGraph *kernel_graph_ptr) {
  232. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  233. TbeUtils::LoadCache();
  234. bool ret;
  235. ret = device::ascend::KernelBuildParallelCompile(kernel_graph_ptr);
  236. return ret;
  237. }
  238. void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
  239. MS_EXCEPTION_IF_NULL(kernel_graph);
  240. std::vector<CNodePtr> new_nodes;
  241. for (const auto &anf_node : kernel_graph->execution_order()) {
  242. std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
  243. if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
  244. AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
  245. auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
  246. MS_EXCEPTION_IF_NULL(clear_zero_prim);
  247. auto new_value_node = NewValueNode(clear_zero_prim);
  248. MS_EXCEPTION_IF_NULL(new_value_node);
  249. std::vector<AnfNodePtr> inputs = {new_value_node};
  250. inputs.push_back(anf_node);
  251. CNodePtr clear_zero = kernel_graph->NewCNode(inputs);
  252. MS_EXCEPTION_IF_NULL(clear_zero);
  253. auto kernel_info = std::make_shared<device::KernelInfo>();
  254. MS_EXCEPTION_IF_NULL(kernel_info);
  255. clear_zero->set_kernel_info(kernel_info);
  256. AbstractBasePtr abstract = std::make_shared<abstract::AbstractNone>();
  257. MS_EXCEPTION_IF_NULL(abstract);
  258. AnfAlgo::SetNodeAttr("input_names", MakeValue(std::vector<std::string>({"x"})), clear_zero);
  259. SelectKernelInfo(clear_zero);
  260. // set the distinction label of clear same with anf
  261. AnfAlgo::SetStreamDistinctionLabel(AnfAlgo::GetStreamDistinctionLabel(anf_node.get()), clear_zero.get());
  262. new_nodes.push_back(clear_zero);
  263. } else if (AnfAlgo::GetKernelType(anf_node) == KernelType::TBE_KERNEL) {
  264. if (IsAtomicNode(anf_node)) {
  265. AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
  266. }
  267. }
  268. new_nodes.push_back(anf_node);
  269. }
  270. kernel_graph->set_execution_order(new_nodes);
  271. }
  272. } // namespace ascend
  273. } // namespace device
  274. } // namespace mindspore