You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_kernel_select.cc 25 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/tbe/tbe_kernel_select.h"
  17. #include <unordered_map>
  18. #include <memory>
  19. #include <map>
  20. #include <set>
  21. #include "session/anf_runtime_algorithm.h"
  22. #include "kernel/oplib/oplib.h"
  23. #include "kernel/tbe/tbe_kernel_build.h"
  24. #include "nlohmann/json.hpp"
  25. #include "common/utils.h"
  26. #include "utils/context/ms_context.h"
  27. #include "kernel/tbe/tbe_python_funcs.h"
  28. #include "pre_activate/common/helper.h"
  29. #include "kernel/tbe/tbe_convert_utils.h"
  30. namespace mindspore {
  31. namespace kernel {
  32. constexpr auto kName = "name";
  33. constexpr auto kDtype = "dtype";
  34. constexpr auto kFormat = "format";
  35. constexpr auto kPrefixInput = "input";
  36. constexpr auto kPrefixOutput = "output";
  37. const std::map<std::string, std::string> DYNAMIC_FORMAT_MAP = {{"NCHW", "DefaultFormat"},
  38. {"NHWC", "DefaultFormat"},
  39. {"ND", "DefaultFormat"},
  40. {"FRACTAL_Z", "FracZ"},
  41. {"NDHWC", "DefaultFormat"}};
  42. static const std::vector<std::string> CHECK_SUPPORTED_OPTYPE{
  43. "MatMul", "BatchMatMul", "TopK", "InTopK", "Pack", "GatherNd", "UnsortedSegmentMinD", "UnsortedSegmentProdD", "Cast"};
  44. bool CheckSupported(const AnfNodePtr &anf_node, const KernelBuildInfoPtr &select_kernel_build_info) {
  45. MS_EXCEPTION_IF_NULL(anf_node);
  46. MS_EXCEPTION_IF_NULL(select_kernel_build_info);
  47. std::string op_name = AnfAlgo::GetCNodeName(anf_node);
  48. auto iter = std::find(CHECK_SUPPORTED_OPTYPE.begin(), CHECK_SUPPORTED_OPTYPE.end(), op_name);
  49. if (iter == CHECK_SUPPORTED_OPTYPE.end()) {
  50. MS_LOG(DEBUG) << "Op " << op_name << "this op does not need to check op supported.";
  51. return true;
  52. }
  53. // replace kernel_info with current kernel info
  54. auto ori_select_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(anf_node);
  55. AnfAlgo::SetSelectKernelBuildInfo(select_kernel_build_info, anf_node.get());
  56. nlohmann::json kernel_json;
  57. TbeKernelJsonCreator creator(CHECK_SUPPORTED);
  58. bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json);
  59. if (!ret) {
  60. MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed";
  61. AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get());
  62. return false;
  63. }
  64. ret = TbePythonFuncs::CheckSupported(kernel_json);
  65. AnfAlgo::SetSelectKernelBuildInfo(ori_select_kernel_info, anf_node.get());
  66. return ret;
  67. }
  68. bool CheckJsonItemValidity(const nlohmann::json &json_obj, const std::string &key_name,
  69. const std::vector<std::string> &keys) {
  70. if (!json_obj[key_name].is_object()) {
  71. MS_LOG(DEBUG) << key_name << "is not an object!";
  72. return false;
  73. }
  74. for (auto key : keys) {
  75. if (json_obj[key_name].find(key) == json_obj[key_name].end()) {
  76. MS_LOG(DEBUG) << "Key" << key << "of " << key_name << " is not found!";
  77. return false;
  78. }
  79. }
  80. return true;
  81. }
  82. std::vector<std::string> SplitStr(const std::string &string, const std::string &sep) {
  83. std::vector<std::string> result;
  84. size_t start = 0;
  85. size_t index = string.find(sep, start);
  86. std::string substr;
  87. while (index != std::string::npos) {
  88. if (string.size() > start) {
  89. substr = string.substr(start, index - start);
  90. }
  91. (void)substr.erase(0, substr.find_first_not_of(' '));
  92. (void)substr.erase(substr.find_last_not_of(" ") + 1);
  93. auto iter = DYNAMIC_FORMAT_MAP.find(substr);
  94. if (iter != DYNAMIC_FORMAT_MAP.end()) {
  95. substr = iter->second;
  96. }
  97. result.push_back(substr);
  98. start = index + sep.size();
  99. index = string.find(sep, start);
  100. }
  101. if (string.size() > start) {
  102. substr = string.substr(start);
  103. }
  104. (void)substr.erase(0, substr.find_first_not_of(" "));
  105. (void)substr.erase(substr.find_last_not_of(" ") + 1);
  106. auto iter = DYNAMIC_FORMAT_MAP.find(substr);
  107. if (iter != DYNAMIC_FORMAT_MAP.end()) {
  108. substr = iter->second;
  109. }
  110. result.push_back(substr);
  111. return result;
  112. }
  113. void ConvertFormatDtype(const std::string &format, const std::string &dtype, const std::shared_ptr<OpIOInfo> io_info) {
  114. MS_EXCEPTION_IF_NULL(io_info);
  115. std::vector<std::string> format_vec = SplitStr(format, ",");
  116. std::vector<std::string> dtype_vec = SplitStr(dtype, ",");
  117. io_info->set_formats(format_vec);
  118. io_info->set_dtypes(dtype_vec);
  119. }
  120. bool ParseDynamicFormatJson(const std::string &jsonStr, std::vector<std::shared_ptr<OpIOInfo>> *const inputs,
  121. std::vector<std::shared_ptr<OpIOInfo>> *const outputs) {
  122. nlohmann::json json_obj = nlohmann::json::parse(jsonStr);
  123. if (!json_obj.is_object()) {
  124. MS_LOG(DEBUG) << "JsonStr is not an object, the jsonStr is:" << jsonStr;
  125. return false;
  126. }
  127. std::vector<std::string> keys = {kName, kDtype, kFormat};
  128. for (const auto &item : json_obj.items()) {
  129. std::string key_name;
  130. key_name = item.key();
  131. if (key_name.empty()) {
  132. MS_LOG(DEBUG) << "Key name is empty!";
  133. return false;
  134. }
  135. if (!CheckJsonItemValidity(json_obj, key_name, keys)) {
  136. return false;
  137. }
  138. if (key_name.compare(0, strlen(kPrefixInput), kPrefixInput) == 0) {
  139. std::shared_ptr<OpIOInfo> input = std::make_shared<OpIOInfo>();
  140. MS_EXCEPTION_IF_NULL(input);
  141. input->set_name(json_obj[key_name].at(kName));
  142. ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), input);
  143. inputs->emplace_back(input);
  144. } else if (key_name.compare(0, strlen(kPrefixOutput), kPrefixOutput) == 0) {
  145. std::shared_ptr<OpIOInfo> output = std::make_shared<OpIOInfo>();
  146. MS_EXCEPTION_IF_NULL(output);
  147. output->set_name(json_obj[key_name].at(kName));
  148. ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), output);
  149. outputs->emplace_back(output);
  150. } else {
  151. MS_LOG(DEBUG) << "Key name:" << key_name << " is undefined!";
  152. return false;
  153. }
  154. }
  155. return true;
  156. }
  157. std::string OpSelectFormat(const std::shared_ptr<AnfNode> &anf_node) {
  158. nlohmann::json kernel_json;
  159. std::string res_json_str;
  160. TbeKernelJsonCreator creator(OP_SELECT_FORMAT);
  161. bool ret = creator.GenTbeSingleKernelJson(anf_node, &kernel_json);
  162. if (!ret) {
  163. MS_LOG(DEBUG) << "GenTbeSingleKernelJson failed";
  164. return res_json_str;
  165. }
  166. res_json_str = TbePythonFuncs::OpSelectFormat(kernel_json);
  167. MS_LOG(INFO) << "Dynamic select foramt response result:" << res_json_str;
  168. return res_json_str;
  169. }
  170. void SetTidyInputsInfo(const std::shared_ptr<AnfNode> &anf_node,
  171. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder,
  172. const std::vector<std::shared_ptr<OpIOInfo>> &inputs) {
  173. std::vector<TypeId> inputs_type;
  174. std::vector<std::string> inputs_format;
  175. std::vector<int> dyn_input_sizes;
  176. size_t dyn_input_idx = 0;
  177. size_t kernel_info_index = 0;
  178. size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node);
  179. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  180. MS_EXCEPTION_IF_NULL(primitive);
  181. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  182. dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
  183. }
  184. for (size_t i = 0; i < inputs.size(); i++) {
  185. MS_EXCEPTION_IF_NULL(inputs[i]);
  186. std::string param_type = inputs[i]->param_type();
  187. if (i >= real_input_num) {
  188. MS_LOG(INFO) << "Input index:" << i << "is out of real_input_num:" << real_input_num;
  189. continue;
  190. }
  191. auto type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, i);
  192. auto format = kOpFormat_DEFAULT;
  193. if (param_type == "dynamic") {
  194. if (!dyn_input_sizes.empty()) {
  195. for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  196. kernel_info_index++;
  197. inputs_type.emplace_back(type_id);
  198. inputs_format.emplace_back(format);
  199. }
  200. dyn_input_idx++;
  201. }
  202. } else if (param_type == "required") {
  203. kernel_info_index++;
  204. inputs_type.emplace_back(type_id);
  205. inputs_format.emplace_back(format);
  206. } else {
  207. if (kernel_info_index < real_input_num) {
  208. MS_LOG(INFO) << "Input type is optional, input index is :" << kernel_info_index;
  209. kernel_info_index++;
  210. inputs_type.emplace_back(type_id);
  211. inputs_format.emplace_back(format);
  212. }
  213. }
  214. }
  215. builder->SetInputsDeviceType(inputs_type);
  216. builder->SetInputsFormat(inputs_format);
  217. }
  218. void SetTidyOutputsInfo(const std::shared_ptr<AnfNode> &anf_node,
  219. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder,
  220. const std::vector<std::shared_ptr<OpIOInfo>> &outputs) {
  221. std::vector<TypeId> outputs_type;
  222. std::vector<std::string> outputs_format;
  223. auto real_output_num = AnfAlgo::GetOutputTensorNum(anf_node);
  224. size_t output_idx = 0;
  225. for (const auto output : outputs) {
  226. MS_EXCEPTION_IF_NULL(output);
  227. if (output_idx >= real_output_num) {
  228. continue;
  229. }
  230. size_t output_num = 0;
  231. if (output->param_type() == "dynamic") {
  232. if (outputs.size() > 1) {
  233. MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
  234. }
  235. output_num = real_output_num;
  236. } else if (output->param_type() == "required") {
  237. output_num = 1;
  238. } else {
  239. if (output_idx < real_output_num) {
  240. MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
  241. output_num = 1;
  242. }
  243. }
  244. for (size_t i = 0; i < output_num; i++) {
  245. auto type_id = AnfAlgo::GetOutputInferDataType(anf_node, output_idx);
  246. outputs_type.emplace_back(type_id);
  247. outputs_format.emplace_back(kOpFormat_DEFAULT);
  248. output_idx++;
  249. }
  250. }
  251. builder->SetOutputsDeviceType(outputs_type);
  252. builder->SetOutputsFormat(outputs_format);
  253. }
  254. void GenTidyKernelBuildInfo(const std::shared_ptr<AnfNode> &anf_node,
  255. const std::vector<std::shared_ptr<OpIOInfo>> &inputs,
  256. const std::vector<std::shared_ptr<OpIOInfo>> &outputs) {
  257. auto builder_tmp = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  258. builder_tmp->SetKernelType(TBE_KERNEL);
  259. SetTidyInputsInfo(anf_node, builder_tmp, inputs);
  260. SetTidyOutputsInfo(anf_node, builder_tmp, outputs);
  261. AnfAlgo::SetSelectKernelBuildInfo(builder_tmp->Build(), anf_node.get());
  262. }
  263. void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr,
  264. const std::shared_ptr<OpInfo> op_info_new_ptr) {
  265. std::vector<std::shared_ptr<OpIOInfo>> inputs_static = op_info_ptr->inputs_ptr();
  266. std::vector<std::shared_ptr<OpIOInfo>> outputs_static = op_info_ptr->outputs_ptr();
  267. std::vector<std::shared_ptr<OpIOInfo>> inputs_dyn;
  268. std::vector<std::shared_ptr<OpIOInfo>> outputs_dyn;
  269. if ((op_info_ptr->imply_type() == kTBE) && (!mindspore::opt::IsNopNode(kernel_node->cast<AnfNodePtr>()))) {
  270. // 1. create tidy kernelBuildInfo in order to generate json for calling op_select_format
  271. auto anf_node = kernel_node->cast<std::shared_ptr<AnfNode>>();
  272. auto kernel_build_info_ptr = AnfAlgo::GetSelectKernelBuildInfo(anf_node);
  273. GenTidyKernelBuildInfo(kernel_node, inputs_static, outputs_static);
  274. // 2.get dynamic format from op_impl
  275. std::string res_json_str;
  276. auto context_ptr = MsContext::GetInstance();
  277. MS_EXCEPTION_IF_NULL(context_ptr);
  278. if (context_ptr->execution_mode() != kPynativeMode) {
  279. res_json_str = OpSelectFormat(kernel_node);
  280. }
  281. if (!res_json_str.empty()) {
  282. (void)ParseDynamicFormatJson(res_json_str, &inputs_dyn, &outputs_dyn);
  283. }
  284. if (inputs_static.size() != inputs_dyn.size()) {
  285. inputs_dyn.clear();
  286. }
  287. if (outputs_static.size() != outputs_dyn.size()) {
  288. outputs_dyn.clear();
  289. }
  290. // 3. resume kernel node's SelectKernelBuildInfo
  291. // As it has been replaced by GenTidyKernelBuildInfo in order to call python func
  292. AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_ptr, anf_node.get());
  293. }
  294. // 4.replace by dynamic format and dtype
  295. if (inputs_dyn.empty() && outputs_dyn.empty()) {
  296. MS_LOG(INFO) << "Dynamic select format response is empty, use static register info.";
  297. op_info_new_ptr->set_inputs_ptr(inputs_static);
  298. op_info_new_ptr->set_outputs_ptr(outputs_static);
  299. } else {
  300. MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format.";
  301. for (size_t i = 0; i < inputs_static.size(); i++) {
  302. inputs_dyn[i]->set_param_type(inputs_static[i]->param_type());
  303. }
  304. for (size_t j = 0; j < outputs_static.size(); j++) {
  305. outputs_dyn[j]->set_param_type(outputs_static[j]->param_type());
  306. }
  307. op_info_new_ptr->set_inputs_ptr(inputs_dyn);
  308. op_info_new_ptr->set_outputs_ptr(outputs_dyn);
  309. }
  310. // 5.copy other opinfo to new op_info_new
  311. op_info_new_ptr->set_op_name(op_info_ptr->op_name());
  312. op_info_new_ptr->set_imply_type(op_info_ptr->imply_type());
  313. op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type());
  314. }
  315. bool SetKernelBuilderInputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
  316. size_t builder_idex, const std::vector<int> &dyn_input_sizes,
  317. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  318. MS_EXCEPTION_IF_NULL(builder);
  319. std::vector<TypeId> inputs_device_type;
  320. std::vector<std::string> inputs_format;
  321. size_t dyn_input_idx = 0;
  322. size_t kernel_info_index = 0;
  323. MS_EXCEPTION_IF_NULL(inputs[0]);
  324. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  325. for (const auto &input : inputs) {
  326. MS_EXCEPTION_IF_NULL(input);
  327. std::string param_type = input->param_type();
  328. std::vector<std::string> dtypes = input->dtypes();
  329. std::vector<std::string> formats = input->formats();
  330. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  331. MS_LOG(ERROR) << "Set input kernel builder info, dtyps size != formats size.";
  332. return false;
  333. }
  334. if (param_type == "dynamic") {
  335. if (dyn_input_sizes.empty()) {
  336. MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic";
  337. return false;
  338. }
  339. for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  340. kernel_info_index++;
  341. auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
  342. inputs_device_type.push_back(type_id);
  343. inputs_format.push_back(formats[builder_idex]);
  344. }
  345. dyn_input_idx++;
  346. } else if (param_type == "required") {
  347. kernel_info_index++;
  348. auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
  349. inputs_device_type.push_back(type_id);
  350. inputs_format.push_back(formats[builder_idex]);
  351. } else {
  352. if (kernel_info_index < real_input_num) {
  353. MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index;
  354. kernel_info_index++;
  355. auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
  356. inputs_device_type.push_back(type_id);
  357. inputs_format.push_back(formats[builder_idex]);
  358. }
  359. }
  360. }
  361. builder->SetInputsDeviceType(inputs_device_type);
  362. builder->SetInputsFormat(inputs_format);
  363. return true;
  364. }
  365. bool SetKernelBuilderOutputInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
  366. const size_t &real_output_num,
  367. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  368. // not now but in the next we need to support dynamic output case
  369. MS_EXCEPTION_IF_NULL(builder);
  370. size_t output_idx = 0;
  371. std::vector<TypeId> outputs_device_type;
  372. std::vector<std::string> outputs_format;
  373. MS_EXCEPTION_IF_NULL(outputs[0]);
  374. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  375. for (const auto &output : outputs) {
  376. MS_EXCEPTION_IF_NULL(output);
  377. if (output_idx >= real_output_num) {
  378. MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!";
  379. continue;
  380. }
  381. size_t output_num = 0;
  382. if (output->param_type() == "dynamic") {
  383. if (outputs.size() > 1) {
  384. MS_LOG(EXCEPTION) << "Dynamic output is unsupported multi output!";
  385. }
  386. output_num = real_output_num;
  387. } else if (output->param_type() == "required") {
  388. output_num = 1;
  389. } else {
  390. if (output_idx < real_output_num) {
  391. MS_LOG(INFO) << "Set output kernel builder info, output type is optional, output index is " << output_idx;
  392. output_num = 1;
  393. }
  394. }
  395. for (size_t i = 0; i < output_num; i++) {
  396. std::vector<std::string> dtypes = output->dtypes();
  397. std::vector<std::string> formats = output->formats();
  398. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  399. MS_LOG(ERROR) << "Set output kernel builder info, dtyps size != formats size.";
  400. return false;
  401. }
  402. auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]);
  403. outputs_device_type.push_back(type_id);
  404. outputs_format.push_back(formats[builder_idex]);
  405. output_idx++;
  406. }
  407. }
  408. builder->SetOutputsFormat(outputs_format);
  409. builder->SetOutputsDeviceType(outputs_device_type);
  410. return true;
  411. }
  412. void SetKernelBuildCommonInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder,
  413. Processor processor, const std::shared_ptr<const OpInfo> &op_info_ptr) {
  414. MS_EXCEPTION_IF_NULL(builder);
  415. MS_EXCEPTION_IF_NULL(op_info_ptr);
  416. builder->SetProcessor(processor);
  417. std::string fusion_type = op_info_ptr->fusion_type();
  418. if (tbe::GetFusionType(fusion_type) != UNKNOWN_FUSION_TYPE) {
  419. builder->SetFusionType(tbe::GetFusionType(fusion_type));
  420. }
  421. builder->SetKernelType(TBE_KERNEL);
  422. }
  423. bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr,
  424. std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
  425. MS_EXCEPTION_IF_NULL(kernel_node);
  426. MS_EXCEPTION_IF_NULL(kernel_info_list);
  427. size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  428. size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  429. std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
  430. std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
  431. std::vector<int> dyn_input_sizes;
  432. auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
  433. MS_EXCEPTION_IF_NULL(primitive);
  434. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  435. dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
  436. }
  437. if (inputs.size() > 0) {
  438. MS_EXCEPTION_IF_NULL(inputs[0]);
  439. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  440. for (size_t j = 0; j < kernel_info_cnt; j++) {
  441. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  442. MS_EXCEPTION_IF_NULL(builder);
  443. SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr);
  444. if (!SetKernelBuilderInputInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
  445. MS_LOG(ERROR) << "Parse kernel metadata, set inputs kernel builder info failed.";
  446. return false;
  447. }
  448. if (outputs.size() > 0) {
  449. if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) {
  450. MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed.";
  451. return false;
  452. }
  453. }
  454. kernel_info_list->push_back(builder->Build());
  455. }
  456. } else if (outputs.size() > 0) {
  457. MS_EXCEPTION_IF_NULL(outputs[0]);
  458. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  459. for (size_t j = 0; j < kernel_info_cnt; j++) {
  460. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  461. MS_EXCEPTION_IF_NULL(builder);
  462. SetKernelBuildCommonInfo(builder, Processor::AICORE, op_info_ptr);
  463. if (!SetKernelBuilderOutputInfo(outputs, j, real_output_num, builder)) {
  464. MS_LOG(ERROR) << "Parse kernel metadata, set outputs kernel builder info failed.";
  465. return false;
  466. }
  467. kernel_info_list->push_back(builder->Build());
  468. }
  469. }
  470. return true;
  471. }
  472. bool IsShapeMatchFormat(const std::vector<size_t> &shape, const std::string &format) {
  473. const std::set<std::string> kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND,
  474. kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
  475. kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
  476. kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04};
  477. // if format is default, it remarkes support all format
  478. if (kOpFormatList.find(format) == kOpFormatList.end()) {
  479. MS_LOG(EXCEPTION) << "Got the unknown format " << format;
  480. }
  481. if (format == kOpFormat_DEFAULT) {
  482. return true;
  483. }
  484. // if shape size is 0, the shape will be a scalar
  485. if (shape.empty()) {
  486. return true;
  487. }
  488. if (shape.size() > kShapeSupportFormatMap.size()) {
  489. return false;
  490. }
  491. if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) {
  492. return true;
  493. }
  494. return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end());
  495. }
  496. bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) {
  497. MS_EXCEPTION_IF_NULL(kernel_node);
  498. auto check_function = [](const std::vector<size_t> &shape, const std::string &format) -> bool {
  499. if (!IsShapeMatchFormat(shape, format)) {
  500. return false;
  501. }
  502. return true;
  503. };
  504. for (size_t index = 0; index < kernel_build_info.GetOutputNum(); ++index) {
  505. auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, index);
  506. if (!check_function(output_shape, kernel_build_info.GetOutputFormat(index))) {
  507. return false;
  508. }
  509. }
  510. for (size_t index = 0; index < kernel_build_info.GetInputNum(); ++index) {
  511. auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
  512. if (!check_function(input_shape, kernel_build_info.GetInputFormat(index))) {
  513. return false;
  514. }
  515. }
  516. if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
  517. return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
  518. AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);
  519. }
  520. return true;
  521. }
  522. void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
  523. MS_EXCEPTION_IF_NULL(kernel_node);
  524. MS_EXCEPTION_IF_NULL(kernel_info_list);
  525. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> parse_info_list;
  526. std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
  527. auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE);
  528. if (op_info_ptr == nullptr) {
  529. return;
  530. }
  531. // dynamic get op format and dtype and replace opinfo
  532. auto op_info_new_ptr = std::make_shared<OpInfo>();
  533. ReplaceByDynamicFormatDtype(kernel_node, op_info_ptr, op_info_new_ptr);
  534. if (!ParseMetadata(kernel_node, op_info_new_ptr, &parse_info_list)) {
  535. MS_LOG(INFO) << "Tbe parsed metadata of op[" << op_name << "] failed.";
  536. return;
  537. }
  538. auto context_ptr = MsContext::GetInstance();
  539. MS_EXCEPTION_IF_NULL(context_ptr);
  540. for (auto parse_info : parse_info_list) {
  541. if (context_ptr->execution_mode() == kPynativeMode) {
  542. kernel_info_list->push_back(parse_info);
  543. } else {
  544. if (IsValidKernelInfo(kernel_node, *(parse_info))) {
  545. if (CheckSupported(kernel_node, parse_info)) {
  546. kernel_info_list->push_back(parse_info);
  547. } else {
  548. MS_LOG(INFO) << "CheckSupported Failed for TBE op" << op_name << " kernel info.";
  549. }
  550. }
  551. }
  552. }
  553. if (kernel_info_list->empty()) {
  554. MS_LOG(DEBUG) << "Tbe dose not have op [" << op_name << "].";
  555. }
  556. }
  557. } // namespace kernel
  558. } // namespace mindspore