You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_utils.cc 31 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/kernel_compiler/common_utils.h"
  17. #include <unordered_map>
  18. #include <map>
  19. #include <iostream>
  20. #include <utility>
  21. #include <fstream>
  22. #include <algorithm>
  23. #include <thread>
  24. #include "nlohmann/json.hpp"
  25. #include "backend/session/anf_runtime_algorithm.h"
  26. #include "utils/ms_utils.h"
  27. #include "ir/manager.h"
  28. #include "ir/meta_tensor.h"
  29. #include "base/core_ops.h"
  30. #include "ir/graph_utils.h"
  31. #include "utils/ms_context.h"
  32. namespace mindspore {
  33. namespace kernel {
  34. constexpr char kAxis[] = "axis";
  35. constexpr char kTypeInt32[] = "Int32";
  36. const std::unordered_map<std::string, TypeId> type_id_maps = {
  37. {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16},
  38. {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64},
  39. {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8},
  40. {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32},
  41. {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
  42. {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
  43. {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
  44. {"bool", TypeId::kNumberTypeBool},
  45. };
  46. const std::map<TypeId, std::string> type_id_str_map = {
  47. {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"},
  48. {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"},
  49. {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"},
  50. {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"},
  51. {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
  52. {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
  53. {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
  54. {TypeId::kNumberTypeBool, "bool"},
  55. };
  56. const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
  57. {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
  58. {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
  59. };
  60. const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
  61. {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
  62. {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
  63. {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
  64. {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
  65. };
  66. const std::unordered_map<std::string, FusionType> fusion_type_maps = {
  67. {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE},
  68. {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE},
  69. };
  70. void KernelMeta::Initialize(int pid) {
  71. if (pid == -1) {
  72. kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/";
  73. } else {
  74. kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(pid) + "/";
  75. }
  76. // remove old kernel cache
  77. RemoveKernelCache();
  78. #if defined(_WIN32) || defined(_WIN64)
  79. auto ret = mkdir(kernel_meta_path_.c_str());
  80. #else
  81. auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU);
  82. #endif
  83. if (ret != 0) {
  84. MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later";
  85. }
  86. initialized_ = true;
  87. }
  88. void KernelMeta::RemoveKernelCache() {
  89. DIR *dir = opendir(kernel_meta_path_.c_str());
  90. if (dir == nullptr) {
  91. return;
  92. }
  93. struct dirent *entry;
  94. while ((entry = readdir(dir)) != nullptr) {
  95. std::string kernel_file = entry->d_name;
  96. std::string kernel_file_realpath = kernel_meta_path_ + kernel_file;
  97. (void)remove(kernel_file_realpath.c_str());
  98. }
  99. (void)closedir(dir);
  100. (void)rmdir(kernel_meta_path_.c_str());
  101. }
  102. std::string KernelMeta::Search(const std::string &kernel_name) const {
  103. if (!initialized_) {
  104. return "";
  105. }
  106. auto iter = kernel_meta_map_.find(kernel_name);
  107. if (iter == kernel_meta_map_.end()) {
  108. return "";
  109. } else {
  110. return iter->second;
  111. }
  112. }
  113. bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
  114. if (!initialized_) {
  115. return false;
  116. }
  117. kernel_meta_map_[kernel_name] = kernel_json;
  118. return true;
  119. }
  120. bool CheckCache(const std::string &kernel_name) {
  121. // check cache.
  122. KernelMeta *bin_map = KernelMeta::GetInstance();
  123. if (bin_map == nullptr) {
  124. MS_LOG(DEBUG) << "kernel cache is invalid.";
  125. return false;
  126. }
  127. std::string kernel_json = bin_map->Search(kernel_name);
  128. bool ret = (!kernel_json.empty());
  129. if (ret) {
  130. MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
  131. } else {
  132. MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
  133. }
  134. return ret;
  135. }
  136. KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
  137. // search cache.
  138. KernelMeta *bin_map = KernelMeta::GetInstance();
  139. if (bin_map == nullptr) {
  140. MS_LOG(DEBUG) << "kernel cache is invalid.";
  141. return nullptr;
  142. }
  143. std::string kernel_json = bin_map->Search(kernel_name);
  144. if (!kernel_json.empty()) {
  145. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  146. // just a tmp solution.
  147. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  148. MS_LOG(DEBUG) << "Read cache json and bin file failed[" << kernel_json << "].";
  149. return nullptr;
  150. } else {
  151. return kernel_pack;
  152. }
  153. } else {
  154. MS_LOG(INFO) << "cache kernel not found[" << kernel_name << "].";
  155. return nullptr;
  156. }
  157. }
  158. KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
  159. MS_LOG(INFO) << "kernel name:" << kernel_name << ", processr:" << processor;
  160. KernelMeta *bin_map = KernelMeta::GetInstance();
  161. std::string kernel_json;
  162. if (processor == kProcessorAiCore || processor == kProcessorAiCpu) {
  163. kernel_json = kCceKernelMeta;
  164. } else {
  165. kernel_json = bin_map->kernel_meta_path();
  166. }
  167. (void)kernel_json.append(kernel_name).append(kJsonSuffix);
  168. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  169. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  170. MS_LOG(DEBUG) << "Read json and bin file failed[" << kernel_json << "].";
  171. return nullptr;
  172. }
  173. if (bin_map == nullptr) {
  174. MS_LOG(DEBUG) << "kernel cache is invalid.";
  175. return nullptr;
  176. }
  177. if (bin_map->Insert(kernel_name, kernel_json)) {
  178. MS_LOG(INFO) << "Insert to cache success[" << kernel_json << "], kernelname[" << kernel_name << "].";
  179. }
  180. return kernel_pack;
  181. }
  182. TypeId DtypeToTypeId(const std::string &dtypes) {
  183. auto iter = type_id_maps.find(dtypes);
  184. if (iter != type_id_maps.end()) {
  185. return iter->second;
  186. } else {
  187. MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes;
  188. }
  189. }
  190. std::string TypeId2String(TypeId type_id, bool unknown_as_default) {
  191. auto iter = type_id_str_map.find(type_id);
  192. if (iter == type_id_str_map.end()) {
  193. if (!unknown_as_default) {
  194. MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id);
  195. }
  196. return "float32";
  197. }
  198. return iter->second;
  199. }
  200. std::string Dtype2ShortType(const std::string &dtypes) {
  201. auto iter = dtype_shortdtype_map_.find(dtypes);
  202. if (iter != dtype_shortdtype_map_.end()) {
  203. return iter->second;
  204. } else {
  205. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
  206. }
  207. }
  208. size_t GetDtypeNbyte(const std::string &dtypes) {
  209. auto iter = dtype_nbyte_map.find(dtypes);
  210. if (iter != dtype_nbyte_map.end()) {
  211. return iter->second;
  212. } else {
  213. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
  214. }
  215. }
  216. bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
  217. size_t builder_idex, const std::vector<int64_t> &dyn_input_sizes,
  218. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  219. MS_EXCEPTION_IF_NULL(builder);
  220. std::vector<TypeId> inputs_device_type;
  221. std::vector<std::string> inputs_format;
  222. size_t dyn_input_idx = 0;
  223. size_t kernel_info_index = 0;
  224. MS_EXCEPTION_IF_NULL(inputs[0]);
  225. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  226. for (const auto &input : inputs) {
  227. MS_EXCEPTION_IF_NULL(input);
  228. std::string param_type = input->param_type();
  229. std::vector<std::string> dtypes = input->dtypes();
  230. std::vector<std::string> formats = input->formats();
  231. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  232. MS_LOG(DEBUG) << "Set input kernel builder info, dtyps size != formats size.";
  233. return false;
  234. }
  235. if (param_type == "dynamic") {
  236. if (dyn_input_sizes.empty()) {
  237. MS_LOG(DEBUG) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic";
  238. return false;
  239. }
  240. for (int64_t t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  241. kernel_info_index++;
  242. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  243. inputs_device_type.push_back(type_id);
  244. inputs_format.push_back(formats[builder_idex]);
  245. }
  246. dyn_input_idx++;
  247. } else if (param_type == "required") {
  248. kernel_info_index++;
  249. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  250. inputs_device_type.push_back(type_id);
  251. inputs_format.push_back(formats[builder_idex]);
  252. } else {
  253. if (kernel_info_index < real_input_num) {
  254. MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
  255. kernel_info_index++;
  256. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  257. inputs_device_type.push_back(type_id);
  258. inputs_format.push_back(formats[builder_idex]);
  259. }
  260. }
  261. }
  262. builder->SetInputsDeviceType(inputs_device_type);
  263. builder->SetInputsFormat(inputs_format);
  264. return true;
  265. }
  266. bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
  267. const size_t &real_output_num,
  268. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  269. // not now but in the next we need to support dynamic output case
  270. MS_EXCEPTION_IF_NULL(builder);
  271. size_t output_idx = 0;
  272. std::vector<TypeId> outputs_device_type;
  273. std::vector<std::string> outputs_format;
  274. MS_EXCEPTION_IF_NULL(outputs[0]);
  275. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  276. for (const auto &output : outputs) {
  277. MS_EXCEPTION_IF_NULL(output);
  278. if (output_idx >= real_output_num) {
  279. MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
  280. continue;
  281. }
  282. size_t output_num = 0;
  283. if (output->param_type() == "dynamic") {
  284. if (outputs.size() > 1) {
  285. MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
  286. }
  287. output_num = real_output_num;
  288. } else if (output->param_type() == "required") {
  289. output_num = 1;
  290. } else {
  291. if (output_idx < real_output_num) {
  292. MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
  293. output_num = 1;
  294. }
  295. }
  296. for (size_t i = 0; i < output_num; i++) {
  297. std::vector<std::string> dtypes = output->dtypes();
  298. std::vector<std::string> formats = output->formats();
  299. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  300. MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size.";
  301. return false;
  302. }
  303. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  304. outputs_device_type.push_back(type_id);
  305. outputs_format.push_back(formats[builder_idex]);
  306. output_idx++;
  307. }
  308. }
  309. builder->SetOutputsFormat(outputs_format);
  310. builder->SetOutputsDeviceType(outputs_device_type);
  311. return true;
  312. }
  313. void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
  314. const std::shared_ptr<const OpInfo> &op_info_ptr) {
  315. MS_EXCEPTION_IF_NULL(builder);
  316. MS_EXCEPTION_IF_NULL(op_info_ptr);
  317. auto imply_type = op_info_ptr->imply_type();
  318. builder->SetProcessor(processor);
  319. std::string fusion_type = op_info_ptr->fusion_type();
  320. auto iter = fusion_type_maps.find(fusion_type);
  321. if (iter != fusion_type_maps.end()) {
  322. builder->SetFusionType(iter->second);
  323. } else {
  324. if (imply_type == kAKG) {
  325. MS_EXCEPTION(NotExistsError) << "Illegal fusion type from dsl register:" << fusion_type;
  326. }
  327. }
  328. if (imply_type == kAKG) {
  329. builder->SetKernelType(AKG_KERNEL);
  330. } else if (imply_type == kAICPU) {
  331. builder->SetKernelType(AICPU_KERNEL);
  332. } else {
  333. builder->SetKernelType(TBE_KERNEL);
  334. }
  335. }
  336. bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
  337. std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
  338. MS_EXCEPTION_IF_NULL(kernel_node);
  339. MS_EXCEPTION_IF_NULL(kernel_info_list);
  340. size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  341. size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  342. std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
  343. std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
  344. std::vector<int64_t> dyn_input_sizes;
  345. auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
  346. MS_EXCEPTION_IF_NULL(primitive);
  347. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  348. dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr("dyn_input_sizes"));
  349. }
  350. if (inputs.size() > 0) {
  351. MS_EXCEPTION_IF_NULL(inputs[0]);
  352. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  353. for (size_t j = 0; j < kernel_info_cnt; j++) {
  354. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  355. MS_EXCEPTION_IF_NULL(builder);
  356. SetKernelBuildInfo(builder, processor, op_info_ptr);
  357. if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
  358. MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed.";
  359. return false;
  360. }
  361. if (outputs.size() > 0) {
  362. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  363. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed.";
  364. return false;
  365. }
  366. }
  367. kernel_info_list->push_back(builder->Build());
  368. }
  369. } else if (outputs.size() > 0) {
  370. MS_EXCEPTION_IF_NULL(outputs[0]);
  371. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  372. for (size_t j = 0; j < kernel_info_cnt; j++) {
  373. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  374. MS_EXCEPTION_IF_NULL(builder);
  375. SetKernelBuildInfo(builder, processor, op_info_ptr);
  376. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  377. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed.";
  378. return false;
  379. }
  380. kernel_info_list->push_back(builder->Build());
  381. }
  382. } else {
  383. if (processor == AICPU) {
  384. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  385. MS_EXCEPTION_IF_NULL(builder);
  386. SetKernelBuildInfo(builder, processor, op_info_ptr);
  387. kernel_info_list->push_back(builder->Build());
  388. }
  389. }
  390. return true;
  391. }
  392. void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
  393. char real_path[PATH_MAX] = {0};
  394. std::string path = base_path + json_name + kInfoSuffix;
  395. if (path.size() > PATH_MAX) {
  396. MS_LOG(DEBUG) << "file path " << path << " is too long.";
  397. return;
  398. }
  399. std::ofstream filewrite;
  400. filewrite.open(path);
  401. if (!filewrite.is_open()) {
  402. return;
  403. }
  404. filewrite << info << std::endl;
  405. filewrite.close();
  406. #if defined(_WIN32) || defined(_WIN64)
  407. if (nullptr == _fullpath(real_path, path.c_str(), PATH_MAX)) {
  408. MS_LOG(DEBUG) << "dir " << path << " does not exit.";
  409. return;
  410. }
  411. #else
  412. if (nullptr == realpath(path.c_str(), real_path)) {
  413. MS_LOG(DEBUG) << "dir " << path << " does not exit.";
  414. return;
  415. }
  416. #endif
  417. MS_LOG(INFO) << "real path is :" << real_path;
  418. if (chmod(real_path, S_IRUSR) == -1) {
  419. MS_LOG(DEBUG) << "modify file:" << real_path << " to read only fail.";
  420. }
  421. }
  422. Processor GetProcessor(const string &processor) {
  423. if (processor == kProcessorAiCore) return Processor::AICORE;
  424. if (processor == kProcessorAiCpu) return Processor::AICPU;
  425. if (processor == kProcessorCuda) return Processor::CUDA;
  426. MS_LOG(DEBUG) << "Unknown processor type.";
  427. return Processor::UNKNOWN;
  428. }
  429. std::string GetProcessor(const AnfNodePtr &anf_node) {
  430. MS_EXCEPTION_IF_NULL(anf_node);
  431. std::string device;
  432. switch (AnfAlgo::GetProcessor(anf_node)) {
  433. case Processor::AICORE:
  434. device = kProcessorAiCore;
  435. break;
  436. case Processor::AICPU:
  437. device = kProcessorAiCpu;
  438. break;
  439. case Processor::CUDA:
  440. device = kProcessorCuda;
  441. break;
  442. default:
  443. MS_LOG(DEBUG) << "Unknown processor type.";
  444. break;
  445. }
  446. return device;
  447. }
  448. bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b) {
  449. if (shape_a.size() != shape_b.size()) {
  450. return false;
  451. }
  452. for (size_t i = 0; i < shape_a.size(); ++i) {
  453. if (shape_a[i] != shape_b[i]) {
  454. return false;
  455. }
  456. }
  457. return true;
  458. }
  459. int Sign(float x) {
  460. if (x > 0) {
  461. return 1;
  462. }
  463. if (x < 0) {
  464. return -1;
  465. }
  466. return 0;
  467. }
  468. std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
  469. MS_EXCEPTION_IF_NULL(anf_node);
  470. if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
  471. MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs.";
  472. }
  473. auto cnode = anf_node->cast<CNodePtr>();
  474. if (cnode == nullptr) {
  475. return AnfAlgo::VisitKernel(anf_node, 0);
  476. } else {
  477. return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
  478. }
  479. }
  480. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
  481. const std::vector<AnfNodePtr> &input_list) {
  482. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
  483. for (size_t i = 0; i < input_list.size(); ++i) {
  484. auto const &input = input_list[i];
  485. MS_EXCEPTION_IF_NULL(input);
  486. bool found = false;
  487. // using NodeUsersMap = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, int>>>;
  488. auto mng = input->func_graph()->manager();
  489. MS_EXCEPTION_IF_NULL(mng);
  490. const NodeUsersMap &users = mng->node_users();
  491. auto input_users = users.find(input);
  492. if (input_users == users.end() || input_users->second.empty()) {
  493. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  494. << input->func_graph()->ToString() << "] has no users.";
  495. }
  496. for (auto const &input_user : input_users->second) {
  497. for (auto const &anf_node : node_list) {
  498. if (anf_node != input_user.first) {
  499. continue;
  500. }
  501. std::vector<int64_t> dyn_input_sizes;
  502. auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
  503. MS_EXCEPTION_IF_NULL(prim);
  504. if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
  505. dyn_input_sizes = GetValue<const std::vector<int64_t>>(prim->GetAttr(kAttrDynInputSizes));
  506. }
  507. if (dyn_input_sizes.empty()) {
  508. input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
  509. found = true;
  510. break;
  511. } else {
  512. int used_as_idx = input_user.second - 1;
  513. int accum_idx = 0;
  514. size_t dyn_i = 0;
  515. for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
  516. accum_idx += dyn_input_sizes[dyn_i];
  517. if (used_as_idx < accum_idx) {
  518. input_index.push_back(std::make_pair(
  519. anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
  520. break;
  521. }
  522. }
  523. if (dyn_i != dyn_input_sizes.size()) {
  524. found = true;
  525. break;
  526. }
  527. }
  528. }
  529. if (found) {
  530. break;
  531. }
  532. }
  533. if (!found) {
  534. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  535. << input->func_graph()->ToString() << "] found no related kernel info.";
  536. }
  537. }
  538. return input_index;
  539. }
  540. std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
  541. const std::vector<AnfNodePtr> &input_list,
  542. const std::vector<AnfNodePtr> &output_list) {
  543. std::vector<std::pair<AnfNodePtr, size_t>> output_index;
  544. for (size_t i = 0; i < output_list.size(); ++i) {
  545. auto const &output = output_list[i];
  546. MS_EXCEPTION_IF_NULL(output);
  547. bool found = false;
  548. auto pree_node = AnfAlgo::VisitKernel(output, 0);
  549. auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
  550. if (pos != std::end(node_list)) {
  551. output_index.push_back(pree_node);
  552. continue;
  553. }
  554. auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
  555. if (ret != std::end(input_list)) {
  556. output_index.push_back(std::make_pair(pree_node.first, 0));
  557. found = true;
  558. }
  559. if (!found) {
  560. MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
  561. << output->func_graph()->ToString() << "] found no related kernel info.";
  562. }
  563. }
  564. return output_index;
  565. }
  566. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
  567. MS_EXCEPTION_IF_NULL(node_list);
  568. MS_EXCEPTION_IF_NULL(func_graph);
  569. std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
  570. for (auto const &node : node_lists) {
  571. if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
  572. continue;
  573. }
  574. auto cnode = node->cast<CNodePtr>();
  575. MS_EXCEPTION_IF_NULL(cnode);
  576. if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
  577. node_list->push_back(node);
  578. }
  579. }
  580. }
  581. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
  582. std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
  583. MS_EXCEPTION_IF_NULL(func_graph);
  584. MS_EXCEPTION_IF_NULL(node_list);
  585. MS_EXCEPTION_IF_NULL(input_list);
  586. GetValidKernelNodes(func_graph, node_list);
  587. auto parameters = func_graph->parameters();
  588. input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
  589. GetFuncGraphOutputNodes(func_graph, output_list);
  590. }
  591. void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
  592. MS_EXCEPTION_IF_NULL(func_graph);
  593. MS_EXCEPTION_IF_NULL(output_list);
  594. auto func_output = func_graph->output();
  595. MS_EXCEPTION_IF_NULL(func_output);
  596. if (func_output->isa<CNode>()) {
  597. // multi output.
  598. auto cnode = func_output->cast<CNodePtr>();
  599. MS_EXCEPTION_IF_NULL(cnode);
  600. auto input0 = cnode->input(kAnfPrimitiveIndex);
  601. MS_EXCEPTION_IF_NULL(input0);
  602. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  603. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
  604. auto input_node = cnode->input(input_idx);
  605. MS_EXCEPTION_IF_NULL(input_node);
  606. output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
  607. }
  608. } else {
  609. // single output.
  610. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  611. }
  612. } else {
  613. // single output.
  614. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  615. }
  616. }
  617. bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
  618. MS_EXCEPTION_IF_NULL(anf_node);
  619. MS_EXCEPTION_IF_NULL(node_json);
  620. auto cnode = anf_node->cast<CNodePtr>();
  621. MS_EXCEPTION_IF_NULL(cnode);
  622. if (input_idx + 1 >= cnode->size()) {
  623. MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
  624. << cnode->inputs().size() << "][" << cnode->DebugString() << "]";
  625. }
  626. auto input_node = cnode->input(input_idx + 1);
  627. if (!IsValueNode<tensor::Tensor>(input_node)) {
  628. return false;
  629. }
  630. auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
  631. if (tensor == nullptr) {
  632. return false;
  633. }
  634. auto type_id = tensor->data_type();
  635. auto *data = tensor->data_c();
  636. MS_EXCEPTION_IF_NULL(data);
  637. if (tensor->DataSize() > 1) {
  638. // not const tensor.
  639. MS_LOG(WARNING) << "Not take value of tensor whose datasize greater than 1, [" << input_node->DebugString(2) << "]";
  640. return false;
  641. }
  642. if (type_id == kFloat32->type_id()) {
  643. float *val = static_cast<float *>(data);
  644. MS_EXCEPTION_IF_NULL(val);
  645. (*node_json)["value"] = val[0];
  646. MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "].";
  647. return true;
  648. } else if (type_id == kFloat16->type_id()) {
  649. float16 *val = static_cast<float16 *>(data);
  650. MS_EXCEPTION_IF_NULL(val);
  651. (*node_json)["value"] = static_cast<float>(val[0]);
  652. MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "].";
  653. return true;
  654. } else if (type_id == kInt32->type_id()) {
  655. int *val = static_cast<int *>(data);
  656. MS_EXCEPTION_IF_NULL(val);
  657. (*node_json)["value"] = val[0];
  658. MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "].";
  659. return true;
  660. }
  661. MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
  662. return false;
  663. }
  664. bool IsWeightBoundary(const AnfNodePtr &node) {
  665. if (node->isa<ValueNode>()) {
  666. return true;
  667. }
  668. if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  669. return true;
  670. }
  671. return false;
  672. }
  673. std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
  674. if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) &&
  675. AnfAlgo::GetInputTensorNum(cnode) != 1) {
  676. MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString()
  677. << "] is not single input or single output ";
  678. }
  679. std::vector<int64_t> axis;
  680. auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
  681. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  682. MS_EXCEPTION_IF_NULL(primitive);
  683. auto axis_attr = primitive->GetAttr(kAxis);
  684. if (axis_attr == nullptr) {
  685. MS_LOG(ERROR) << "This node doesn't have axie attr.";
  686. return std::vector<int64_t>();
  687. }
  688. std::vector<int64_t> axis_list;
  689. if (axis_attr->isa<Int64Imm>()) {
  690. (void)axis_list.emplace_back(GetValue<int64_t>(axis_attr));
  691. } else {
  692. axis_list = GetValue<std::vector<int64_t>>(axis_attr);
  693. }
  694. for (const auto &elem : axis_list) {
  695. if (elem < 0) {
  696. (void)axis.emplace_back(input_shape.size() + elem);
  697. } else {
  698. (void)axis.emplace_back(elem);
  699. }
  700. }
  701. AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
  702. return axis;
  703. }
  704. std::string GetProcessorStr(const AnfNodePtr &anf_node) {
  705. MS_EXCEPTION_IF_NULL(anf_node);
  706. std::string processor = kProcessorUnknown;
  707. auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
  708. MS_EXCEPTION_IF_NULL(kernel_info);
  709. auto build_info = kernel_info->select_kernel_build_info();
  710. // we may call this before kernel select.
  711. if (build_info == nullptr) {
  712. return processor;
  713. }
  714. switch (build_info->processor()) {
  715. case Processor::AICORE:
  716. processor = kProcessorAiCore;
  717. break;
  718. case Processor::AICPU:
  719. processor = kProcessorAiCpu;
  720. break;
  721. case Processor::CUDA:
  722. processor = kProcessorCuda;
  723. break;
  724. default:
  725. MS_LOG(ERROR) << "Unknown processor type.";
  726. break;
  727. }
  728. return processor;
  729. }
  730. Processor GetProcessorFromContext() {
  731. kernel::Processor processor = kernel::Processor::UNKNOWN;
  732. auto context_ptr = MsContext::GetInstance();
  733. MS_EXCEPTION_IF_NULL(context_ptr);
  734. auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  735. if (device_info == kGPUDevice) {
  736. processor = kernel::Processor::CUDA;
  737. } else if (device_info == kAscendDevice) {
  738. processor = kernel::Processor::AICORE;
  739. }
  740. return processor;
  741. }
  742. std::string GetStrProcessorFromContext() {
  743. auto processor = GetProcessorFromContext();
  744. string str_processor = kernel::kProcessorUnknown;
  745. if (processor == kernel::Processor::CUDA) {
  746. str_processor = kernel::kProcessorCuda;
  747. } else if (processor == kernel::Processor::AICORE) {
  748. str_processor = kernel::kProcessorAiCore;
  749. }
  750. return str_processor;
  751. }
  752. float Scaling(size_t in_size, size_t out_size, bool align_corners) {
  753. return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
  754. : in_size / static_cast<float>(out_size);
  755. }
  756. float ScaleGrid(const int x, const float scale) { return static_cast<float>(x) * scale; }
  757. void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale,
  758. CachedInterpolation *interpolation) {
  759. interpolation[out_size].lower = 0;
  760. interpolation[out_size].upper = 0;
  761. for (size_t i = 0; i <= out_size - 1; ++i) {
  762. const float in = ScaleGrid(i, scale);
  763. const float in_f = std::floor(in);
  764. interpolation[i].lower = std::max(static_cast<size_t>(in_f), static_cast<size_t>(0));
  765. interpolation[i].upper = std::min(static_cast<size_t>(std::ceil(in)), in_size - 1);
  766. interpolation[i].lerp = in - in_f;
  767. }
  768. }
  769. } // namespace kernel
  770. } // namespace mindspore