You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ascend_kernel_select_test.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "mindspore/ccsrc/device/ascend/kernel_select_ascend.h"
  17. #include "common/common_test.h"
  18. #include "session/kernel_graph.h"
  19. #include "kernel/kernel.h"
  20. #include "session/anf_runtime_algorithm.h"
  21. #include "utils/utils.h"
  22. #include "operator/ops.h"
  23. #include "mindspore/ccsrc/device/kernel_info.h"
  24. #include "mindspore/ccsrc/kernel/kernel_build_info.h"
  25. #include <vector>
  26. namespace mindspore {
  27. namespace device {
  28. namespace ascend {
  29. namespace {
  30. using KernelInfo = device::KernelInfo;
  31. using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
  32. using KernelBuildInfo = kernel::KernelBuildInfo;
  33. using KernelGraph = session::KernelGraph;
  34. using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>;
  35. using KernelBuilderPtr = std::shared_ptr<KernelBuildInfoBuilder>;
  36. using Shape = std::vector<size_t>;
  37. using ShapeList = std::vector<Shape>;
  38. enum MatchCountPriority {
  39. MATCH_COUNT_PRIORITY_BEGIN = 0,
  40. MATCH_FORMAT_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
  41. MATCH_DTYPE_COUNT,
  42. MATCH_NZ_FORMAT_COUNT,
  43. MATCH_5D_FORMAT_COUNT,
  44. MATCH_OUTPUT_DTYPE_COUNT,
  45. MATCH_COUNT_PRIORITY_END
  46. };
  47. const std::set<std::string> kOpFormatList = {
  48. kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC,
  49. kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ};
  50. bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
  51. // if format is default,it remarkes support all format
  52. if (kOpFormatList.find(format) == kOpFormatList.end()) {
  53. MS_EXCEPTION(ArgumentError) << "got the unknow format " << format;
  54. }
  55. if (format == kOpFormat_DEFAULT) {
  56. return true;
  57. }
  58. // if shape size is 0,the shape will be a scalar
  59. if (shape.empty()) {
  60. return true;
  61. }
  62. if (shape.size() > kShapeSupportFormatMap.size()) {
  63. return false;
  64. }
  65. if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) {
  66. return shape[shape.size() - 1] % 16 != 0 && shape[shape.size() - 2] % 16 != 0;
  67. }
  68. return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end());
  69. }
  70. bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
  71. MS_EXCEPTION_IF_NULL(kernel_node);
  72. auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool {
  73. if (!IsShapeMatchFormat(shape, format)) {
  74. return false;
  75. }
  76. for (auto shape_value : shape) {
  77. if (shape_value == 0) {
  78. MS_EXCEPTION(ArgumentError) << "dimension size of the tensor shape should be a positive integer, but got ["
  79. << shape_value << "]";
  80. }
  81. }
  82. return true;
  83. };
  84. for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
  85. auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
  86. if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) {
  87. return false;
  88. }
  89. }
  90. for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
  91. auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
  92. if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) {
  93. return false;
  94. }
  95. }
  96. return true;
  97. }
  98. bool MatchInferOutputDataType(const CNodePtr &cnode, const kernel::KernelBuildInfo &kernel_build_info) {
  99. MS_EXCEPTION_IF_NULL(cnode);
  100. // Check input data type
  101. for (size_t input_index = 0; input_index < kernel_build_info.GetInputNum(); ++input_index) {
  102. AnfNodePtr cur_input = cnode->input(input_index + 1);
  103. MS_EXCEPTION_IF_NULL(cur_input);
  104. TypeId input_origin_type;
  105. if (cur_input->isa<Parameter>() && AnfAlgo::IsParameterWeight(cur_input->cast<ParameterPtr>())) {
  106. // weight
  107. input_origin_type = AnfAlgo::GetOutputDeviceDataType(cur_input, 0);
  108. } else {
  109. // feature map
  110. input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index);
  111. }
  112. if (input_origin_type == kTypeUnknown) {
  113. continue;
  114. }
  115. if (kernel_build_info.GetInputDeviceType(input_index) != input_origin_type) {
  116. return false;
  117. }
  118. }
  119. // Check output data type
  120. for (size_t output_index = 0; output_index < kernel_build_info.GetOutputNum(); ++output_index) {
  121. if (kernel_build_info.GetOutputDeviceType(output_index) != AnfAlgo::GetOutputInferDataType(cnode, output_index)) {
  122. return false;
  123. }
  124. }
  125. return true;
  126. }
  127. /**
  128. * compare too vector by priority,select a better vector,like compare too num,first compare highest num location,if
  129. * equal then next num location
  130. * example:[3,1,1,1] > [2,2,2,2] > [2,2,1,2] > [2,1,1,3]
  131. */
  132. bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best_item) {
  133. MS_EXCEPTION_IF_NULL(best_item);
  134. if (cur_item.size() != best_item->size()) {
  135. MS_LOG(ERROR) << "item size should be same!";
  136. return false;
  137. }
  138. // Update the best_item by comparing the cur_item and best_item
  139. for (size_t i = 0; i < cur_item.size(); i++) {
  140. if (cur_item[i] > best_item->at(i)) {
  141. *best_item = cur_item;
  142. return true;
  143. } else if (cur_item[i] == best_item->at(i)) {
  144. continue;
  145. } else {
  146. return false;
  147. }
  148. }
  149. return false;
  150. }
  151. void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
  152. std::vector<int> *const cur_kernelinfo_match_counts) {
  153. MS_EXCEPTION_IF_NULL(kernel_node);
  154. MS_EXCEPTION_IF_NULL(cur_kernelinfo_match_counts);
  155. if (cur_kernelinfo_match_counts->size() < MATCH_COUNT_PRIORITY_END) {
  156. MS_EXCEPTION(ArgumentError) << "Out of range cur_kernelinfo_match_counts " << MATCH_COUNT_PRIORITY_END;
  157. }
  158. for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
  159. AnfNodePtr input_anf_node = kernel_node->input(input_index + 1);
  160. MS_EXCEPTION_IF_NULL(input_anf_node);
  161. // if a input parameter is a weight with default format, the input shouldn't participate the judge
  162. if (input_anf_node->isa<Parameter>()) {
  163. auto para = input_anf_node->cast<ParameterPtr>();
  164. if (AnfAlgo::IsParameterWeight(para) && AnfAlgo::GetOutputDeviceDataType(para, 0) == kTypeUnknown) {
  165. continue;
  166. }
  167. }
  168. if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
  169. (*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++;
  170. }
  171. if (kernel_build_info.GetInputDeviceType(input_index) ==
  172. AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)) {
  173. (*cur_kernelinfo_match_counts)[MATCH_DTYPE_COUNT]++;
  174. }
  175. if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_FRAC_NZ) {
  176. (*cur_kernelinfo_match_counts)[MATCH_NZ_FORMAT_COUNT]++;
  177. }
  178. if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_NC1HWC0) {
  179. (*cur_kernelinfo_match_counts)[MATCH_5D_FORMAT_COUNT]++;
  180. }
  181. }
  182. for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
  183. // cal count of same output dtype between abstract and kernel info
  184. if (kernel_build_info.GetOutputDeviceType(output_index) ==
  185. AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) {
  186. (*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++;
  187. }
  188. }
  189. }
  190. void SetKernelBuildInfo(KernelBuilderPtr builder) {
  191. builder->SetFusionType(kernel::OPAQUE);
  192. builder->SetKernelType(AUTO_DIFF_KERNEL);
  193. builder->SetProcessor(kernel::AICORE);
  194. }
  195. void test_select(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list) {
  196. std::vector<int> most_match_counts = {-1, -1, -1, -1, -1};
  197. int selected_index = -1;
  198. for (size_t info_index = 0; info_index < kernel_info_list.size(); ++info_index) {
  199. std::vector<int> cur_kernel_info_match_counts = {0, 0, 0, 0, 0};
  200. if (!IsValidKernelInfo(kernel_node, *(kernel_info_list[info_index]))) {
  201. continue;
  202. }
  203. if (!MatchInferOutputDataType(kernel_node, *(kernel_info_list[info_index]))) {
  204. continue;
  205. }
  206. std::shared_ptr<kernel::KernelBuildInfo> kernel_info_ptr = kernel_info_list[info_index];
  207. UpdateCurMatchCounts(*kernel_info_ptr, kernel_node, &cur_kernel_info_match_counts);
  208. // Currently the selection policy is the match format count first, and then is datatype counts.
  209. if (PriorityChooseItem(cur_kernel_info_match_counts, &most_match_counts)) {
  210. selected_index = SizeToInt(info_index);
  211. }
  212. }
  213. if (selected_index == -1) {
  214. MS_EXCEPTION(NotExistsError) << "" << kernel_node->DebugString() << " Cannot find valid kernel Info !";
  215. }
  216. auto index = IntToSize(selected_index);
  217. if (index >= kernel_info_list.size()) {
  218. MS_EXCEPTION(ArgumentError) << "index outof range";
  219. }
  220. std::shared_ptr<kernel::KernelBuildInfo> selected_kernel_info_ptr = kernel_info_list[index];
  221. MS_EXCEPTION_IF_NULL(selected_kernel_info_ptr);
  222. AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_info_ptr, kernel_node.get());
  223. }
  224. void SetParentAbstract(std::vector<AnfNodePtr> parent_list, std::vector<std::vector<size_t>> shapes,
  225. std::vector<TypeId> types) {
  226. for (const auto &node : parent_list) {
  227. AnfAlgo::SetOutputInferTypeAndShape(types, shapes, node.get());
  228. }
  229. }
  230. } // namespace
  231. class AscendKernelSelctTest : public UT::Common {
  232. public:
  233. AscendKernelSelctTest() = default;
  234. void SetUp() override {}
  235. void TearDown() override {}
  236. };
  237. TEST_F(AscendKernelSelctTest, TestSelect) {
  238. std::vector<KernelBuilderPtr> build_list;
  239. std::vector<TypeId> type_list = {kNumberTypeFloat32};
  240. for (size_t i = 0; i <= 4; ++i) {
  241. build_list.push_back(std::make_shared<KernelBuildInfoBuilder>());
  242. SetKernelBuildInfo(build_list[i]);
  243. build_list[i]->SetInputsDeviceType(type_list);
  244. build_list[i]->SetOutputsDeviceType(type_list);
  245. }
  246. std::vector<std::string> nd_fmt = {kOpFormat_DEFAULT};
  247. std::vector<std::string> nz_fmt = {kOpFormat_FRAC_NZ};
  248. auto anf_graph = std::make_shared<KernelGraph>();
  249. // 16's multiple should not chose format NZ
  250. Shape nd_shapes = {2, 32, 224, 224};
  251. Shape nz_shapes = {3, 3, 5, 5};
  252. auto add_value = NewValueNode(prim::kPrimTensorAdd);
  253. auto a_node = anf_graph->NewCNode(std::vector<AnfNodePtr>{add_value});
  254. auto b_node = anf_graph->NewCNode(std::vector<AnfNodePtr>{add_value});
  255. std::vector<AnfNodePtr> parent_list = {add_value, a_node, b_node};
  256. auto c_node = anf_graph->NewCNode(parent_list);
  257. // a b
  258. // \ /
  259. // c
  260. // a & b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}}
  261. // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
  262. // c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3,224, 224}}
  263. // set a & b's info
  264. SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list);
  265. // set abstract c
  266. AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nd_shapes}, c_node.get());
  267. // set format of kernel info
  268. build_list[0]->SetOutputsFormat(nz_fmt);
  269. build_list[1]->SetOutputsFormat(nz_fmt);
  270. build_list[2]->SetInputsFormat(std::vector<std::string>{nd_fmt[0], nd_fmt[0]});
  271. build_list[3]->SetInputsFormat(std::vector<std::string>{nz_fmt[0], nz_fmt[0]});
  272. build_list[2]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
  273. build_list[3]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
  274. build_list[2]->SetOutputsFormat(nd_fmt);
  275. build_list[3]->SetOutputsFormat(nz_fmt);
  276. std::vector<KernelBuildInfoPtr> select_info_list;
  277. // set select info list
  278. select_info_list.emplace_back(build_list[2]->Build());
  279. select_info_list.emplace_back(build_list[3]->Build());
  280. // set device info for a & b
  281. AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get());
  282. AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get());
  283. test_select(c_node, select_info_list);
  284. EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT);
  285. EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_DEFAULT);
  286. // set a & b's info
  287. // a b
  288. // \ /
  289. // c
  290. // a: kernel_info:{output_format:{5d},dtype:{kNumberTypeFloat32}}
  291. // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
  292. // b: kernel_info:{output_format:{nz},dtype:{kNumberTypeFloat32}}
  293. // infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
  294. // c: infer_dtype:{kNumberTypeFloat32},infer_shape:{{3, 3, 5, 5}}
  295. // set a & b's info
  296. SetParentAbstract(parent_list, ShapeList{nz_shapes}, type_list);
  297. // set abstract c
  298. AnfAlgo::SetOutputInferTypeAndShape(type_list, ShapeList{nz_shapes}, c_node.get());
  299. // set format of kernel info
  300. build_list[0]->SetOutputsFormat(std::vector<std::string>{kOpFormat_NC1HWC0});
  301. build_list[1]->SetOutputsFormat(nz_fmt);
  302. build_list[2]->SetInputsFormat(std::vector<std::string>{kOpFormat_NC1HWC0, nd_fmt[0]});
  303. build_list[3]->SetInputsFormat(std::vector<std::string>{nd_fmt[0], nz_fmt[0]});
  304. build_list[2]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
  305. build_list[3]->SetInputsDeviceType(std::vector<TypeId>{kNumberTypeFloat32, kNumberTypeFloat32});
  306. build_list[2]->SetOutputsFormat(nd_fmt);
  307. build_list[3]->SetOutputsFormat(nz_fmt);
  308. // set select info list
  309. select_info_list.emplace_back(build_list[2]->Build());
  310. select_info_list.emplace_back(build_list[3]->Build());
  311. // set device info for a & b
  312. AnfAlgo::SetSelectKernelBuildInfo(build_list[0]->Build(), a_node.get());
  313. AnfAlgo::SetSelectKernelBuildInfo(build_list[1]->Build(), b_node.get());
  314. test_select(c_node, select_info_list);
  315. EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 0), kOpFormat_DEFAULT);
  316. EXPECT_EQ(AnfAlgo::GetInputFormat(c_node, 1), kOpFormat_FRAC_NZ);
  317. }
  318. } // namespace ascend
  319. } // namespace device
  320. } // namespace mindspore