You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_select_ascend.cc 28 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
6 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/ascend/kernel_select_ascend.h"
  17. #include <string>
  18. #include <vector>
  19. #include <memory>
  20. #include <utility>
  21. #include <algorithm>
  22. #include <map>
  23. #include <unordered_map>
  24. #include <unordered_set>
  25. #include "common/utils.h"
  26. #include "debug/anf_ir_dump.h"
  27. #include "operator/ops.h"
  28. #include "ir/func_graph.h"
  29. #include "utils/context/ms_context.h"
  30. #include "session/anf_runtime_algorithm.h"
  31. #include "device/kernel_info.h"
  32. #include "kernel/common_utils.h"
  33. #include "kernel/kernel_query.h"
  34. #include "kernel/oplib/oplib.h"
  35. #include "kernel/kernel_build_info.h"
  36. namespace mindspore {
  37. namespace device {
  38. namespace ascend {
  39. namespace {
  40. const float kWegihtBaseScore = 1;
  41. const float kFeatureMapBaseScore = 10;
  42. constexpr auto kPriChoosenFormat = "pri_format";
  43. enum MatchCountPriority : int {
  44. MATCH_COUNT_PRIORITY_BEGIN = 0,
  45. MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
  46. MATCH_FORMAT_COUNT,
  47. MATCH_SPECIAL_FORMAT_COUNT,
  48. MATCH_DEFAULT_FORMAT_COUNT,
  49. MATCH_OUTPUT_DTYPE_COUNT,
  50. MATCH_COUNT_PRIORITY_END
  51. };
  52. const int kUnSupportMixedDataTypeIndex = -1;
  53. bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
  54. MS_EXCEPTION_IF_NULL(cnode);
  55. // Check input data type
  56. for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
  57. TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
  58. if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
  59. return false;
  60. }
  61. }
  62. // Check output data type
  63. for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
  64. if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
  65. return false;
  66. }
  67. }
  68. return true;
  69. }
  70. string GetPriorityMatchFormat(const CNodePtr &cnode) {
  71. string priority_matched_format = kOpFormat_NC1HWC0;
  72. bool is_init = false;
  73. bool need_change_nd = false;
  74. for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
  75. auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
  76. if (AnfAlgo::IsFeatureMapInput(cnode, index) &&
  77. kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) {
  78. priority_matched_format = !is_init ? pre_output_format : priority_matched_format;
  79. is_init = true;
  80. }
  81. // feature map has two or more special format;
  82. if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) {
  83. priority_matched_format = kOpFormat_DEFAULT;
  84. }
  85. auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size();
  86. need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1));
  87. }
  88. if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) {
  89. priority_matched_format = kOpFormat_DEFAULT;
  90. }
  91. AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode);
  92. return priority_matched_format;
  93. }
  94. /**
  95. * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location,
  96. * if equal then next num location
  97. * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3]
  98. */
  99. bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) {
  100. MS_EXCEPTION_IF_NULL(best_item);
  101. if (cur_item.size() != best_item->size()) {
  102. MS_LOG(ERROR) << "Item size should be same!";
  103. return false;
  104. }
  105. // Update the best_item by comparing the cur_item and best_item
  106. for (size_t i = 0; i < cur_item.size(); i++) {
  107. if (cur_item[i] > best_item->at(i)) {
  108. *best_item = cur_item;
  109. return true;
  110. } else if (cur_item[i] == best_item->at(i)) {
  111. continue;
  112. } else {
  113. return false;
  114. }
  115. }
  116. return false;
  117. }
  118. void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
  119. std::vector<int> *const cur_kernelinfo_match_counts) {
  120. MS_EXCEPTION_IF_NULL(kernel_node);
  121. MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts);
  122. if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) {
  123. MS_LOG(EXCEPTION) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END;
  124. }
  125. auto pri_match_format = GetPriorityMatchFormat(kernel_node);
  126. for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
  127. auto input_anf_node = kernel_node->input(input_index + 1);
  128. // we do not take ValueNode into consideration in graph kernel.
  129. if (kernel_build_info.kernel_type() == KernelType::AKG_KERNEL) {
  130. if (input_anf_node->isa<ValueNode>() && AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) {
  131. continue;
  132. }
  133. }
  134. auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore;
  135. if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
  136. (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score;
  137. }
  138. // we match output fix precision first.
  139. auto prev_device_type = AnfAlgo::GetPrevNodeOutputPrecision(kernel_node, input_index);
  140. if (prev_device_type == kTypeUnknown) {
  141. prev_device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
  142. }
  143. if (kernel_build_info.GetInputDeviceType(input_index) == prev_device_type) {
  144. (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score;
  145. }
  146. if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
  147. (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score;
  148. }
  149. if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) {
  150. (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score;
  151. }
  152. }
  153. for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
  154. // cal count of same output dtype between abstract and kernel info
  155. if (kernel_build_info.GetOutputDeviceType(output_index) ==
  156. AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) {
  157. (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT] += 1;
  158. }
  159. }
  160. }
  161. void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) {
  162. MS_EXCEPTION_IF_NULL(support_index);
  163. int index = kUnSupportMixedDataTypeIndex;
  164. switch (data_type) {
  165. case kNumberTypeFloat16:
  166. index = 0;
  167. break;
  168. case kNumberTypeFloat32:
  169. case kNumberTypeFloat:
  170. index = 1;
  171. break;
  172. default:
  173. break;
  174. }
  175. support_index->push_back(index);
  176. }
  177. void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index,
  178. std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) {
  179. MS_EXCEPTION_IF_NULL(support_datatype);
  180. auto data_type = kernel_build_info.GetInputDeviceType(input_index);
  181. support_datatype->push_back(data_type);
  182. AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index);
  183. }
  184. void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index,
  185. std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) {
  186. MS_EXCEPTION_IF_NULL(support_datatype);
  187. auto data_type = kernel_build_info.GetOutputDeviceType(output_index);
  188. support_datatype->push_back(data_type);
  189. AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index);
  190. }
  191. void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index,
  192. std::vector<int> *node_mix_precision_datatype_index,
  193. std::vector<TypeId> *node_mix_precision_datatype) {
  194. AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index);
  195. MS_EXCEPTION_IF_NULL(cur_input);
  196. MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
  197. TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
  198. AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index);
  199. node_mix_precision_datatype->push_back(input_origin_type);
  200. }
  201. void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index,
  202. std::vector<int> *node_mix_precision_datatype_index,
  203. std::vector<TypeId> *node_mix_precision_datatype) {
  204. MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
  205. auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
  206. AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index);
  207. node_mix_precision_datatype->push_back(output_origin_type);
  208. }
  209. void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_index,
  210. const std::vector<TypeId> &node_mix_precision_datatype,
  211. const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
  212. std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
  213. if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) {
  214. MS_LOG(EXCEPTION) << "Node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size "
  215. << node_mix_precision_datatype.size();
  216. }
  217. MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
  218. if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) {
  219. MS_LOG(EXCEPTION) << "Kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size "
  220. << kernel_support_datatypes.size();
  221. }
  222. }
  223. bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
  224. const std::vector<TypeId> &node_mix_precision_datatype,
  225. const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
  226. std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
  227. MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
  228. CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes,
  229. kernel_match_datatype_idx);
  230. for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) {
  231. if (node_mix_precision_datatype[i] == kTypeUnknown) {
  232. continue;
  233. }
  234. auto iter = kernel_match_datatype_idx->begin();
  235. while (iter != kernel_match_datatype_idx->end()) {
  236. if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) {
  237. auto find_iter = kernel_support_datatypes.find(iter->first);
  238. if (find_iter == kernel_support_datatypes.end()) {
  239. MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first;
  240. }
  241. if (i >= find_iter->second.size()) {
  242. MS_LOG(EXCEPTION) << "Node index " << i << "kernel datatype size " << find_iter->second.size();
  243. }
  244. if (node_mix_precision_datatype[i] != find_iter->second[i]) {
  245. iter = kernel_match_datatype_idx->erase(iter);
  246. } else {
  247. ++iter;
  248. }
  249. continue;
  250. }
  251. auto datatype_indexes = iter->second;
  252. if (i >= datatype_indexes.size()) {
  253. MS_LOG(EXCEPTION) << "Node datatype index: " << i << " kernel support size " << datatype_indexes.size();
  254. }
  255. if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) {
  256. iter = kernel_match_datatype_idx->erase(iter);
  257. } else {
  258. ++iter;
  259. }
  260. }
  261. }
  262. return !kernel_match_datatype_idx->empty();
  263. }
  264. bool CanDataTypeReduce(const std::vector<int> &datatype_indexes, int check_index,
  265. const std::vector<int> &node_mix_precision_datatype_index) {
  266. auto check_index_tmp = IntToSize(check_index);
  267. if (check_index_tmp < datatype_indexes.size() && check_index_tmp < node_mix_precision_datatype_index.size()) {
  268. return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex &&
  269. datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index];
  270. }
  271. MS_LOG(EXCEPTION) << "Check index " << check_index << "is outof range";
  272. }
  273. bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
  274. const std::vector<TypeId> &node_mix_precision_datatype,
  275. const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
  276. std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
  277. MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
  278. CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes,
  279. kernel_match_datatype_idx);
  280. for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) {
  281. if (node_mix_precision_datatype[i] == kTypeUnknown) {
  282. continue;
  283. }
  284. auto iter = kernel_match_datatype_idx->begin();
  285. while (iter != kernel_match_datatype_idx->end()) {
  286. if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) {
  287. auto find_iter = kernel_support_datatypes.find(iter->first);
  288. if (find_iter == kernel_support_datatypes.end()) {
  289. MS_LOG(EXCEPTION) << "Kernel datatype index:%lu can not be found " << iter->first;
  290. }
  291. if (i >= find_iter->second.size()) {
  292. MS_LOG(EXCEPTION) << "Node index " << i << " >= kernel datatype size " << find_iter->second.size();
  293. }
  294. if (node_mix_precision_datatype[i] != find_iter->second[i]) {
  295. iter = kernel_match_datatype_idx->erase(iter);
  296. } else {
  297. ++iter;
  298. }
  299. continue;
  300. }
  301. auto datatype_indexes = iter->second;
  302. if (i >= datatype_indexes.size()) {
  303. MS_LOG(EXCEPTION) << "Index " << i << "> kernel datatype indexes size " << datatype_indexes.size();
  304. }
  305. if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) {
  306. iter = kernel_match_datatype_idx->erase(iter);
  307. } else {
  308. ++iter;
  309. }
  310. }
  311. }
  312. return !kernel_match_datatype_idx->empty();
  313. }
  314. void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info,
  315. std::vector<int> *support_indexes, std::vector<TypeId> *node_mix_precision_datatype,
  316. std::vector<TypeId> *support_datatypes,
  317. std::vector<int> *node_mix_precision_datatype_index) {
  318. MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
  319. bool add_node_datatype_flag = false;
  320. if (node_mix_precision_datatype->empty()) {
  321. add_node_datatype_flag = true;
  322. }
  323. for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
  324. AddKernelInputSupportDataType(kernel_build_info, input_index, support_indexes, support_datatypes);
  325. if (add_node_datatype_flag) {
  326. AddNodeInputDataType(kernel_node, input_index, node_mix_precision_datatype_index, node_mix_precision_datatype);
  327. }
  328. }
  329. // Check output data type
  330. for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
  331. AddKernelOutputSupportDataType(kernel_build_info, output_index, support_indexes, support_datatypes);
  332. if (add_node_datatype_flag) {
  333. AddNodeOutputDataType(kernel_node, output_index, node_mix_precision_datatype_index, node_mix_precision_datatype);
  334. }
  335. }
  336. }
  337. void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
  338. const std::vector<TypeId> &node_mix_precision_datatype,
  339. const std::map<size_t, std::vector<TypeId>> &kernel_support_datatype,
  340. std::map<size_t, std::vector<int>> *kernel_match_datatype_idx, bool *precision_reduce) {
  341. MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
  342. auto context_ptr = MsContext::GetInstance();
  343. MS_EXCEPTION_IF_NULL(context_ptr);
  344. MS_EXCEPTION_IF_NULL(precision_reduce);
  345. std::map<size_t, std::vector<int>> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx;
  346. // raise precision
  347. bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
  348. kernel_support_datatype, kernel_match_datatype_idx);
  349. if (selected_ret) {
  350. *precision_reduce = false;
  351. return;
  352. }
  353. if (context_ptr->enable_reduce_precision()) {
  354. selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
  355. kernel_support_datatype, &kernel_match_datatype_idx_copy);
  356. }
  357. if (selected_ret) {
  358. *precision_reduce = true;
  359. *kernel_match_datatype_idx = kernel_match_datatype_idx_copy;
  360. }
  361. }
  362. void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode,
  363. const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_build_info,
  364. bool precision_reduce) {
  365. MS_EXCEPTION_IF_NULL(selected_kernel_build_info);
  366. MS_EXCEPTION_IF_NULL(cnode);
  367. std::ostringstream buffer;
  368. buffer << cnode->DebugString();
  369. if (precision_reduce) {
  370. buffer << " Reduce precision, node datatype: \n";
  371. } else {
  372. buffer << " Raise precision, node datatype: \n";
  373. }
  374. PrintInputAndOutputInferType(buffer, cnode);
  375. buffer << ", select kernel:" << selected_kernel_build_info->ToString();
  376. MS_LOG(INFO) << buffer.str();
  377. }
  378. std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo(
  379. const CNodePtr &kernel_node, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
  380. if (kernel_info_list.empty()) {
  381. return nullptr;
  382. }
  383. std::vector<int> most_match_counts = {-1, -1, -1, -1, -1};
  384. size_t selected_index = 0;
  385. for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
  386. std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0};
  387. auto kernel_info_ptr = kernel_info_list[info_index];
  388. MS_EXCEPTION_IF_NULL(kernel_info_ptr);
  389. UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts);
  390. // Currently the selection policy is the match format count first, and then is datatype counts.
  391. if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) {
  392. selected_index = SizeToInt(info_index);
  393. }
  394. }
  395. return kernel_info_list[selected_index];
  396. }
  397. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilteredKernelInfoByDtype(
  398. const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
  399. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result;
  400. for (const auto &kernel_build_info : kernel_info_list) {
  401. MS_EXCEPTION_IF_NULL(kernel_build_info);
  402. if (!MatchInferOutputDataType(cnode, *kernel_build_info)) {
  403. continue;
  404. }
  405. result.push_back(kernel_build_info);
  406. }
  407. return result;
  408. }
  409. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecisionMatchedKernelInfo(
  410. const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list,
  411. bool *precision_reduce) {
  412. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_kernel_info_list;
  413. std::map<size_t, std::vector<int>> kernel_match_datatype_idx;
  414. std::map<size_t, std::vector<TypeId>> kernel_support_datatype;
  415. std::vector<int> node_mix_precision_datatype_index;
  416. std::vector<TypeId> node_mix_precision_datatype;
  417. for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
  418. std::vector<int> support_indexes;
  419. std::vector<TypeId> support_datatypes;
  420. MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]);
  421. AddNodeAndKernelDataType(cnode, *kernel_info_list[info_index], &support_indexes, &node_mix_precision_datatype,
  422. &support_datatypes, &node_mix_precision_datatype_index);
  423. kernel_match_datatype_idx[info_index] = support_indexes;
  424. kernel_support_datatype[info_index] = support_datatypes;
  425. }
  426. PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype,
  427. &kernel_match_datatype_idx, precision_reduce);
  428. std::transform(
  429. kernel_match_datatype_idx.begin(), kernel_match_datatype_idx.end(), std::back_inserter(filtered_kernel_info_list),
  430. [&](const std::pair<size_t, std::vector<int>> &matched_idx) -> std::shared_ptr<kernel::KernelBuildInfo> {
  431. return kernel_info_list[matched_idx.first];
  432. });
  433. return filtered_kernel_info_list;
  434. }
  435. } // namespace
  436. void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
  437. MS_EXCEPTION_IF_NULL(kernel_node);
  438. for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
  439. auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
  440. MS_EXCEPTION_IF_NULL(input_kernel_node);
  441. auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
  442. MS_EXCEPTION_IF_NULL(input_with_index.first);
  443. auto real_input_node = input_with_index.first;
  444. if (real_input_node->isa<CNode>()) {
  445. continue;
  446. }
  447. if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
  448. continue;
  449. }
  450. auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  451. if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
  452. AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
  453. std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
  454. builder->SetOutputsFormat(output_format);
  455. std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
  456. builder->SetOutputsDeviceType(output_type);
  457. AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), input_kernel_node.get());
  458. continue;
  459. }
  460. // we set special device info of a input tensor.
  461. bool is_ref = false;
  462. auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
  463. if (op_info != nullptr) {
  464. is_ref = op_info->is_ref();
  465. }
  466. MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
  467. if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
  468. AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
  469. continue;
  470. }
  471. if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
  472. std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
  473. builder->SetOutputsFormat(output_format);
  474. std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
  475. builder->SetOutputsDeviceType(output_type);
  476. AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
  477. }
  478. }
  479. }
  480. KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
  481. const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
  482. MS_EXCEPTION_IF_NULL(kernel_node);
  483. KernelSelectStatus select_status = kNoMatched;
  484. bool precision_reduce = false;
  485. std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
  486. // Matched kernel info
  487. // Filter kernel info matched with me infered type
  488. auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list);
  489. if (!filtered_kernel_info_list.empty()) {
  490. selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
  491. select_status = kStatusAllMatched;
  492. } else {
  493. // selected kernel info using raised precision or reduce precision
  494. filtered_kernel_info_list =
  495. FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
  496. selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
  497. if (selected_kernel_info == nullptr) {
  498. return select_status;
  499. } else {
  500. PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
  501. select_status = precision_reduce ? kStatusReducePrecision : kStatusRaisePrecision;
  502. }
  503. }
  504. // Set kernel info to the anfnode
  505. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
  506. // Set format and data type for input tensor.
  507. SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
  508. return select_status;
  509. }
  510. KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
  511. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
  512. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list;
  513. MS_EXCEPTION_IF_NULL(kernel_node);
  514. if (AnfAlgo::IsGraphKernel(kernel_node)) {
  515. auto func_graph = GetValueNode<FuncGraphPtr>(kernel_node->input(kAnfPrimitiveIndex));
  516. MS_EXCEPTION_IF_NULL(func_graph);
  517. SelectGraphKernelInfo(kernel_node, func_graph);
  518. return kStatusAllMatched;
  519. }
  520. kernel::KernelQuery(kernel_node, &kernel_info_list, kernel_type);
  521. auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
  522. // If aicore not find valid kernel info reloading aicpu kernel info list to find it
  523. if (select_status == kNoMatched) {
  524. MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
  525. << "] cannot find valid TBE kernel info, try to get aicpu kernel info";
  526. kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list);
  527. select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list);
  528. AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
  529. }
  530. // The kernel info not finded both in the aicpu kernel list & aicore kernel list
  531. if (select_status == kNoMatched) {
  532. std::ostringstream buffer;
  533. PrintInputAndOutputInferType(buffer, kernel_node);
  534. MS_LOG(WARNING) << ">>> Candidates kernel info list:";
  535. for (size_t index = 0; index < kernel_info_list.size(); ++index) {
  536. MS_LOG(WARNING) << "Kernel [" << index << "] :" << kernel_info_list[index]->ToString();
  537. }
  538. for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) {
  539. MS_LOG(WARNING) << "Kernel [" << (kernel_info_list.size() + index)
  540. << "] :" << aicpu_kernel_info_list[index]->ToString();
  541. }
  542. if (IsPrimitiveCNode(kernel_node, prim::kPrimLabelSwitch)) {
  543. auto selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, kernel_info_list);
  544. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
  545. // Set format and data type for input tensor.
  546. SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
  547. } else {
  548. MS_LOG(WARNING) << " <<<";
  549. MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
  550. << "] cannot find valid kernel info, not supported the type:" << buffer.str()
  551. << ", please refer to the supported dtypes in candidates kernel info list";
  552. }
  553. }
  554. return select_status;
  555. }
  556. } // namespace ascend
  557. } // namespace device
  558. } // namespace mindspore