You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_utils.cc 31 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/kernel_compiler/common_utils.h"
  17. #include <unordered_map>
  18. #include <map>
  19. #include <iostream>
  20. #include <utility>
  21. #include <fstream>
  22. #include <algorithm>
  23. #include <thread>
  24. #include "nlohmann/json.hpp"
  25. #include "backend/session/anf_runtime_algorithm.h"
  26. #include "utils/file_utils.h"
  27. #include "utils/ms_utils.h"
  28. #include "ir/manager.h"
  29. #include "ir/meta_tensor.h"
  30. #include "base/core_ops.h"
  31. #include "ir/graph_utils.h"
  32. #include "utils/ms_context.h"
  33. #include "utils/trace_base.h"
  34. #include "mindspore/ccsrc/debug/common.h"
  35. namespace mindspore {
  36. namespace kernel {
  37. constexpr char kAxis[] = "axis";
  38. constexpr char kTypeInt32[] = "Int32";
  39. const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
  40. {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
  41. {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
  42. };
  43. const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
  44. {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
  45. {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
  46. {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
  47. {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
  48. {"complex64", sizeof(float) * 2}};
  49. // Define all patterns here for different schedule
  50. const std::unordered_map<FusionType, std::string> fusion_type_name_maps = {
  51. {FusionType::BN_UPDATE_GRAD, "bn_update_grad"},
  52. {FusionType::BN_GRAD_REDUCE, "bn_grad_reduce"},
  53. {FusionType::LAYER_NORM_GRAD, "layer_norm_grad"},
  54. {FusionType::L2LOSS_MUL_ADDN, "l2loss_mul_addn"},
  55. {FusionType::ELEMWISE, "ElemWise"},
  56. {FusionType::PURE_BROADCAST, "PureBroadcast"},
  57. {FusionType::COMMREDUCE, "CommReduce"},
  58. {FusionType::SEGMENT, "Segment"},
  59. {FusionType::INPLACE, "Inplace"},
  60. {FusionType::MATMUL, "Matmul"},
  61. {FusionType::MATMUL_V2, "Matmul_v2"},
  62. {FusionType::GEMM, "GEMM"},
  63. {FusionType::CONV, "Convolution"},
  64. {FusionType::CONV2D_BACKPROP_INPUT, "Conv2d_backprop_input"},
  65. {FusionType::CONV2D_BACKPROP_FILTER, "Conv2d_backprop_filter"},
  66. {FusionType::CONV3D_BACKPROP_INPUT, "Conv3d_backprop_input"},
  67. {FusionType::CONV3D_BACKPROP_FILTER, "Conv3d_backprop_filter"},
  68. {FusionType::CUBE_LAYER_NORM, "cube_layer_norm"},
  69. {FusionType::OPAQUE, "Opaque"},
  70. {FusionType::BN_REDUCE, "bn_reduce"},
  71. {FusionType::BN_UPDATE, "bn_update"},
  72. {FusionType::SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, "softmax_cross_entropy_with_logits"},
  73. {FusionType::L2_NORMALIZE, "l2_normalize"},
  74. {FusionType::SOFTMAX, "softmax_pattern"},
  75. {FusionType::L2_LOSS, "l2_loss"},
  76. {FusionType::ASCEND_QUANT, "quant"},
  77. {FusionType::ASCEND_DEQUANT, "dequant"},
  78. {FusionType::ASCEND_ANTI_QUANT, "anti_quant"},
  79. {FusionType::STRIDED_READ, "strided_read"},
  80. {FusionType::STRIDED_WRITE, "strided_write"},
  81. {FusionType::ASCEND_DEQUANT_S16, "dequant_s16"},
  82. {FusionType::ASCEND_REQUANT, "requant"},
  83. {FusionType::ASCEND_REQUANT_S16, "requant_s16"},
  84. {FusionType::MAX_POOL, "MaxPool"},
  85. {FusionType::DEPTHWISECONV, "DepthwiseConvolution"},
  86. {FusionType::CONV3D, "Conv3d"},
  87. {FusionType::POOL2D, "Pool2d"},
  88. {FusionType::POOL3D, "Pool3d"},
  89. {FusionType::READ_SELECT, "read_select"},
  90. {FusionType::WRITE_SELECT, "write_select"},
  91. {FusionType::COSINE_EMBEDDING_LOSS, "cosine_embedding_loss"},
  92. {FusionType::DILATION_PATTERN, "dilation"},
  93. {FusionType::BROAD_CAST, "Broadcast"},
  94. {FusionType::BATCH_MATMUL, "BatchMatmul"},
  95. {FusionType::CONFUSION_TRANSPOSE, "confusiontranspose"},
  96. {FusionType::UNKNOWN_FUSION_TYPE, ""}};
  97. std::string GetFusionNameByType(const kernel::FusionType &type) {
  98. auto iter = fusion_type_name_maps.find(type);
  99. if (iter == fusion_type_name_maps.end()) {
  100. MS_LOG(EXCEPTION) << "Illegal fusion type: " << type;
  101. }
  102. return iter->second;
  103. }
  104. FusionType GetFusionTypeByName(const std::string &name) {
  105. std::string fusion_name_upper = name;
  106. transform(fusion_name_upper.begin(), fusion_name_upper.end(), fusion_name_upper.begin(), ::toupper);
  107. auto iter =
  108. std::find_if(fusion_type_name_maps.begin(), fusion_type_name_maps.end(), [&fusion_name_upper](const auto &it) {
  109. std::string name_upper = it.second;
  110. transform(name_upper.begin(), name_upper.end(), name_upper.begin(), ::toupper);
  111. return fusion_name_upper == name_upper;
  112. });
  113. if (iter == fusion_type_name_maps.end()) {
  114. MS_LOG(EXCEPTION) << "Illegal fusion name: " << name;
  115. }
  116. return iter->first;
  117. }
  118. std::string GetCompilerCachePath() {
  119. static std::string config_path = "";
  120. if (config_path != "") {
  121. return config_path;
  122. }
  123. const char *value = ::getenv(kCOMPILER_CACHE_PATH);
  124. if (value == nullptr) {
  125. config_path = "./";
  126. } else {
  127. config_path = std::string(value);
  128. FileUtils::CreateNotExistDirs(config_path);
  129. if (config_path[config_path.length() - 1] != '/') {
  130. config_path += "/";
  131. }
  132. }
  133. return config_path;
  134. }
  135. void KernelMeta::Initialize() {
  136. auto config_path = GetCompilerCachePath();
  137. kernel_meta_path_ = config_path + std::string(kAkgKernelMeta);
  138. FileUtils::CreateNotExistDirs(kernel_meta_path_);
  139. initialized_ = true;
  140. }
  141. std::string KernelMeta::Search(const std::string &kernel_name) const {
  142. if (!initialized_) {
  143. return "";
  144. }
  145. auto iter = kernel_meta_map_.find(kernel_name);
  146. if (iter == kernel_meta_map_.end()) {
  147. return "";
  148. } else {
  149. return iter->second;
  150. }
  151. }
  152. bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
  153. if (!initialized_) {
  154. return false;
  155. }
  156. kernel_meta_map_[kernel_name] = kernel_json;
  157. return true;
  158. }
  159. bool CheckCache(const std::string &kernel_name) {
  160. // check cache.
  161. KernelMeta *bin_map = KernelMeta::GetInstance();
  162. if (bin_map == nullptr) {
  163. MS_LOG(DEBUG) << "Kernel cache is invalid, kernel_name: " << kernel_name;
  164. return false;
  165. }
  166. std::string kernel_json = bin_map->Search(kernel_name);
  167. bool ret = (!kernel_json.empty());
  168. if (ret) {
  169. MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
  170. } else {
  171. MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
  172. }
  173. return ret;
  174. }
  175. KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
  176. // search cache.
  177. KernelMeta *bin_map = KernelMeta::GetInstance();
  178. if (bin_map == nullptr) {
  179. MS_LOG(DEBUG) << "kernel cache is invalid, kernel_name: " << kernel_name;
  180. return nullptr;
  181. }
  182. std::string kernel_json = bin_map->Search(kernel_name);
  183. if (!kernel_json.empty()) {
  184. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  185. // just a tmp solution.
  186. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  187. MS_LOG(ERROR) << "Read cache json and bin file failed[" << kernel_json << "].";
  188. return nullptr;
  189. } else {
  190. return kernel_pack;
  191. }
  192. } else {
  193. MS_LOG(INFO) << "The cache kernel not found[" << kernel_name << "].";
  194. return nullptr;
  195. }
  196. }
  197. KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
  198. MS_LOG(INFO) << "Insert cache for kernel:" << kernel_name << ", processr:" << processor;
  199. KernelMeta *bin_map = KernelMeta::GetInstance();
  200. std::string kernel_json;
  201. kernel_json = bin_map->kernel_meta_path();
  202. (void)kernel_json.append(kernel_name).append(kJsonSuffix);
  203. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  204. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  205. MS_LOG(ERROR) << "Read json and bin file failed[" << kernel_json << "].";
  206. return nullptr;
  207. }
  208. if (bin_map == nullptr) {
  209. MS_LOG(DEBUG) << "Kernel cache is invalid, kernel name :" << kernel_name;
  210. return nullptr;
  211. }
  212. if (bin_map->Insert(kernel_name, kernel_json)) {
  213. MS_LOG(INFO) << "Kernel insert cache success[" << kernel_json << "], kernel name[" << kernel_name << "].";
  214. }
  215. return kernel_pack;
  216. }
  217. TypeId DtypeToTypeId(const std::string &dtypes) {
  218. if (dtypes == "float") {
  219. return TypeId::kNumberTypeFloat32;
  220. }
  221. if (dtypes.empty()) {
  222. return TypeId::kMetaTypeNone;
  223. }
  224. return StringToTypeId(dtypes);
  225. }
  226. std::string Dtype2ShortType(const std::string &dtype) {
  227. auto iter = dtype_shortdtype_map_.find(dtype);
  228. if (iter != dtype_shortdtype_map_.end()) {
  229. return iter->second;
  230. } else {
  231. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
  232. }
  233. }
  234. size_t GetDtypeNbyte(const std::string &dtype) {
  235. auto iter = dtype_nbyte_map.find(dtype);
  236. if (iter != dtype_nbyte_map.end()) {
  237. return iter->second;
  238. } else {
  239. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
  240. }
  241. }
  242. bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
  243. size_t builder_idex, const std::vector<int64_t> &dyn_input_sizes,
  244. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  245. MS_EXCEPTION_IF_NULL(builder);
  246. std::vector<TypeId> inputs_device_type;
  247. std::vector<std::string> inputs_format;
  248. size_t dyn_input_idx = 0;
  249. size_t kernel_info_index = 0;
  250. MS_EXCEPTION_IF_NULL(inputs[0]);
  251. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  252. for (const auto &input : inputs) {
  253. MS_EXCEPTION_IF_NULL(input);
  254. std::string param_type = input->param_type();
  255. std::vector<std::string> dtypes = input->dtypes();
  256. std::vector<std::string> formats = input->formats();
  257. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  258. MS_LOG(DEBUG) << "Set input kernel builder info failed, dtyps size != formats size. dtypes size: "
  259. << dtypes.size() << ", formats size : " << formats.size();
  260. return false;
  261. }
  262. if (param_type == "dynamic") {
  263. if (dyn_input_sizes.empty()) {
  264. MS_LOG(DEBUG) << "Set input kernel builder info failed, dyn_input_sizes's size is 0 when param_type is dynamic";
  265. return false;
  266. }
  267. for (int64_t t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  268. kernel_info_index++;
  269. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  270. inputs_device_type.push_back(type_id);
  271. inputs_format.push_back(formats[builder_idex]);
  272. }
  273. dyn_input_idx++;
  274. } else if (param_type == "required") {
  275. kernel_info_index++;
  276. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  277. inputs_device_type.push_back(type_id);
  278. inputs_format.push_back(formats[builder_idex]);
  279. } else {
  280. if (kernel_info_index < real_input_num) {
  281. MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
  282. kernel_info_index++;
  283. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  284. inputs_device_type.push_back(type_id);
  285. inputs_format.push_back(formats[builder_idex]);
  286. }
  287. }
  288. }
  289. builder->SetInputsDeviceType(inputs_device_type);
  290. builder->SetInputsFormat(inputs_format);
  291. return true;
  292. }
  293. bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
  294. const size_t &real_output_num,
  295. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  296. // not now but in the next we need to support dynamic output case
  297. MS_EXCEPTION_IF_NULL(builder);
  298. size_t output_idx = 0;
  299. std::vector<TypeId> outputs_device_type;
  300. std::vector<std::string> outputs_format;
  301. MS_EXCEPTION_IF_NULL(outputs[0]);
  302. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  303. for (const auto &output : outputs) {
  304. MS_EXCEPTION_IF_NULL(output);
  305. if (output_idx >= real_output_num) {
  306. MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
  307. continue;
  308. }
  309. size_t output_num = 0;
  310. if (output->param_type() == "dynamic") {
  311. if (outputs.size() > 1) {
  312. MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
  313. }
  314. output_num = real_output_num;
  315. } else if (output->param_type() == "required") {
  316. output_num = 1;
  317. } else {
  318. if (output_idx < real_output_num) {
  319. MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
  320. output_num = 1;
  321. }
  322. }
  323. for (size_t i = 0; i < output_num; i++) {
  324. std::vector<std::string> dtypes = output->dtypes();
  325. std::vector<std::string> formats = output->formats();
  326. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  327. MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size.";
  328. return false;
  329. }
  330. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  331. outputs_device_type.push_back(type_id);
  332. outputs_format.push_back(formats[builder_idex]);
  333. output_idx++;
  334. }
  335. }
  336. builder->SetOutputsFormat(outputs_format);
  337. builder->SetOutputsDeviceType(outputs_device_type);
  338. return true;
  339. }
  340. void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
  341. const std::shared_ptr<const OpInfo> &op_info_ptr) {
  342. MS_EXCEPTION_IF_NULL(builder);
  343. MS_EXCEPTION_IF_NULL(op_info_ptr);
  344. auto imply_type = op_info_ptr->imply_type();
  345. builder->SetProcessor(processor);
  346. std::string fusion_name = op_info_ptr->fusion_type();
  347. auto fusion_type = GetFusionTypeByName(fusion_name);
  348. builder->SetFusionType(fusion_type);
  349. if (imply_type == kAKG) {
  350. builder->SetKernelType(AKG_KERNEL);
  351. } else if (imply_type == kGPU) {
  352. builder->SetKernelType(GPU_KERNEL);
  353. } else if (imply_type == kCPU) {
  354. builder->SetKernelType(CPU_KERNEL);
  355. } else if (imply_type == kAICPU) {
  356. builder->SetKernelType(AICPU_KERNEL);
  357. } else {
  358. builder->SetKernelType(TBE_KERNEL);
  359. }
  360. }
  361. bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
  362. std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
  363. MS_EXCEPTION_IF_NULL(kernel_node);
  364. MS_EXCEPTION_IF_NULL(kernel_info_list);
  365. size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  366. size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  367. std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
  368. std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
  369. std::vector<int64_t> dyn_input_sizes;
  370. auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
  371. MS_EXCEPTION_IF_NULL(primitive);
  372. auto op_name = AnfAlgo::GetCNodeName(kernel_node);
  373. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  374. dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr("dyn_input_sizes"));
  375. }
  376. if (inputs.size() > 0) {
  377. if (inputs[0] == nullptr) {
  378. MS_LOG(EXCEPTION) << "Inputs[0] is nullptr. Op name: " << op_name;
  379. }
  380. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  381. for (size_t j = 0; j < kernel_info_cnt; j++) {
  382. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  383. MS_EXCEPTION_IF_NULL(builder);
  384. SetKernelBuildInfo(builder, processor, op_info_ptr);
  385. if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
  386. MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed. Op name: " << op_name;
  387. return false;
  388. }
  389. if (outputs.size() > 0) {
  390. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  391. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
  392. return false;
  393. }
  394. }
  395. kernel_info_list->push_back(builder->Build());
  396. }
  397. } else if (outputs.size() > 0) {
  398. if (outputs[0] == nullptr) {
  399. MS_LOG(EXCEPTION) << "Outputs[0] is nullptr. Op name: " << op_name;
  400. }
  401. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  402. for (size_t j = 0; j < kernel_info_cnt; j++) {
  403. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  404. MS_EXCEPTION_IF_NULL(builder);
  405. SetKernelBuildInfo(builder, processor, op_info_ptr);
  406. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  407. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
  408. return false;
  409. }
  410. kernel_info_list->push_back(builder->Build());
  411. }
  412. } else {
  413. if (processor == AICPU) {
  414. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  415. MS_EXCEPTION_IF_NULL(builder);
  416. SetKernelBuildInfo(builder, processor, op_info_ptr);
  417. kernel_info_list->push_back(builder->Build());
  418. }
  419. }
  420. return true;
  421. }
  422. void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
  423. std::string path = base_path + json_name + kInfoSuffix;
  424. auto realpath = Common::CreatePrefixPath(path);
  425. if (!realpath.has_value()) {
  426. MS_LOG(ERROR) << "Get real path failed, path=" << path;
  427. return;
  428. }
  429. ChangeFileMode(realpath.value(), S_IWUSR);
  430. std::ofstream filewrite(realpath.value());
  431. if (!filewrite.is_open()) {
  432. MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!";
  433. return;
  434. }
  435. filewrite << info << std::endl;
  436. filewrite.close();
  437. ChangeFileMode(realpath.value(), S_IRUSR);
  438. }
  439. Processor GetProcessor(const string &processor) {
  440. if (processor == kProcessorAiCore) return Processor::AICORE;
  441. if (processor == kProcessorAiCpu) return Processor::AICPU;
  442. if (processor == kProcessorCuda) return Processor::CUDA;
  443. MS_LOG(DEBUG) << "Unknown processor type.";
  444. return Processor::UNKNOWN;
  445. }
  446. std::string GetProcessor(const AnfNodePtr &anf_node) {
  447. MS_EXCEPTION_IF_NULL(anf_node);
  448. std::string device;
  449. switch (AnfAlgo::GetProcessor(anf_node)) {
  450. case Processor::AICORE:
  451. device = kProcessorAiCore;
  452. break;
  453. case Processor::AICPU:
  454. device = kProcessorAiCpu;
  455. break;
  456. case Processor::CUDA:
  457. device = kProcessorCuda;
  458. break;
  459. default:
  460. MS_LOG(DEBUG) << "Unknown processor type.";
  461. break;
  462. }
  463. return device;
  464. }
  465. bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b) {
  466. if (shape_a.size() != shape_b.size()) {
  467. return false;
  468. }
  469. for (size_t i = 0; i < shape_a.size(); ++i) {
  470. if (shape_a[i] != shape_b[i]) {
  471. return false;
  472. }
  473. }
  474. return true;
  475. }
  476. int Sign(float x) {
  477. if (x > 0) {
  478. return 1;
  479. }
  480. if (x < 0) {
  481. return -1;
  482. }
  483. return 0;
  484. }
  485. int GetReductionInt(const std::string &reduction) {
  486. if (reduction == "none") {
  487. return 0;
  488. } else if (reduction == "sum") {
  489. return 2;
  490. } else {
  491. // reduction = 'mean'
  492. return 1;
  493. }
  494. }
  495. std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
  496. const std::vector<AnfNodePtr> &input_list,
  497. const std::vector<AnfNodePtr> &output_list) {
  498. std::vector<std::pair<AnfNodePtr, size_t>> output_index;
  499. for (size_t i = 0; i < output_list.size(); ++i) {
  500. auto const &output = output_list[i];
  501. MS_EXCEPTION_IF_NULL(output);
  502. bool found = false;
  503. auto pree_node = AnfAlgo::VisitKernel(output, 0);
  504. auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
  505. if (pos != std::end(node_list)) {
  506. output_index.push_back(pree_node);
  507. continue;
  508. }
  509. auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
  510. if (ret != std::end(input_list)) {
  511. output_index.push_back(std::make_pair(pree_node.first, 0));
  512. found = true;
  513. }
  514. if (!found) {
  515. MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
  516. << output->func_graph()->ToString() << "] found no related kernel info.";
  517. }
  518. }
  519. return output_index;
  520. }
  521. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
  522. MS_EXCEPTION_IF_NULL(node_list);
  523. MS_EXCEPTION_IF_NULL(func_graph);
  524. std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
  525. for (auto const &node : node_lists) {
  526. if (!AnfUtils::IsRealKernel(node) || !node->isa<CNode>()) {
  527. continue;
  528. }
  529. auto cnode = node->cast<CNodePtr>();
  530. MS_EXCEPTION_IF_NULL(cnode);
  531. if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
  532. node_list->push_back(node);
  533. }
  534. }
  535. }
  536. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
  537. std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
  538. MS_EXCEPTION_IF_NULL(func_graph);
  539. MS_EXCEPTION_IF_NULL(node_list);
  540. MS_EXCEPTION_IF_NULL(input_list);
  541. GetValidKernelNodes(func_graph, node_list);
  542. auto parameters = func_graph->parameters();
  543. input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
  544. GetFuncGraphOutputNodes(func_graph, output_list);
  545. }
  546. void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
  547. MS_EXCEPTION_IF_NULL(func_graph);
  548. MS_EXCEPTION_IF_NULL(output_list);
  549. auto func_output = func_graph->output();
  550. MS_EXCEPTION_IF_NULL(func_output);
  551. if (func_output->isa<CNode>()) {
  552. // multi output.
  553. auto cnode = func_output->cast<CNodePtr>();
  554. MS_EXCEPTION_IF_NULL(cnode);
  555. auto input0 = cnode->input(kAnfPrimitiveIndex);
  556. MS_EXCEPTION_IF_NULL(input0);
  557. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  558. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
  559. auto input_node = cnode->input(input_idx);
  560. MS_EXCEPTION_IF_NULL(input_node);
  561. if (input_node->isa<CNode>() && AnfAlgo::GetInputTensorNum(input_node) == 0) {
  562. continue;
  563. }
  564. output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
  565. }
  566. } else {
  567. // single output.
  568. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  569. }
  570. } else {
  571. // single output.
  572. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  573. }
  574. }
  575. bool IsWeightBoundary(const AnfNodePtr &node) {
  576. if (node->isa<ValueNode>()) {
  577. return true;
  578. }
  579. if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  580. return true;
  581. }
  582. return false;
  583. }
  584. std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
  585. if (AnfAlgo::GetInputTensorNum(cnode) != 1 || AnfAlgo::GetOutputTensorNum(cnode) != 1) {
  586. MS_LOG(EXCEPTION) << "The reduce node [" << cnode->DebugString() << "] is not single input or single output."
  587. << trace::DumpSourceLines(cnode);
  588. }
  589. std::vector<int64_t> axis;
  590. auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
  591. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  592. MS_EXCEPTION_IF_NULL(primitive);
  593. auto axis_attr = primitive->GetAttr(kAxis);
  594. if (axis_attr == nullptr) {
  595. MS_LOG(ERROR) << "This node doesn't have axis attr. Node info [" << cnode->DebugString() << "]";
  596. return std::vector<int64_t>();
  597. }
  598. std::vector<int64_t> axis_list;
  599. if (axis_attr->isa<Int64Imm>()) {
  600. (void)axis_list.emplace_back(GetValue<int64_t>(axis_attr));
  601. } else {
  602. axis_list = GetValue<std::vector<int64_t>>(axis_attr);
  603. }
  604. for (const auto &elem : axis_list) {
  605. if (elem < 0) {
  606. (void)axis.emplace_back(input_shape.size() + elem);
  607. } else {
  608. (void)axis.emplace_back(elem);
  609. }
  610. }
  611. AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
  612. return axis;
  613. }
  614. std::string GetProcessorStr(const AnfNodePtr &anf_node) {
  615. MS_EXCEPTION_IF_NULL(anf_node);
  616. std::string processor = kProcessorUnknown;
  617. auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
  618. MS_EXCEPTION_IF_NULL(kernel_info);
  619. auto build_info = kernel_info->select_kernel_build_info();
  620. // we may call this before kernel select.
  621. if (build_info == nullptr) {
  622. return processor;
  623. }
  624. switch (build_info->processor()) {
  625. case Processor::AICORE:
  626. processor = kProcessorAiCore;
  627. break;
  628. case Processor::AICPU:
  629. processor = kProcessorAiCpu;
  630. break;
  631. case Processor::CUDA:
  632. processor = kProcessorCuda;
  633. break;
  634. default:
  635. MS_LOG(ERROR) << "Unknown processor type.";
  636. break;
  637. }
  638. return processor;
  639. }
  640. Processor GetProcessorFromContext() {
  641. kernel::Processor processor = kernel::Processor::UNKNOWN;
  642. auto context_ptr = MsContext::GetInstance();
  643. MS_EXCEPTION_IF_NULL(context_ptr);
  644. auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  645. if (device_info == kGPUDevice) {
  646. processor = kernel::Processor::CUDA;
  647. } else if (device_info == kAscendDevice) {
  648. processor = kernel::Processor::AICORE;
  649. } else if (device_info == kCPUDevice) {
  650. processor = kernel::Processor::CPU;
  651. }
  652. return processor;
  653. }
  654. std::string GetStrProcessorFromContext() {
  655. auto processor = GetProcessorFromContext();
  656. string str_processor = kernel::kProcessorUnknown;
  657. if (processor == kernel::Processor::CUDA) {
  658. str_processor = kernel::kProcessorCuda;
  659. } else if (processor == kernel::Processor::AICORE) {
  660. str_processor = kernel::kProcessorAiCore;
  661. } else if (processor == kernel::Processor::CPU) {
  662. str_processor = kernel::kProcessorCpu;
  663. }
  664. return str_processor;
  665. }
  666. float Scaling(size_t in_size, size_t out_size, bool align_corners) {
  667. return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
  668. : in_size / static_cast<float>(out_size);
  669. }
  670. float ScaleGrid(const int x, const float scale) { return static_cast<float>(x) * scale; }
  671. void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale,
  672. CachedInterpolation *interpolation) {
  673. interpolation[out_size].lower = 0;
  674. interpolation[out_size].upper = 0;
  675. for (size_t i = 0; i <= out_size - 1; ++i) {
  676. const float in = ScaleGrid(i, scale);
  677. const float in_f = std::floor(in);
  678. interpolation[i].lower = std::max(static_cast<size_t>(in_f), static_cast<size_t>(0));
  679. interpolation[i].upper = std::min(static_cast<size_t>(std::ceil(in)), in_size - 1);
  680. interpolation[i].lerp = in - in_f;
  681. }
  682. }
  683. bool GetShapeSize(const std::vector<size_t> &shape, const TypePtr &type_ptr, int64_t *size_i) {
  684. MS_EXCEPTION_IF_NULL(type_ptr);
  685. size_t type_byte = GetTypeByte(type_ptr);
  686. if (type_byte == 0) {
  687. return false;
  688. }
  689. for (size_t j = 0; j < shape.size(); j++) {
  690. size_i[0] = LongMulWithOverflowCheck(size_i[0], static_cast<int64_t>(shape[j]));
  691. }
  692. size_i[0] = LongMulWithOverflowCheck(size_i[0], SizeToInt(type_byte));
  693. return true;
  694. }
  695. void CastShapeSizeToLong(const std::vector<size_t> &shape, std::vector<int64_t> *long_shape) {
  696. MS_EXCEPTION_IF_NULL(long_shape);
  697. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(*long_shape), SizeToLong);
  698. }
  699. void CheckSliceValid(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
  700. const std::vector<int64_t> &step, const std::vector<int64_t> &input_shape) {
  701. if (start.size() != stop.size() || start.size() != step.size() || start.size() > input_shape.size()) {
  702. MS_LOG(EXCEPTION)
  703. << "TensorCopySlices requires the length of begin, stride and end must be equal and less than input dimension.";
  704. }
  705. size_t size = start.size();
  706. for (size_t i = 0; i < size; ++i) {
  707. if (stop[i] <= start[i]) {
  708. MS_LOG(EXCEPTION) << "Invalid slice: (" << start[i] << ", " << stop[i] << " ," << step[i] << ")";
  709. }
  710. // Operator need to be generalized in the future. Only support to copy continuous memory now.
  711. if (step[i] != 1) {
  712. MS_LOG(EXCEPTION) << "The element in step only support 1, but got:" << step;
  713. }
  714. }
  715. size_t slice_pos = size;
  716. for (size_t i = 0; i < size; ++i) {
  717. if (stop[i] - start[i] > 1) {
  718. slice_pos = i;
  719. break;
  720. }
  721. }
  722. for (size_t i = slice_pos + 1; i < size; ++i) {
  723. if (stop[i] - start[i] != input_shape[i]) {
  724. MS_LOG(EXCEPTION) << "Only support copy continuous memory now. For example tensor[0, 0:100] is fine, "
  725. "but tensor[0:100, 0] is not supported.";
  726. }
  727. }
  728. }
  729. size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
  730. const std::vector<int64_t> &stop) {
  731. for (size_t i = 0; i < start.size(); ++i) {
  732. if (stop[i] - start[i] != 1) {
  733. return SizetMulWithOverflowCheck(LongToSize(stop[i] - start[i]), LongToSize(dim_offset[i]));
  734. }
  735. }
  736. return LongToSize(dim_offset[start.size() - 1]);
  737. }
  738. std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape) {
  739. std::vector<int64_t> dim_offset;
  740. int64_t offset = 1;
  741. for (auto iter = input_shape.rbegin(); iter != input_shape.rend(); ++iter) {
  742. dim_offset.push_back(offset);
  743. offset = offset * (*iter);
  744. }
  745. std::reverse(dim_offset.begin(), dim_offset.end());
  746. return dim_offset;
  747. }
  748. size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
  749. const std::vector<int64_t> &dim_offset) {
  750. size_t size = start.size();
  751. size_t offset = 0;
  752. for (size_t i = 0; i < size; ++i) {
  753. offset += SizetMulWithOverflowCheck(LongToSize(dim_offset[i]), LongToSize(start[i]));
  754. if (stop[i] - start[i] != 1) {
  755. break;
  756. }
  757. }
  758. return offset;
  759. }
  760. size_t UnitSizeInBytes(const mindspore::TypeId &t) {
  761. size_t bytes = 0;
  762. switch (t) {
  763. case kNumberTypeBool:
  764. case kNumberTypeInt8:
  765. case kNumberTypeUInt8:
  766. bytes = sizeof(int8_t);
  767. break;
  768. case kNumberTypeInt16:
  769. case kNumberTypeUInt16:
  770. case kNumberTypeFloat16:
  771. bytes = sizeof(int16_t);
  772. break;
  773. case kNumberTypeInt:
  774. case kNumberTypeUInt:
  775. case kNumberTypeInt32:
  776. case kNumberTypeUInt32:
  777. case kNumberTypeFloat:
  778. case kNumberTypeFloat32:
  779. bytes = sizeof(int32_t);
  780. break;
  781. case kNumberTypeUInt64:
  782. case kNumberTypeInt64:
  783. case kNumberTypeFloat64:
  784. bytes = sizeof(int64_t);
  785. break;
  786. default:
  787. MS_LOG(EXCEPTION) << "Invalid types " << t;
  788. break;
  789. }
  790. return bytes;
  791. }
  792. } // namespace kernel
  793. } // namespace mindspore