You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_select_ascend.cc 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/ascend/kernel_select_ascend.h"
  17. #include <string>
  18. #include <vector>
  19. #include <memory>
  20. #include <utility>
  21. #include <map>
  22. #include "kernel/oplib/oplib.h"
  23. #include "kernel/kernel_query.h"
  24. #include "session/anf_runtime_algorithm.h"
  25. #include "kernel/kernel_build_info.h"
  26. #include "utils/context/ms_context.h"
  27. #include "operator/ops.h"
  28. #include "debug/anf_ir_dump.h"
  29. namespace mindspore {
  30. namespace device {
  31. namespace ascend {
  32. namespace {
  33. const float kWegihtBaseScore = 1;
  34. const float kFeatureMapBaseScore = 10;
  35. enum MatchCountPriority : int {
  36. MATCH_COUNT_PRIORITY_BEGIN = 0,
  37. MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
  38. MATCH_FORMAT_COUNT,
  39. MATCH_SPECIAL_FORMAT_COUNT,
  40. MATCH_OUTPUT_DTYPE_COUNT,
  41. MATCH_COUNT_PRIORITY_END
  42. };
  43. const size_t kMaxCount = 0xffffffff;
  44. const int kUnSupportMixedDataTypeIndex = -1;
  45. bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
  46. MS_EXCEPTION_IF_NULL(cnode);
  47. // Check input data type
  48. for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
  49. TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
  50. if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
  51. return false;
  52. }
  53. }
  54. // Check output data type
  55. for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
  56. if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
  57. return false;
  58. }
  59. }
  60. return true;
  61. }
  62. string GetPriorityMatchFormat(const CNodePtr &cnode) {
  63. string priority_matched_format = kOpFormat_NC1HWC0;
  64. bool is_init = false;
  65. bool need_change_nd = false;
  66. for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
  67. auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
  68. if (AnfAlgo::IsFeatureMapInput(cnode, index) &&
  69. kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) {
  70. priority_matched_format = !is_init ? priority_matched_format : pre_output_format;
  71. is_init = true;
  72. }
  73. // feature map has two or more special format;
  74. if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) {
  75. priority_matched_format = kOpFormat_DEFAULT;
  76. }
  77. auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size();
  78. need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1));
  79. }
  80. if (need_change_nd) {
  81. priority_matched_format = kOpFormat_DEFAULT;
  82. }
  83. return priority_matched_format;
  84. }
  85. /**
  86. * compare two vector by priority, select a better vector, like compare two num, first compare highest num location,
  87. * if equal then next num location
  88. * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3]
  89. */
  90. bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) {
  91. MS_EXCEPTION_IF_NULL(best_item);
  92. if (cur_item.size() != best_item->size()) {
  93. MS_LOG(ERROR) << "item size should be same!";
  94. return false;
  95. }
  96. // Update the best_item by comparing the cur_item and best_item
  97. for (size_t i = 0; i < cur_item.size(); i++) {
  98. if (cur_item[i] > best_item->at(i)) {
  99. *best_item = cur_item;
  100. return true;
  101. } else if (cur_item[i] == best_item->at(i)) {
  102. continue;
  103. } else {
  104. return false;
  105. }
  106. }
  107. return false;
  108. }
  109. void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
  110. std::vector<int> *const cur_kernelinfo_match_counts) {
  111. MS_EXCEPTION_IF_NULL(kernel_node);
  112. MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts);
  113. if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) {
  114. MS_LOG(EXCEPTION) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END;
  115. }
  116. auto pri_match_format = GetPriorityMatchFormat(kernel_node);
  117. for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
  118. auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore;
  119. if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
  120. (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score;
  121. }
  122. if (kernel_build_info.GetInputDeviceType(input_index) ==
  123. AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) {
  124. (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score;
  125. }
  126. if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
  127. (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score;
  128. }
  129. }
  130. for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
  131. // cal count of same output dtype between abstract and kernel info
  132. if (kernel_build_info.GetOutputDeviceType(output_index) ==
  133. AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) {
  134. (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT] += 1;
  135. }
  136. }
  137. }
  138. void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
  139. MS_EXCEPTION_IF_NULL(kernel_node);
  140. for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
  141. auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
  142. MS_EXCEPTION_IF_NULL(input_kernel_node);
  143. auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
  144. MS_EXCEPTION_IF_NULL(input_with_index.first);
  145. auto real_input_node = input_with_index.first;
  146. if (real_input_node->isa<CNode>()) {
  147. continue;
  148. }
  149. if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
  150. continue;
  151. }
  152. std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
  153. std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  154. // we set special device info of a input tensor.
  155. bool is_ref = false;
  156. auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
  157. if (op_info != nullptr) {
  158. is_ref = op_info->is_ref();
  159. }
  160. MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
  161. if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
  162. AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
  163. continue;
  164. }
  165. if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
  166. std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
  167. builder->SetOutputsFormat(output_format);
  168. std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
  169. builder->SetOutputsDeviceType(output_type);
  170. AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
  171. }
  172. }
  173. }
  174. void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) {
  175. MS_EXCEPTION_IF_NULL(support_index);
  176. int index = kUnSupportMixedDataTypeIndex;
  177. switch (data_type) {
  178. case kNumberTypeFloat16:
  179. index = 0;
  180. break;
  181. case kNumberTypeFloat32:
  182. case kNumberTypeFloat:
  183. index = 1;
  184. break;
  185. default:
  186. break;
  187. }
  188. support_index->push_back(index);
  189. }
  190. void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index,
  191. std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) {
  192. MS_EXCEPTION_IF_NULL(support_datatype);
  193. auto data_type = kernel_build_info.GetInputDeviceType(input_index);
  194. support_datatype->push_back(data_type);
  195. AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index);
  196. }
  197. void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index,
  198. std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) {
  199. MS_EXCEPTION_IF_NULL(support_datatype);
  200. auto data_type = kernel_build_info.GetOutputDeviceType(output_index);
  201. support_datatype->push_back(data_type);
  202. AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index);
  203. }
  204. void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index,
  205. std::vector<int> *node_mix_precision_datatype_index,
  206. std::vector<TypeId> *node_mix_precision_datatype) {
  207. AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index);
  208. MS_EXCEPTION_IF_NULL(cur_input);
  209. TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
  210. AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index);
  211. node_mix_precision_datatype->push_back(input_origin_type);
  212. }
  213. void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index,
  214. std::vector<int> *node_mix_precision_datatype_index,
  215. std::vector<TypeId> *node_mix_precision_datatype) {
  216. auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
  217. AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index);
  218. node_mix_precision_datatype->push_back(output_origin_type);
  219. }
  220. void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_index,
  221. const std::vector<TypeId> &node_mix_precision_datatype,
  222. const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
  223. std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
  224. if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) {
  225. MS_LOG(EXCEPTION) << "node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size "
  226. << node_mix_precision_datatype.size();
  227. }
  228. MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
  229. if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) {
  230. MS_LOG(EXCEPTION) << "kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size "
  231. << kernel_support_datatypes.size();
  232. }
  233. }
  234. bool RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
  235. const std::vector<TypeId> &node_mix_precision_datatype,
  236. const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
  237. std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
  238. MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
  239. CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes,
  240. kernel_match_datatype_idx);
  241. for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) {
  242. if (node_mix_precision_datatype[i] == kTypeUnknown) {
  243. continue;
  244. }
  245. auto iter = kernel_match_datatype_idx->begin();
  246. while (iter != kernel_match_datatype_idx->end()) {
  247. if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) {
  248. auto find_iter = kernel_support_datatypes.find(iter->first);
  249. if (find_iter == kernel_support_datatypes.end()) {
  250. MS_LOG(EXCEPTION) << "kernel datatype index:%lu can not be found " << iter->first;
  251. }
  252. if (i >= find_iter->second.size()) {
  253. MS_LOG(EXCEPTION) << "node index " << i << "kernel datatype size " << find_iter->second.size();
  254. }
  255. if (node_mix_precision_datatype[i] != find_iter->second[i]) {
  256. iter = kernel_match_datatype_idx->erase(iter);
  257. } else {
  258. ++iter;
  259. }
  260. continue;
  261. }
  262. auto datatype_indexes = iter->second;
  263. if (i >= datatype_indexes.size()) {
  264. MS_LOG(EXCEPTION) << "node datatype index: " << i << " kernel support size " << datatype_indexes.size();
  265. }
  266. if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) {
  267. iter = kernel_match_datatype_idx->erase(iter);
  268. } else {
  269. ++iter;
  270. }
  271. }
  272. }
  273. return !kernel_match_datatype_idx->empty();
  274. }
  275. bool CanDataTypeReduce(const std::vector<int> &datatype_indexes, int check_index,
  276. const std::vector<int> &node_mix_precision_datatype_index) {
  277. return datatype_indexes[check_index] != kUnSupportMixedDataTypeIndex &&
  278. datatype_indexes[check_index] <= node_mix_precision_datatype_index[check_index];
  279. }
  280. bool RaiseOrReduceDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
  281. const std::vector<TypeId> &node_mix_precision_datatype,
  282. const std::map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
  283. std::map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
  284. MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
  285. CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes,
  286. kernel_match_datatype_idx);
  287. for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) {
  288. if (node_mix_precision_datatype[i] == kTypeUnknown) {
  289. continue;
  290. }
  291. auto iter = kernel_match_datatype_idx->begin();
  292. while (iter != kernel_match_datatype_idx->end()) {
  293. if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) {
  294. auto find_iter = kernel_support_datatypes.find(iter->first);
  295. if (find_iter == kernel_support_datatypes.end()) {
  296. MS_LOG(EXCEPTION) << "kernel datatype index:%lu can not be found " << iter->first;
  297. }
  298. if (i >= find_iter->second.size()) {
  299. MS_LOG(EXCEPTION) << "node index " << i << " >= kernel datatype size " << find_iter->second.size();
  300. }
  301. if (node_mix_precision_datatype[i] != find_iter->second[i]) {
  302. iter = kernel_match_datatype_idx->erase(iter);
  303. } else {
  304. ++iter;
  305. }
  306. continue;
  307. }
  308. auto datatype_indexes = iter->second;
  309. if (i >= datatype_indexes.size()) {
  310. MS_LOG(EXCEPTION) << "index " << i << "> kernel datatype indexes size " << datatype_indexes.size();
  311. }
  312. if (!CanDataTypeReduce(datatype_indexes, i, node_mix_precision_datatype_index)) {
  313. iter = kernel_match_datatype_idx->erase(iter);
  314. } else {
  315. ++iter;
  316. }
  317. }
  318. }
  319. return !kernel_match_datatype_idx->empty();
  320. }
  321. void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info,
  322. std::vector<int> *support_indexes, std::vector<TypeId> *node_mix_precision_datatype,
  323. std::vector<TypeId> *support_datatypes,
  324. std::vector<int> *node_mix_precision_datatype_index) {
  325. MS_EXCEPTION_IF_NULL(node_mix_precision_datatype);
  326. bool add_node_datatype_flag = false;
  327. if (node_mix_precision_datatype->size() == 0) {
  328. add_node_datatype_flag = true;
  329. }
  330. for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
  331. AddKernelInputSupportDataType(kernel_build_info, input_index, support_indexes, support_datatypes);
  332. if (add_node_datatype_flag) {
  333. AddNodeInputDataType(kernel_node, input_index, node_mix_precision_datatype_index, node_mix_precision_datatype);
  334. }
  335. }
  336. // Check output data type
  337. for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
  338. AddKernelOutputSupportDataType(kernel_build_info, output_index, support_indexes, support_datatypes);
  339. if (add_node_datatype_flag) {
  340. AddNodeOutputDataType(kernel_node, output_index, node_mix_precision_datatype_index, node_mix_precision_datatype);
  341. }
  342. }
  343. }
  344. void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
  345. const std::vector<TypeId> &node_mix_precision_datatype,
  346. const std::map<size_t, std::vector<TypeId>> &kernel_support_datatype,
  347. std::map<size_t, std::vector<int>> *kernel_match_datatype_idx, bool *precision_reduce) {
  348. MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
  349. auto context_ptr = MsContext::GetInstance();
  350. MS_EXCEPTION_IF_NULL(context_ptr);
  351. MS_EXCEPTION_IF_NULL(precision_reduce);
  352. std::map<size_t, std::vector<int>> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx;
  353. // raise precision
  354. bool selected_ret = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
  355. kernel_support_datatype, kernel_match_datatype_idx);
  356. if (selected_ret) {
  357. *precision_reduce = false;
  358. return;
  359. }
  360. if (context_ptr->enable_reduce_precision()) {
  361. selected_ret = RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
  362. kernel_support_datatype, &kernel_match_datatype_idx_copy);
  363. }
  364. if (selected_ret) {
  365. *precision_reduce = true;
  366. *kernel_match_datatype_idx = kernel_match_datatype_idx_copy;
  367. }
  368. }
  369. void PrintRaiseOrReducePrecisionSelectedInfo(const CNodePtr &cnode,
  370. const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_build_info,
  371. bool precision_reduce) {
  372. MS_EXCEPTION_IF_NULL(selected_kernel_build_info);
  373. MS_EXCEPTION_IF_NULL(cnode);
  374. std::ostringstream buffer;
  375. buffer << cnode->DebugString();
  376. if (precision_reduce) {
  377. buffer << " reduce precision, node datatype: ";
  378. } else {
  379. buffer << " raise precision, node datatype: ";
  380. }
  381. PrintInputAndOutputInferType(buffer, cnode);
  382. buffer << ", select kernel:" << selected_kernel_build_info->ToString();
  383. MS_LOG(INFO) << buffer.str();
  384. }
  385. std::shared_ptr<kernel::KernelBuildInfo> ChooseMatchedKernelInfo(
  386. const CNodePtr &kernel_node, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
  387. if (kernel_info_list.empty()) {
  388. return nullptr;
  389. }
  390. std::vector<int> most_match_counts = {-1, -1, -1, -1};
  391. size_t selected_index = 0;
  392. for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
  393. std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0};
  394. auto kernel_build_info = *(kernel_info_list[info_index]);
  395. std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index];
  396. UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts);
  397. // Currently the selection policy is the match format count first, and then is datatype counts.
  398. if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) {
  399. selected_index = SizeToInt(info_index);
  400. }
  401. }
  402. return kernel_info_list[selected_index];
  403. }
  404. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetAllMatchedFilteredKernelInfo(
  405. const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list) {
  406. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> result;
  407. for (const auto &kernel_build_info : kernel_info_list) {
  408. MS_EXCEPTION_IF_NULL(kernel_build_info);
  409. if (!MatchInferOutputDataType(cnode, *kernel_build_info)) {
  410. continue;
  411. }
  412. result.push_back(kernel_build_info);
  413. }
  414. return result;
  415. }
  416. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> FilterRaisedOrReducePrecisionMatchedKernelInfo(
  417. const CNodePtr &cnode, const std::vector<std::shared_ptr<kernel::KernelBuildInfo>> &kernel_info_list,
  418. bool *precision_reduce) {
  419. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> filtered_kernel_info_list;
  420. std::map<size_t, std::vector<int>> kernel_match_datatype_idx;
  421. std::map<size_t, std::vector<TypeId>> kernel_support_datatype;
  422. std::vector<int> node_mix_precision_datatype_index;
  423. std::vector<TypeId> node_mix_precision_datatype;
  424. for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
  425. std::vector<int> support_indexes;
  426. std::vector<TypeId> support_datatypes;
  427. MS_EXCEPTION_IF_NULL(kernel_info_list[info_index]);
  428. AddNodeAndKernelDataType(cnode, *kernel_info_list[info_index], &support_indexes, &node_mix_precision_datatype,
  429. &support_datatypes, &node_mix_precision_datatype_index);
  430. kernel_match_datatype_idx[info_index] = support_indexes;
  431. kernel_support_datatype[info_index] = support_datatypes;
  432. }
  433. PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatype,
  434. &kernel_match_datatype_idx, precision_reduce);
  435. std::transform(
  436. kernel_match_datatype_idx.begin(), kernel_match_datatype_idx.end(), std::back_inserter(filtered_kernel_info_list),
  437. [&](const std::pair<size_t, std::vector<int>> &matched_idx) -> std::shared_ptr<kernel::KernelBuildInfo> {
  438. return kernel_info_list[matched_idx.first];
  439. });
  440. return filtered_kernel_info_list;
  441. }
  442. } // namespace
  443. void SelectKernelInfo(const CNodePtr &kernel_node) {
  444. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
  445. MS_EXCEPTION_IF_NULL(kernel_node);
  446. bool precision_reduce = false;
  447. std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info = nullptr;
  448. kernel::KernelQuery(kernel_node, &kernel_info_list);
  449. // filter kernel info matched with me infered type
  450. auto filtered_kernel_info_list = GetAllMatchedFilteredKernelInfo(kernel_node, kernel_info_list);
  451. if (!filtered_kernel_info_list.empty()) {
  452. selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
  453. } else {
  454. // selected kernel info using raised precision or reduce precision
  455. filtered_kernel_info_list =
  456. FilterRaisedOrReducePrecisionMatchedKernelInfo(kernel_node, kernel_info_list, &precision_reduce);
  457. selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list);
  458. if (selected_kernel_info == nullptr) {
  459. std::ostringstream buffer;
  460. PrintInputAndOutputInferType(buffer, kernel_node);
  461. MS_LOG(EXCEPTION) << "The node [" << kernel_node->DebugString()
  462. << "] cannot find valid kernel info, not supported the type" << buffer.str();
  463. } else {
  464. PrintRaiseOrReducePrecisionSelectedInfo(kernel_node, selected_kernel_info, precision_reduce);
  465. }
  466. }
  467. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info, kernel_node.get());
  468. // Set format and data type for input tensor.
  469. SetTensorDeviceInfo(*selected_kernel_info, kernel_node);
  470. }
  471. bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
  472. const kernel::KernelBuildInfoPtr &new_kernel_build_info) {
  473. MS_EXCEPTION_IF_NULL(kernel_node);
  474. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
  475. kernel::KernelQuery(kernel_node, &kernel_info_list);
  476. auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(),
  477. [&new_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
  478. MS_EXCEPTION_IF_NULL(item);
  479. return *item == *new_kernel_build_info;
  480. });
  481. return result != kernel_info_list.end();
  482. }
  483. } // namespace ascend
  484. } // namespace device
  485. } // namespace mindspore