You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_select_ascend.cc 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/ascend/kernel_select_ascend.h"
  17. #include <string>
  18. #include <vector>
  19. #include <memory>
  20. #include <set>
  21. #include <unordered_map>
  22. #include "kernel/oplib/oplib.h"
  23. #include "kernel/kernel_query.h"
  24. #include "session/anf_runtime_algorithm.h"
  25. #include "kernel/kernel_build_info.h"
  26. #include "utils/context/ms_context.h"
  27. #include "operator/ops.h"
  28. namespace mindspore {
  29. namespace device {
  30. namespace ascend {
  31. namespace {
  32. const float kWegihtBaseScore = 1;
  33. const float kFeatureMapBaseScore = 10;
  34. enum MatchCountPriority : int {
  35. MATCH_COUNT_PRIORITY_BEGIN = 0,
  36. MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
  37. MATCH_FORMAT_COUNT,
  38. MATCH_SPECIAL_FORMAT_COUNT,
  39. MATCH_OUTPUT_DTYPE_COUNT,
  40. MATCH_COUNT_PRIORITY_END
  41. };
  42. const size_t kMaxCount = 0xffffffff;
  43. const int kUnSupportMixedDataTypeIndex = -1;
  44. bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
  45. MS_EXCEPTION_IF_NULL(cnode);
  46. // Check input data type
  47. for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
  48. TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
  49. if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
  50. return false;
  51. }
  52. }
  53. // Check output data type
  54. for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
  55. if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
  56. return false;
  57. }
  58. }
  59. return true;
  60. }
  61. string GetPriorityMatchFormat(const CNodePtr &cnode) {
  62. string priority_matched_format = kOpFormat_NC1HWC0;
  63. bool is_init = false;
  64. bool need_change_nd = false;
  65. for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
  66. auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
  67. if (AnfAlgo::IsFeatureMapInput(cnode, index) &&
  68. kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) {
  69. priority_matched_format = !is_init ? priority_matched_format : pre_output_format;
  70. is_init = true;
  71. }
  72. // feature map has two or more special format;
  73. if (priority_matched_format != pre_output_format && pre_output_format != kOpFormat_DEFAULT) {
  74. priority_matched_format = kOpFormat_DEFAULT;
  75. }
  76. auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size();
  77. need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1));
  78. }
  79. if (need_change_nd) {
  80. priority_matched_format = kOpFormat_DEFAULT;
  81. }
  82. return priority_matched_format;
  83. }
  84. /**
  85. * compare two vector by priority, select a better vector, like compare two num, first compare highest num location,
  86. * if equal then next num location
  87. * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3]
  88. */
  89. bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) {
  90. MS_EXCEPTION_IF_NULL(best_item);
  91. if (cur_item.size() != best_item->size()) {
  92. MS_LOG(ERROR) << "item size should be same!";
  93. return false;
  94. }
  95. // Update the best_item by comparing the cur_item and best_item
  96. for (size_t i = 0; i < cur_item.size(); i++) {
  97. if (cur_item[i] > best_item->at(i)) {
  98. *best_item = cur_item;
  99. return true;
  100. } else if (cur_item[i] == best_item->at(i)) {
  101. continue;
  102. } else {
  103. return false;
  104. }
  105. }
  106. return false;
  107. }
  108. void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
  109. std::vector<int> *const cur_kernelinfo_match_counts) {
  110. MS_EXCEPTION_IF_NULL(kernel_node);
  111. MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts);
  112. if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) {
  113. MS_LOG(EXCEPTION) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END;
  114. }
  115. auto pri_match_format = GetPriorityMatchFormat(kernel_node);
  116. for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
  117. auto base_score = AnfAlgo::IsFeatureMapInput(kernel_node, input_index) ? kFeatureMapBaseScore : kWegihtBaseScore;
  118. if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
  119. (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT] += base_score;
  120. }
  121. if (kernel_build_info.GetInputDeviceType(input_index) ==
  122. AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) {
  123. (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT] += base_score;
  124. }
  125. if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) {
  126. (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score;
  127. }
  128. }
  129. for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
  130. // cal count of same output dtype between abstract and kernel info
  131. if (kernel_build_info.GetOutputDeviceType(output_index) ==
  132. AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) {
  133. (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT] += 1;
  134. }
  135. }
  136. }
  137. void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
  138. MS_EXCEPTION_IF_NULL(kernel_node);
  139. for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
  140. auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index);
  141. MS_EXCEPTION_IF_NULL(input_kernel_node);
  142. auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0);
  143. MS_EXCEPTION_IF_NULL(input_with_index.first);
  144. auto real_input_node = input_with_index.first;
  145. if (real_input_node->isa<CNode>()) {
  146. continue;
  147. }
  148. if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
  149. continue;
  150. }
  151. std::shared_ptr<kernel::KernelBuildInfo::KernelBuildInfoBuilder> builder =
  152. std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
  153. // we set special device info of a input tensor.
  154. bool is_ref = false;
  155. auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
  156. if (op_info != nullptr) {
  157. is_ref = op_info->is_ref();
  158. }
  159. MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
  160. if (MsContext::GetInstance()->execution_mode() == kPynativeMode &&
  161. AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) != kTypeUnknown) {
  162. continue;
  163. }
  164. if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
  165. std::vector<std::string> output_format = {selected_kernel_info.GetInputFormat(input_index)};
  166. builder->SetOutputsFormat(output_format);
  167. std::vector<TypeId> output_type = {selected_kernel_info.GetInputDeviceType(input_index)};
  168. builder->SetOutputsDeviceType(output_type);
  169. AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), real_input_node.get());
  170. }
  171. }
  172. }
  173. void AddSupportMixedPrecisionDataTypeIndex(TypeId data_type, std::vector<int> *support_index) {
  174. int index = kUnSupportMixedDataTypeIndex;
  175. switch (data_type) {
  176. case kNumberTypeFloat16:
  177. index = 0;
  178. break;
  179. case kNumberTypeFloat32:
  180. case kNumberTypeFloat:
  181. index = 1;
  182. break;
  183. default:
  184. break;
  185. }
  186. support_index->push_back(index);
  187. }
  188. void AddKernelInputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t input_index,
  189. std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) {
  190. auto data_type = kernel_build_info.GetInputDeviceType(input_index);
  191. support_datatype->push_back(data_type);
  192. AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index);
  193. }
  194. void AddKernelOutputSupportDataType(const kernel::KernelBuildInfo &kernel_build_info, size_t output_index,
  195. std::vector<int> *support_datatype_index, std::vector<TypeId> *support_datatype) {
  196. auto data_type = kernel_build_info.GetOutputDeviceType(output_index);
  197. support_datatype->push_back(data_type);
  198. AddSupportMixedPrecisionDataTypeIndex(data_type, support_datatype_index);
  199. }
  200. void AddNodeInputDataType(const CNodePtr &kernel_node, size_t input_index,
  201. std::vector<int> *node_mix_precision_datatype_index,
  202. std::vector<TypeId> *node_mix_precision_datatype) {
  203. AnfNodePtr cur_input = AnfAlgo::GetInputNode(kernel_node, input_index);
  204. MS_EXCEPTION_IF_NULL(cur_input);
  205. TypeId input_origin_type;
  206. if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) {
  207. // weight
  208. input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0);
  209. } else if (cur_input->isa<ValueNode>()) {
  210. input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0);
  211. } else {
  212. // feature map
  213. input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
  214. }
  215. AddSupportMixedPrecisionDataTypeIndex(input_origin_type, node_mix_precision_datatype_index);
  216. node_mix_precision_datatype->push_back(input_origin_type);
  217. }
  218. void AddNodeOutputDataType(const CNodePtr &kernel_node, size_t output_index,
  219. std::vector<int> *node_mix_precision_datatype_index,
  220. std::vector<TypeId> *node_mix_precision_datatype) {
  221. auto output_origin_type = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
  222. AddSupportMixedPrecisionDataTypeIndex(output_origin_type, node_mix_precision_datatype_index);
  223. node_mix_precision_datatype->push_back(output_origin_type);
  224. }
  225. void CheckDataTypeInputs(const std::vector<int> &node_mix_precision_datatype_index,
  226. const std::vector<TypeId> &node_mix_precision_datatype,
  227. const std::unordered_map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
  228. std::unordered_map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
  229. if (node_mix_precision_datatype_index.size() != node_mix_precision_datatype.size()) {
  230. MS_LOG(EXCEPTION) << "node datatype index size " << node_mix_precision_datatype_index.size() << " != datatype size "
  231. << node_mix_precision_datatype.size();
  232. }
  233. MS_EXCEPTION_IF_NULL(kernel_match_datatype_idx);
  234. if (kernel_support_datatypes.size() != kernel_match_datatype_idx->size()) {
  235. MS_LOG(EXCEPTION) << "kernel datatype index size " << kernel_match_datatype_idx->size() << " != datatype size "
  236. << kernel_support_datatypes.size();
  237. }
  238. }
  239. int RaiseDataTypePrecisionSelect(const std::vector<int> &node_mix_precision_datatype_index,
  240. const std::vector<TypeId> &node_mix_precision_datatype,
  241. const std::unordered_map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
  242. std::unordered_map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
  243. CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes,
  244. kernel_match_datatype_idx);
  245. for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) {
  246. if (node_mix_precision_datatype[i] == kTypeUnknown) {
  247. continue;
  248. }
  249. auto iter = kernel_match_datatype_idx->begin();
  250. while (iter != kernel_match_datatype_idx->end()) {
  251. if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) {
  252. auto find_iter = kernel_support_datatypes.find(iter->first);
  253. if (find_iter == kernel_support_datatypes.end()) {
  254. MS_LOG(EXCEPTION) << "kernel datatype index:%lu can not be found " << iter->first;
  255. }
  256. if (i >= find_iter->second.size()) {
  257. MS_LOG(EXCEPTION) << "node index " << i << "kernel datatype size " << find_iter->second.size();
  258. }
  259. if (node_mix_precision_datatype[i] != find_iter->second[i]) {
  260. iter = kernel_match_datatype_idx->erase(iter);
  261. } else {
  262. ++iter;
  263. }
  264. continue;
  265. }
  266. auto datatype_indexes = iter->second;
  267. if (i >= datatype_indexes.size()) {
  268. MS_LOG(EXCEPTION) << "node datatype index: " << i << " kernel support size " << datatype_indexes.size();
  269. }
  270. if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) {
  271. iter = kernel_match_datatype_idx->erase(iter);
  272. } else {
  273. ++iter;
  274. }
  275. }
  276. }
  277. if (kernel_match_datatype_idx->size() >= 1) {
  278. return SizeToInt(kernel_match_datatype_idx->begin()->first);
  279. }
  280. return -1;
  281. }
  282. int GetMinReducePrecisionCountIndex(std::unordered_map<size_t, std::vector<int>> *kernel_match_datatype_idx,
  283. const std::unordered_map<size_t, size_t> &precision_reduce_count) {
  284. int selected_index = -1;
  285. size_t min_reduce_precision_count = kMaxCount;
  286. auto iter = kernel_match_datatype_idx->begin();
  287. while (iter != kernel_match_datatype_idx->end()) {
  288. auto find_iter = precision_reduce_count.find(iter->first);
  289. if (find_iter == precision_reduce_count.end()) {
  290. continue;
  291. }
  292. if (min_reduce_precision_count > find_iter->second) {
  293. selected_index = SizeToInt(iter->first);
  294. min_reduce_precision_count = find_iter->second;
  295. }
  296. ++iter;
  297. }
  298. return selected_index;
  299. }
  300. int RaiseOrReduceDataTypePrecisionSelect(
  301. const std::vector<int> &node_mix_precision_datatype_index, const std::vector<TypeId> &node_mix_precision_datatype,
  302. const std::unordered_map<size_t, std::vector<TypeId>> &kernel_support_datatypes,
  303. std::unordered_map<size_t, std::vector<int>> *kernel_match_datatype_idx) {
  304. CheckDataTypeInputs(node_mix_precision_datatype_index, node_mix_precision_datatype, kernel_support_datatypes,
  305. kernel_match_datatype_idx);
  306. // reduce / raise
  307. std::unordered_map<size_t, size_t> precision_reduce_count;
  308. for (size_t i = 0; i < node_mix_precision_datatype_index.size(); ++i) {
  309. if (node_mix_precision_datatype[i] == kTypeUnknown) {
  310. continue;
  311. }
  312. auto iter = kernel_match_datatype_idx->begin();
  313. while (iter != kernel_match_datatype_idx->end()) {
  314. if (node_mix_precision_datatype_index[i] == kUnSupportMixedDataTypeIndex) {
  315. auto find_iter = kernel_support_datatypes.find(iter->first);
  316. if (find_iter == kernel_support_datatypes.end()) {
  317. MS_LOG(EXCEPTION) << "kernel datatype index:%lu can not be found " << iter->first;
  318. }
  319. if (i >= find_iter->second.size()) {
  320. MS_LOG(EXCEPTION) << "node index " << i << " >= kernel datatype size " << find_iter->second.size();
  321. }
  322. if (node_mix_precision_datatype[i] != find_iter->second[i]) {
  323. iter = kernel_match_datatype_idx->erase(iter);
  324. } else {
  325. ++iter;
  326. }
  327. continue;
  328. }
  329. auto datatype_indexes = iter->second;
  330. if (i >= datatype_indexes.size()) {
  331. MS_LOG(EXCEPTION) << "index " << i << "> kernel datatype indexes size " << datatype_indexes.size();
  332. }
  333. if (datatype_indexes[i] == kUnSupportMixedDataTypeIndex) {
  334. iter = kernel_match_datatype_idx->erase(iter);
  335. } else {
  336. if (datatype_indexes[i] < node_mix_precision_datatype_index[i]) {
  337. auto count_iter = precision_reduce_count.find(iter->first);
  338. if (count_iter != precision_reduce_count.end()) {
  339. count_iter->second++;
  340. } else {
  341. precision_reduce_count[iter->first] = 1;
  342. }
  343. }
  344. ++iter;
  345. }
  346. }
  347. }
  348. return GetMinReducePrecisionCountIndex(kernel_match_datatype_idx, precision_reduce_count);
  349. }
  350. void AddNodeAndKernelDataType(const CNodePtr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info,
  351. std::vector<int> *support_indexes, std::vector<TypeId> *node_mix_precision_datatype,
  352. std::vector<TypeId> *support_datatypes,
  353. std::vector<int> *node_mix_precision_datatype_index) {
  354. bool add_node_datatype_flag = false;
  355. if (node_mix_precision_datatype->size() == 0) {
  356. add_node_datatype_flag = true;
  357. }
  358. for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
  359. AddKernelInputSupportDataType(kernel_build_info, input_index, support_indexes, support_datatypes);
  360. if (add_node_datatype_flag) {
  361. AddNodeInputDataType(kernel_node, input_index, node_mix_precision_datatype_index, node_mix_precision_datatype);
  362. }
  363. }
  364. // Check output data type
  365. for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
  366. AddKernelOutputSupportDataType(kernel_build_info, output_index, support_indexes, support_datatypes);
  367. if (add_node_datatype_flag) {
  368. AddNodeOutputDataType(kernel_node, output_index, node_mix_precision_datatype_index, node_mix_precision_datatype);
  369. }
  370. }
  371. }
  372. int PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
  373. const std::vector<TypeId> &node_mix_precision_datatype,
  374. const std::unordered_map<size_t, std::vector<TypeId>> &kernel_support_datatype,
  375. std::unordered_map<size_t, std::vector<int>> *kernel_match_datatype_idx, bool *precision_reduce) {
  376. auto context_ptr = MsContext::GetInstance();
  377. MS_EXCEPTION_IF_NULL(context_ptr);
  378. MS_EXCEPTION_IF_NULL(precision_reduce);
  379. std::unordered_map<size_t, std::vector<int>> kernel_match_datatype_idx_copy = *kernel_match_datatype_idx;
  380. // raise precision
  381. int selected_index = RaiseDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
  382. kernel_support_datatype, kernel_match_datatype_idx);
  383. if (selected_index != -1) {
  384. int max_match = 0;
  385. auto iter = kernel_match_datatype_idx->begin();
  386. int match_count = 0;
  387. while (iter != kernel_match_datatype_idx->end()) {
  388. auto kernel_datatypes = kernel_support_datatype.find(iter->first);
  389. if (kernel_datatypes == kernel_support_datatype.end()) {
  390. MS_LOG(EXCEPTION) << "Can not find kernel index" << iter->first << "'s datatype.";
  391. }
  392. if (kernel_datatypes->second.size() < node_mix_precision_datatype.size()) {
  393. MS_LOG(EXCEPTION) << "Kernel datatype size is not equal to node datatype size!";
  394. }
  395. for (size_t i = 0; i < node_mix_precision_datatype.size(); ++i) {
  396. if (node_mix_precision_datatype[i] == kernel_datatypes->second[i]) {
  397. ++match_count;
  398. }
  399. }
  400. if (match_count > max_match) {
  401. selected_index = SizeToInt(iter->first);
  402. }
  403. ++iter;
  404. }
  405. }
  406. if (selected_index == -1 && context_ptr->enable_reduce_precision()) {
  407. selected_index =
  408. RaiseOrReduceDataTypePrecisionSelect(node_mix_precision_datatype_index, node_mix_precision_datatype,
  409. kernel_support_datatype, &kernel_match_datatype_idx_copy);
  410. if (selected_index != -1) {
  411. *precision_reduce = true;
  412. }
  413. }
  414. return selected_index;
  415. }
  416. void SelectKernel(const CNodePtr &kernel_node, bool precision_reduce, const std::vector<TypeId> &node_datatype,
  417. const std::shared_ptr<kernel::KernelBuildInfo> &selected_kernel_info_ptr) {
  418. MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr);
  419. if (precision_reduce) {
  420. std::ostringstream datatype;
  421. size_t input_num = selected_kernel_info_ptr->GetInputNum();
  422. size_t i = 0;
  423. datatype << "(";
  424. for (; i < input_num && i < node_datatype.size(); ++i) {
  425. datatype << static_cast<int>(node_datatype[i]);
  426. if (i < input_num - 1) {
  427. datatype << ", ";
  428. }
  429. }
  430. datatype << ") -> (";
  431. for (; i < node_datatype.size(); ++i) {
  432. datatype << static_cast<int>(node_datatype[i]);
  433. if (i < node_datatype.size() - 1) {
  434. datatype << ", ";
  435. }
  436. }
  437. datatype << ")";
  438. MS_LOG(WARNING) << kernel_node->DebugString() << " reduce precision, node datatype: " << datatype.str()
  439. << ", select kernel: %s" << selected_kernel_info_ptr->ToString();
  440. }
  441. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, kernel_node.get());
  442. // Set format and data type for input tensor.
  443. SetTensorDeviceInfo(*selected_kernel_info_ptr, kernel_node);
  444. }
  445. } // namespace
  446. void SelectKernelInfo(const CNodePtr &kernel_node) {
  447. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
  448. MS_EXCEPTION_IF_NULL(kernel_node);
  449. kernel::KernelQuery(kernel_node, &kernel_info_list);
  450. std::vector<int> most_match_counts = {-1, -1, -1, -1};
  451. int selected_index = -1;
  452. std::unordered_map<size_t, std::vector<int>> kernel_match_datatype_idx;
  453. std::unordered_map<size_t, std::vector<TypeId>> kernel_support_datatype;
  454. std::vector<int> node_mix_precision_datatype_index;
  455. std::vector<TypeId> node_mix_precision_datatype;
  456. for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
  457. std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0};
  458. auto kernel_build_info = *(kernel_info_list[info_index]);
  459. std::vector<int> support_indexes;
  460. std::vector<TypeId> support_datatypes;
  461. AddNodeAndKernelDataType(kernel_node, kernel_build_info, &support_indexes, &node_mix_precision_datatype,
  462. &support_datatypes, &node_mix_precision_datatype_index);
  463. kernel_match_datatype_idx[info_index] = support_indexes;
  464. kernel_support_datatype[info_index] = support_datatypes;
  465. if (!MatchInferOutputDataType(kernel_node, kernel_build_info)) {
  466. continue;
  467. }
  468. std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index];
  469. UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts);
  470. // Currently the selection policy is the match format count first, and then is datatype counts.
  471. if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) {
  472. selected_index = SizeToInt(info_index);
  473. }
  474. }
  475. bool precision_reduce = false;
  476. if (selected_index == -1) {
  477. selected_index = PrecisionReduce(node_mix_precision_datatype_index, node_mix_precision_datatype,
  478. kernel_support_datatype, &kernel_match_datatype_idx, &precision_reduce);
  479. }
  480. if (selected_index == -1) {
  481. MS_LOG(EXCEPTION) << kernel_node->DebugString() << "Cannot find valid kernel Info !";
  482. }
  483. auto index = IntToSize(selected_index);
  484. if (index >= kernel_info_list.size()) {
  485. MS_LOG(EXCEPTION) << "index outof range";
  486. }
  487. std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info_ptr = kernel_info_list[index];
  488. MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr);
  489. SelectKernel(kernel_node, precision_reduce, node_mix_precision_datatype, selected_kernel_info_ptr);
  490. }
  491. bool CheckKernelAccuracySupported(const CNodePtr &kernel_node,
  492. const kernel::KernelBuildInfoPtr &new_kernel_build_info) {
  493. MS_EXCEPTION_IF_NULL(kernel_node);
  494. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
  495. kernel::KernelQuery(kernel_node, &kernel_info_list);
  496. auto result = std::find_if(kernel_info_list.begin(), kernel_info_list.end(),
  497. [&new_kernel_build_info](const kernel::KernelBuildInfoPtr item) {
  498. MS_EXCEPTION_IF_NULL(item);
  499. return *item == *new_kernel_build_info;
  500. });
  501. return result != kernel_info_list.end();
  502. }
  503. } // namespace ascend
  504. } // namespace device
  505. } // namespace mindspore