You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_utils.cc 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/kernel_compiler/common_utils.h"
  17. #include <unordered_map>
  18. #include <map>
  19. #include <iostream>
  20. #include <utility>
  21. #include <fstream>
  22. #include <algorithm>
  23. #include <thread>
  24. #include "nlohmann/json.hpp"
  25. #include "backend/session/anf_runtime_algorithm.h"
  26. #include "utils/ms_utils.h"
  27. #include "ir/manager.h"
  28. #include "ir/meta_tensor.h"
  29. #include "base/core_ops.h"
  30. #include "ir/graph_utils.h"
  31. namespace mindspore {
  32. namespace kernel {
  33. constexpr char kAxis[] = "axis";
  34. constexpr char kTypeInt32[] = "Int32";
  35. const std::unordered_map<std::string, TypeId> type_id_maps = {
  36. {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16},
  37. {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64},
  38. {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8},
  39. {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32},
  40. {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
  41. {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
  42. {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
  43. {"bool", TypeId::kNumberTypeBool},
  44. };
  45. const std::map<TypeId, std::string> type_id_str_map = {
  46. {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"},
  47. {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"},
  48. {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"},
  49. {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"},
  50. {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
  51. {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
  52. {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
  53. {TypeId::kNumberTypeBool, "bool"},
  54. };
  55. const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
  56. {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
  57. {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
  58. };
  59. const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
  60. {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
  61. {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
  62. {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
  63. {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
  64. };
  65. const std::unordered_map<std::string, FusionType> fusion_type_maps = {
  66. {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE},
  67. {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE},
  68. };
  69. void KernelMeta::Initialize(int pid) {
  70. if (pid == -1) {
  71. kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/";
  72. } else {
  73. kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(pid) + "/";
  74. }
  75. // remove old kernel cache
  76. RemoveKernelCache();
  77. #if defined(_WIN32) || defined(_WIN64)
  78. auto ret = mkdir(kernel_meta_path_.c_str());
  79. #else
  80. auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU);
  81. #endif
  82. if (ret != 0) {
  83. MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later";
  84. }
  85. initialized_ = true;
  86. }
  87. void KernelMeta::RemoveKernelCache() {
  88. DIR *dir = opendir(kernel_meta_path_.c_str());
  89. if (dir == nullptr) {
  90. return;
  91. }
  92. struct dirent *entry;
  93. while ((entry = readdir(dir)) != nullptr) {
  94. std::string kernel_file = entry->d_name;
  95. std::string kernel_file_realpath = kernel_meta_path_ + kernel_file;
  96. (void)remove(kernel_file_realpath.c_str());
  97. }
  98. (void)closedir(dir);
  99. (void)rmdir(kernel_meta_path_.c_str());
  100. }
  101. std::string KernelMeta::Search(const std::string &kernel_name) const {
  102. if (!initialized_) {
  103. return "";
  104. }
  105. auto iter = kernel_meta_map_.find(kernel_name);
  106. if (iter == kernel_meta_map_.end()) {
  107. return "";
  108. } else {
  109. return iter->second;
  110. }
  111. }
  112. bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
  113. if (!initialized_) {
  114. return false;
  115. }
  116. kernel_meta_map_[kernel_name] = kernel_json;
  117. return true;
  118. }
  119. bool CheckCache(const std::string &kernel_name) {
  120. // check cache.
  121. KernelMeta *bin_map = KernelMeta::GetInstance();
  122. if (bin_map == nullptr) {
  123. MS_LOG(DEBUG) << "kernel cache is invalid.";
  124. return false;
  125. }
  126. std::string kernel_json = bin_map->Search(kernel_name);
  127. bool ret = (!kernel_json.empty());
  128. if (ret) {
  129. MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed.";
  130. } else {
  131. MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed.";
  132. }
  133. return ret;
  134. }
  135. KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
  136. // search cache.
  137. KernelMeta *bin_map = KernelMeta::GetInstance();
  138. if (bin_map == nullptr) {
  139. MS_LOG(DEBUG) << "kernel cache is invalid.";
  140. return nullptr;
  141. }
  142. std::string kernel_json = bin_map->Search(kernel_name);
  143. if (!kernel_json.empty()) {
  144. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  145. // just a tmp solution.
  146. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  147. MS_LOG(DEBUG) << "Read cache json and bin file failed[" << kernel_json << "].";
  148. return nullptr;
  149. } else {
  150. return kernel_pack;
  151. }
  152. } else {
  153. MS_LOG(INFO) << "cache kernel not found[" << kernel_name << "].";
  154. return nullptr;
  155. }
  156. }
  157. KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
  158. MS_LOG(INFO) << "kernel name:" << kernel_name << ", processr:" << processor;
  159. KernelMeta *bin_map = KernelMeta::GetInstance();
  160. std::string kernel_json;
  161. if (processor == kProcessorAiCore || processor == kProcessorAiCpu) {
  162. kernel_json = kCceKernelMeta;
  163. } else {
  164. kernel_json = bin_map->kernel_meta_path();
  165. }
  166. (void)kernel_json.append(kernel_name).append(kJsonSuffix);
  167. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  168. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  169. MS_LOG(DEBUG) << "Read json and bin file failed[" << kernel_json << "].";
  170. return nullptr;
  171. }
  172. if (bin_map == nullptr) {
  173. MS_LOG(DEBUG) << "kernel cache is invalid.";
  174. return nullptr;
  175. }
  176. if (bin_map->Insert(kernel_name, kernel_json)) {
  177. MS_LOG(INFO) << "Insert to cache success[" << kernel_json << "], kernelname[" << kernel_name << "].";
  178. }
  179. return kernel_pack;
  180. }
  181. TypeId DtypeToTypeId(const std::string &dtypes) {
  182. auto iter = type_id_maps.find(dtypes);
  183. if (iter != type_id_maps.end()) {
  184. return iter->second;
  185. } else {
  186. MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes;
  187. }
  188. }
  189. std::string TypeId2String(TypeId type_id, bool unknown_as_default) {
  190. auto iter = type_id_str_map.find(type_id);
  191. if (iter == type_id_str_map.end()) {
  192. if (!unknown_as_default) {
  193. MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id);
  194. }
  195. return "float32";
  196. }
  197. return iter->second;
  198. }
  199. std::string Dtype2ShortType(const std::string &dtypes) {
  200. auto iter = dtype_shortdtype_map_.find(dtypes);
  201. if (iter != dtype_shortdtype_map_.end()) {
  202. return iter->second;
  203. } else {
  204. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
  205. }
  206. }
  207. size_t GetDtypeNbyte(const std::string &dtypes) {
  208. auto iter = dtype_nbyte_map.find(dtypes);
  209. if (iter != dtype_nbyte_map.end()) {
  210. return iter->second;
  211. } else {
  212. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
  213. }
  214. }
  215. bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
  216. size_t builder_idex, const std::vector<int> &dyn_input_sizes,
  217. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  218. MS_EXCEPTION_IF_NULL(builder);
  219. std::vector<TypeId> inputs_device_type;
  220. std::vector<std::string> inputs_format;
  221. size_t dyn_input_idx = 0;
  222. size_t kernel_info_index = 0;
  223. MS_EXCEPTION_IF_NULL(inputs[0]);
  224. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  225. for (const auto &input : inputs) {
  226. MS_EXCEPTION_IF_NULL(input);
  227. std::string param_type = input->param_type();
  228. std::vector<std::string> dtypes = input->dtypes();
  229. std::vector<std::string> formats = input->formats();
  230. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  231. MS_LOG(DEBUG) << "Set input kernel builder info, dtyps size != formats size.";
  232. return false;
  233. }
  234. if (param_type == "dynamic") {
  235. if (dyn_input_sizes.empty()) {
  236. MS_LOG(DEBUG) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic";
  237. return false;
  238. }
  239. for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  240. kernel_info_index++;
  241. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  242. inputs_device_type.push_back(type_id);
  243. inputs_format.push_back(formats[builder_idex]);
  244. }
  245. dyn_input_idx++;
  246. } else if (param_type == "required") {
  247. kernel_info_index++;
  248. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  249. inputs_device_type.push_back(type_id);
  250. inputs_format.push_back(formats[builder_idex]);
  251. } else {
  252. if (kernel_info_index < real_input_num) {
  253. MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
  254. kernel_info_index++;
  255. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  256. inputs_device_type.push_back(type_id);
  257. inputs_format.push_back(formats[builder_idex]);
  258. }
  259. }
  260. }
  261. builder->SetInputsDeviceType(inputs_device_type);
  262. builder->SetInputsFormat(inputs_format);
  263. return true;
  264. }
  265. bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
  266. const size_t &real_output_num,
  267. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  268. // not now but in the next we need to support dynamic output case
  269. MS_EXCEPTION_IF_NULL(builder);
  270. size_t output_idx = 0;
  271. std::vector<TypeId> outputs_device_type;
  272. std::vector<std::string> outputs_format;
  273. MS_EXCEPTION_IF_NULL(outputs[0]);
  274. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  275. for (const auto &output : outputs) {
  276. MS_EXCEPTION_IF_NULL(output);
  277. if (output_idx >= real_output_num) {
  278. MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
  279. continue;
  280. }
  281. size_t output_num = 0;
  282. if (output->param_type() == "dynamic") {
  283. if (outputs.size() > 1) {
  284. MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
  285. }
  286. output_num = real_output_num;
  287. } else if (output->param_type() == "required") {
  288. output_num = 1;
  289. } else {
  290. if (output_idx < real_output_num) {
  291. MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
  292. output_num = 1;
  293. }
  294. }
  295. for (size_t i = 0; i < output_num; i++) {
  296. std::vector<std::string> dtypes = output->dtypes();
  297. std::vector<std::string> formats = output->formats();
  298. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  299. MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size.";
  300. return false;
  301. }
  302. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  303. outputs_device_type.push_back(type_id);
  304. outputs_format.push_back(formats[builder_idex]);
  305. output_idx++;
  306. }
  307. }
  308. builder->SetOutputsFormat(outputs_format);
  309. builder->SetOutputsDeviceType(outputs_device_type);
  310. return true;
  311. }
  312. void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
  313. const std::shared_ptr<const OpInfo> &op_info_ptr) {
  314. MS_EXCEPTION_IF_NULL(builder);
  315. MS_EXCEPTION_IF_NULL(op_info_ptr);
  316. auto imply_type = op_info_ptr->imply_type();
  317. builder->SetProcessor(processor);
  318. std::string fusion_type = op_info_ptr->fusion_type();
  319. auto iter = fusion_type_maps.find(fusion_type);
  320. if (iter != fusion_type_maps.end()) {
  321. builder->SetFusionType(iter->second);
  322. } else {
  323. if (imply_type == kAKG) {
  324. MS_EXCEPTION(NotExistsError) << "Illegal fusion type from dsl register:" << fusion_type;
  325. }
  326. }
  327. if (imply_type == kAKG) {
  328. builder->SetKernelType(AKG_KERNEL);
  329. } else if (imply_type == kAICPU) {
  330. builder->SetKernelType(AICPU_KERNEL);
  331. } else {
  332. builder->SetKernelType(TBE_KERNEL);
  333. }
  334. }
  335. bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
  336. std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
  337. MS_EXCEPTION_IF_NULL(kernel_node);
  338. MS_EXCEPTION_IF_NULL(kernel_info_list);
  339. size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  340. size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  341. std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
  342. std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
  343. std::vector<int> dyn_input_sizes;
  344. auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
  345. MS_EXCEPTION_IF_NULL(primitive);
  346. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  347. dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
  348. }
  349. if (inputs.size() > 0) {
  350. MS_EXCEPTION_IF_NULL(inputs[0]);
  351. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  352. for (size_t j = 0; j < kernel_info_cnt; j++) {
  353. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  354. MS_EXCEPTION_IF_NULL(builder);
  355. SetKernelBuildInfo(builder, processor, op_info_ptr);
  356. if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
  357. MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed.";
  358. return false;
  359. }
  360. if (outputs.size() > 0) {
  361. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  362. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed.";
  363. return false;
  364. }
  365. }
  366. kernel_info_list->push_back(builder->Build());
  367. }
  368. } else if (outputs.size() > 0) {
  369. MS_EXCEPTION_IF_NULL(outputs[0]);
  370. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  371. for (size_t j = 0; j < kernel_info_cnt; j++) {
  372. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  373. MS_EXCEPTION_IF_NULL(builder);
  374. SetKernelBuildInfo(builder, processor, op_info_ptr);
  375. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  376. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed.";
  377. return false;
  378. }
  379. kernel_info_list->push_back(builder->Build());
  380. }
  381. } else {
  382. if (processor == AICPU) {
  383. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  384. MS_EXCEPTION_IF_NULL(builder);
  385. SetKernelBuildInfo(builder, processor, op_info_ptr);
  386. kernel_info_list->push_back(builder->Build());
  387. }
  388. }
  389. return true;
  390. }
  391. void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
  392. char real_path[PATH_MAX] = {0};
  393. std::string path = base_path + json_name + kInfoSuffix;
  394. if (path.size() > PATH_MAX) {
  395. MS_LOG(DEBUG) << "file path " << path << " is too long.";
  396. return;
  397. }
  398. std::ofstream filewrite;
  399. filewrite.open(path);
  400. if (!filewrite.is_open()) {
  401. return;
  402. }
  403. filewrite << info << std::endl;
  404. filewrite.close();
  405. #if defined(_WIN32) || defined(_WIN64)
  406. if (nullptr == _fullpath(real_path, path.c_str(), PATH_MAX)) {
  407. MS_LOG(DEBUG) << "dir " << path << " does not exit.";
  408. return;
  409. }
  410. #else
  411. if (nullptr == realpath(path.c_str(), real_path)) {
  412. MS_LOG(DEBUG) << "dir " << path << " does not exit.";
  413. return;
  414. }
  415. #endif
  416. MS_LOG(INFO) << "real path is :" << real_path;
  417. if (chmod(real_path, S_IRUSR) == -1) {
  418. MS_LOG(DEBUG) << "modify file:" << real_path << " to read only fail.";
  419. }
  420. }
  421. Processor GetProcessor(const string &processor) {
  422. if (processor == kProcessorAiCore) return Processor::AICORE;
  423. if (processor == kProcessorAiCpu) return Processor::AICPU;
  424. if (processor == kProcessorCuda) return Processor::CUDA;
  425. MS_LOG(DEBUG) << "Unknown processor type.";
  426. return Processor::UNKNOWN;
  427. }
  428. std::string GetProcessor(const AnfNodePtr &anf_node) {
  429. MS_EXCEPTION_IF_NULL(anf_node);
  430. std::string device;
  431. switch (AnfAlgo::GetProcessor(anf_node)) {
  432. case Processor::AICORE:
  433. device = kProcessorAiCore;
  434. break;
  435. case Processor::AICPU:
  436. device = kProcessorAiCpu;
  437. break;
  438. case Processor::CUDA:
  439. device = kProcessorCuda;
  440. break;
  441. default:
  442. MS_LOG(DEBUG) << "Unknown processor type.";
  443. break;
  444. }
  445. return device;
  446. }
  447. bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b) {
  448. if (shape_a.size() != shape_b.size()) {
  449. return false;
  450. }
  451. for (size_t i = 0; i < shape_a.size(); ++i) {
  452. if (shape_a[i] != shape_b[i]) {
  453. return false;
  454. }
  455. }
  456. return true;
  457. }
  458. int Sign(float x) {
  459. if (x > 0) {
  460. return 1;
  461. }
  462. if (x < 0) {
  463. return -1;
  464. }
  465. return 0;
  466. }
  467. std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
  468. MS_EXCEPTION_IF_NULL(anf_node);
  469. if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
  470. MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs.";
  471. }
  472. auto cnode = anf_node->cast<CNodePtr>();
  473. if (cnode == nullptr) {
  474. return AnfAlgo::VisitKernel(anf_node, 0);
  475. } else {
  476. return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
  477. }
  478. }
  479. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
  480. const std::vector<AnfNodePtr> &input_list) {
  481. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
  482. for (size_t i = 0; i < input_list.size(); ++i) {
  483. auto const &input = input_list[i];
  484. MS_EXCEPTION_IF_NULL(input);
  485. bool found = false;
  486. // using NodeUsersMap = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, int>>>;
  487. auto mng = input->func_graph()->manager();
  488. MS_EXCEPTION_IF_NULL(mng);
  489. const NodeUsersMap &users = mng->node_users();
  490. auto input_users = users.find(input);
  491. if (input_users == users.end() || input_users->second.empty()) {
  492. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  493. << input->func_graph()->ToString() << "] has no users.";
  494. }
  495. for (auto const &input_user : input_users->second) {
  496. for (auto const &anf_node : node_list) {
  497. if (anf_node != input_user.first) {
  498. continue;
  499. }
  500. std::vector<int> dyn_input_sizes;
  501. auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
  502. MS_EXCEPTION_IF_NULL(prim);
  503. if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
  504. dyn_input_sizes = GetValue<const std::vector<int>>(prim->GetAttr(kAttrDynInputSizes));
  505. }
  506. if (dyn_input_sizes.empty()) {
  507. input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
  508. found = true;
  509. break;
  510. } else {
  511. int used_as_idx = input_user.second - 1;
  512. int accum_idx = 0;
  513. size_t dyn_i = 0;
  514. for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
  515. accum_idx += dyn_input_sizes[dyn_i];
  516. if (used_as_idx < accum_idx) {
  517. input_index.push_back(std::make_pair(
  518. anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
  519. break;
  520. }
  521. }
  522. if (dyn_i != dyn_input_sizes.size()) {
  523. found = true;
  524. break;
  525. }
  526. }
  527. }
  528. if (found) {
  529. break;
  530. }
  531. }
  532. if (!found) {
  533. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  534. << input->func_graph()->ToString() << "] found no related kernel info.";
  535. }
  536. }
  537. return input_index;
  538. }
  539. std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
  540. const std::vector<AnfNodePtr> &input_list,
  541. const std::vector<AnfNodePtr> &output_list) {
  542. std::vector<std::pair<AnfNodePtr, size_t>> output_index;
  543. for (size_t i = 0; i < output_list.size(); ++i) {
  544. auto const &output = output_list[i];
  545. MS_EXCEPTION_IF_NULL(output);
  546. bool found = false;
  547. auto pree_node = AnfAlgo::VisitKernel(output, 0);
  548. auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
  549. if (pos != std::end(node_list)) {
  550. output_index.push_back(pree_node);
  551. continue;
  552. }
  553. auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
  554. if (ret != std::end(input_list)) {
  555. output_index.push_back(std::make_pair(pree_node.first, 0));
  556. found = true;
  557. }
  558. if (!found) {
  559. MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
  560. << output->func_graph()->ToString() << "] found no related kernel info.";
  561. }
  562. }
  563. return output_index;
  564. }
  565. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
  566. MS_EXCEPTION_IF_NULL(node_list);
  567. MS_EXCEPTION_IF_NULL(func_graph);
  568. std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
  569. for (auto const &node : node_lists) {
  570. if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
  571. continue;
  572. }
  573. auto cnode = node->cast<CNodePtr>();
  574. MS_EXCEPTION_IF_NULL(cnode);
  575. if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
  576. node_list->push_back(node);
  577. }
  578. }
  579. }
  580. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
  581. std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
  582. MS_EXCEPTION_IF_NULL(func_graph);
  583. MS_EXCEPTION_IF_NULL(node_list);
  584. MS_EXCEPTION_IF_NULL(input_list);
  585. GetValidKernelNodes(func_graph, node_list);
  586. auto parameters = func_graph->parameters();
  587. input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
  588. GetFuncGraphOutputNodes(func_graph, output_list);
  589. }
  590. void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
  591. MS_EXCEPTION_IF_NULL(func_graph);
  592. MS_EXCEPTION_IF_NULL(output_list);
  593. auto func_output = func_graph->output();
  594. MS_EXCEPTION_IF_NULL(func_output);
  595. if (func_output->isa<CNode>()) {
  596. // multi output.
  597. auto cnode = func_output->cast<CNodePtr>();
  598. MS_EXCEPTION_IF_NULL(cnode);
  599. auto input0 = cnode->input(kAnfPrimitiveIndex);
  600. MS_EXCEPTION_IF_NULL(input0);
  601. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  602. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
  603. auto input_node = cnode->input(input_idx);
  604. MS_EXCEPTION_IF_NULL(input_node);
  605. output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
  606. }
  607. } else {
  608. // single output.
  609. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  610. }
  611. } else {
  612. // single output.
  613. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  614. }
  615. }
  616. bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
  617. MS_EXCEPTION_IF_NULL(anf_node);
  618. MS_EXCEPTION_IF_NULL(node_json);
  619. auto cnode = anf_node->cast<CNodePtr>();
  620. MS_EXCEPTION_IF_NULL(cnode);
  621. if (input_idx + 1 >= cnode->size()) {
  622. MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
  623. << cnode->inputs().size() << "][" << cnode->DebugString() << "]";
  624. }
  625. auto input_node = cnode->input(input_idx + 1);
  626. if (!IsValueNode<tensor::Tensor>(input_node)) {
  627. return false;
  628. }
  629. auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
  630. if (tensor == nullptr) {
  631. return false;
  632. }
  633. auto type_id = tensor->data_type();
  634. auto *data = tensor->data_c();
  635. MS_EXCEPTION_IF_NULL(data);
  636. if (tensor->DataSize() > 1) {
  637. // not const tensor.
  638. MS_LOG(WARNING) << "Not take value of tensor whose datasize greater than 1, [" << input_node->DebugString(2) << "]";
  639. return false;
  640. }
  641. if (type_id == kFloat32->type_id()) {
  642. float *val = static_cast<float *>(data);
  643. MS_EXCEPTION_IF_NULL(val);
  644. (*node_json)["value"] = val[0];
  645. MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "].";
  646. return true;
  647. } else if (type_id == kFloat16->type_id()) {
  648. float16 *val = static_cast<float16 *>(data);
  649. MS_EXCEPTION_IF_NULL(val);
  650. (*node_json)["value"] = static_cast<float>(val[0]);
  651. MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "].";
  652. return true;
  653. } else if (type_id == kInt32->type_id()) {
  654. int *val = static_cast<int *>(data);
  655. MS_EXCEPTION_IF_NULL(val);
  656. (*node_json)["value"] = val[0];
  657. MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "].";
  658. return true;
  659. }
  660. MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
  661. return false;
  662. }
  663. void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) {
  664. MS_EXCEPTION_IF_NULL(func_graph);
  665. MS_EXCEPTION_IF_NULL(node_list);
  666. auto output = func_graph->output();
  667. MS_EXCEPTION_IF_NULL(output);
  668. if (AnfAlgo::IsRealKernel(output)) {
  669. // single output.
  670. node_list->push_back(std::make_pair(output, 0));
  671. return;
  672. } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
  673. auto output_cnode = output->cast<CNodePtr>();
  674. MS_EXCEPTION_IF_NULL(output_cnode);
  675. // multi output.
  676. auto &inputs = output_cnode->inputs();
  677. for (size_t i = 1; i < inputs.size(); ++i) {
  678. auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0);
  679. node_list->push_back(in_with_idx);
  680. }
  681. return;
  682. }
  683. MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2)
  684. << " of graph: " << func_graph->ToString();
  685. }
  686. bool IsWeightBoundary(const AnfNodePtr &node) {
  687. if (node->isa<ValueNode>()) {
  688. return true;
  689. }
  690. if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  691. return true;
  692. }
  693. return false;
  694. }
  695. std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode) {
  696. if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) &&
  697. AnfAlgo::GetInputTensorNum(cnode) != 1) {
  698. MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString()
  699. << "] is not single input or single output ";
  700. }
  701. std::vector<int> axis;
  702. auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
  703. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  704. MS_EXCEPTION_IF_NULL(primitive);
  705. auto axis_attr = primitive->GetAttr(kAxis);
  706. if (axis_attr == nullptr) {
  707. MS_LOG(ERROR) << "This node does't have axie attr.";
  708. return std::vector<int>();
  709. }
  710. auto type = axis_attr->type();
  711. MS_EXCEPTION_IF_NULL(type);
  712. std::vector<int> axis_list;
  713. if (type->ToString() == kTypeInt32) {
  714. axis_list.emplace_back(GetValue<int>(axis_attr));
  715. } else {
  716. axis_list = GetValue<std::vector<int>>(axis_attr);
  717. }
  718. for (const auto &elem : axis_list) {
  719. if (elem < 0) {
  720. axis.emplace_back(input_shape.size() + elem);
  721. } else {
  722. axis.emplace_back(elem);
  723. }
  724. }
  725. AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
  726. return axis;
  727. }
  728. std::string GetProcessorStr(const AnfNodePtr &anf_node) {
  729. MS_EXCEPTION_IF_NULL(anf_node);
  730. std::string processor = kProcessorUnknown;
  731. auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
  732. MS_EXCEPTION_IF_NULL(kernel_info);
  733. auto build_info = kernel_info->select_kernel_build_info();
  734. // we may call this before kernel select.
  735. if (build_info == nullptr) {
  736. return processor;
  737. }
  738. switch (build_info->processor()) {
  739. case Processor::AICORE:
  740. processor = kProcessorAiCore;
  741. break;
  742. case Processor::AICPU:
  743. processor = kProcessorAiCpu;
  744. break;
  745. case Processor::CUDA:
  746. processor = kProcessorCuda;
  747. break;
  748. default:
  749. MS_LOG(ERROR) << "Unknown processor type.";
  750. break;
  751. }
  752. return processor;
  753. }
  754. } // namespace kernel
  755. } // namespace mindspore