You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_utils.cc 42 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/common_utils.h"
  17. #include <unordered_map>
  18. #include <map>
  19. #include <bitset>
  20. #include <iostream>
  21. #include <utility>
  22. #include <fstream>
  23. #include <algorithm>
  24. #include <thread>
  25. #include "nlohmann/json.hpp"
  26. #include "backend/common/session/anf_runtime_algorithm.h"
  27. #include "include/common/utils/anfalgo.h"
  28. #include "utils/file_utils.h"
  29. #include "utils/ms_utils.h"
  30. #include "ir/manager.h"
  31. #include "ir/meta_tensor.h"
  32. #include "base/core_ops.h"
  33. #include "ir/graph_utils.h"
  34. #include "utils/ms_context.h"
  35. #include "utils/trace_base.h"
  36. #include "mindspore/ccsrc/debug/common.h"
  37. namespace mindspore {
  38. namespace kernel {
  39. constexpr char kAxis[] = "axis";
  40. constexpr char kTypeInt32[] = "Int32";
  41. constexpr auto kStridedSliceMaxDims = 8;
  42. const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
  43. {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
  44. {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
  45. };
  46. const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
  47. {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
  48. {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
  49. {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
  50. {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
  51. {"complex64", sizeof(float) * 2}};
  52. // Define all patterns here for different schedule
  53. const std::unordered_map<FusionType, std::string> fusion_type_name_maps = {
  54. {FusionType::BN_UPDATE_GRAD, "bn_update_grad"},
  55. {FusionType::BN_GRAD_REDUCE, "bn_grad_reduce"},
  56. {FusionType::LAYER_NORM_GRAD, "layer_norm_grad"},
  57. {FusionType::L2LOSS_MUL_ADDN, "l2loss_mul_addn"},
  58. {FusionType::ELEMWISE, "ElemWise"},
  59. {FusionType::PURE_BROADCAST, "PureBroadcast"},
  60. {FusionType::COMMREDUCE, "CommReduce"},
  61. {FusionType::SEGMENT, "Segment"},
  62. {FusionType::INPLACE, "Inplace"},
  63. {FusionType::MATMUL, "Matmul"},
  64. {FusionType::MATMUL_V2, "Matmul_v2"},
  65. {FusionType::GEMM, "GEMM"},
  66. {FusionType::CONV, "Convolution"},
  67. {FusionType::CONV2D_BACKPROP_INPUT, "Conv2d_backprop_input"},
  68. {FusionType::CONV2D_BACKPROP_FILTER, "Conv2d_backprop_filter"},
  69. {FusionType::CONV3D_BACKPROP_INPUT, "Conv3d_backprop_input"},
  70. {FusionType::CONV3D_BACKPROP_FILTER, "Conv3d_backprop_filter"},
  71. {FusionType::CUBE_LAYER_NORM, "cube_layer_norm"},
  72. {FusionType::OPAQUE, "Opaque"},
  73. {FusionType::BN_REDUCE, "bn_reduce"},
  74. {FusionType::BN_UPDATE, "bn_update"},
  75. {FusionType::SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, "softmax_cross_entropy_with_logits"},
  76. {FusionType::L2_NORMALIZE, "l2_normalize"},
  77. {FusionType::SOFTMAX, "softmax_pattern"},
  78. {FusionType::L2_LOSS, "l2_loss"},
  79. {FusionType::ASCEND_QUANT, "quant"},
  80. {FusionType::ASCEND_DEQUANT, "dequant"},
  81. {FusionType::ASCEND_ANTI_QUANT, "anti_quant"},
  82. {FusionType::STRIDED_READ, "strided_read"},
  83. {FusionType::STRIDED_WRITE, "strided_write"},
  84. {FusionType::ASCEND_DEQUANT_S16, "dequant_s16"},
  85. {FusionType::ASCEND_REQUANT, "requant"},
  86. {FusionType::ASCEND_REQUANT_S16, "requant_s16"},
  87. {FusionType::MAX_POOL, "MaxPool"},
  88. {FusionType::DEPTHWISECONV, "DepthwiseConvolution"},
  89. {FusionType::CONV3D, "Conv3d"},
  90. {FusionType::POOL2D, "Pool2d"},
  91. {FusionType::POOL3D, "Pool3d"},
  92. {FusionType::READ_SELECT, "read_select"},
  93. {FusionType::WRITE_SELECT, "write_select"},
  94. {FusionType::COSINE_EMBEDDING_LOSS, "cosine_embedding_loss"},
  95. {FusionType::DILATION_PATTERN, "dilation"},
  96. {FusionType::BROAD_CAST, "Broadcast"},
  97. {FusionType::BATCH_MATMUL, "BatchMatmul"},
  98. {FusionType::CONFUSION_TRANSPOSE, "confusiontranspose"},
  99. {FusionType::DROPOUT_DOMASKV3D, "DropOutDoMaskV3D"},
  100. {FusionType::UNKNOWN_FUSION_TYPE, ""}};
  101. std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> GetAlignments(const std::string &alignment) {
  102. auto alignment_iter = MatrixDiag::AlignmentMap.find(alignment);
  103. if (alignment_iter == MatrixDiag::AlignmentMap.end()) {
  104. MS_LOG(EXCEPTION) << "For current kernel, input alignment is invalid: " << alignment
  105. << ". please limit it to {RIGHT_LEFT, LEFT_RIGHT, RIGHT_RIGHT, LEFT_LEFT}";
  106. }
  107. return alignment_iter->second;
  108. }
  109. int CalDiagOffset(int diag_index, int max_diag_len, int inner_rows, int inner_cols,
  110. const std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> &alignment) {
  111. bool right_align_super_diagonal = (alignment.first == MatrixDiag::RIGHT);
  112. bool right_align_sub_diagonal = (alignment.second == MatrixDiag::RIGHT);
  113. const bool right_align =
  114. (diag_index >= 0 && right_align_super_diagonal) || (diag_index <= 0 && right_align_sub_diagonal);
  115. const int diag_len = std::min(inner_rows + std::min(0, diag_index), inner_cols - std::max(0, diag_index));
  116. const int offset = (right_align) ? (max_diag_len - diag_len) : 0;
  117. return offset;
  118. }
  119. std::string GetFusionNameByType(const kernel::FusionType &type) {
  120. auto iter = fusion_type_name_maps.find(type);
  121. if (iter == fusion_type_name_maps.end()) {
  122. MS_LOG(EXCEPTION) << "Illegal fusion type: " << type;
  123. }
  124. return iter->second;
  125. }
  126. FusionType GetFusionTypeByName(const std::string &name) {
  127. std::string fusion_name_upper = name;
  128. transform(fusion_name_upper.begin(), fusion_name_upper.end(), fusion_name_upper.begin(), ::toupper);
  129. auto iter =
  130. std::find_if(fusion_type_name_maps.begin(), fusion_type_name_maps.end(), [&fusion_name_upper](const auto &it) {
  131. std::string name_upper = it.second;
  132. transform(name_upper.begin(), name_upper.end(), name_upper.begin(), ::toupper);
  133. return fusion_name_upper == name_upper;
  134. });
  135. if (iter == fusion_type_name_maps.end()) {
  136. MS_LOG(EXCEPTION) << "Illegal fusion name: " << name;
  137. }
  138. return iter->first;
  139. }
  140. std::string GetCompilerCachePath() {
  141. static std::string config_path = "";
  142. if (config_path != "") {
  143. return config_path;
  144. }
  145. const char *value = ::getenv(kCOMPILER_CACHE_PATH);
  146. if (value == nullptr) {
  147. config_path = "./";
  148. } else {
  149. config_path = std::string(value);
  150. FileUtils::CreateNotExistDirs(config_path);
  151. if (config_path[config_path.length() - 1] != '/') {
  152. config_path += "/";
  153. }
  154. }
  155. return config_path;
  156. }
  157. void KernelMeta::Initialize() {
  158. auto config_path = GetCompilerCachePath();
  159. kernel_meta_path_ = config_path + std::string(kAkgKernelMeta);
  160. FileUtils::CreateNotExistDirs(kernel_meta_path_);
  161. initialized_ = true;
  162. }
  163. std::string KernelMeta::Search(const std::string &kernel_name) const {
  164. if (!initialized_) {
  165. return "";
  166. }
  167. auto iter = kernel_meta_map_.find(kernel_name);
  168. if (iter == kernel_meta_map_.end()) {
  169. return "";
  170. } else {
  171. return iter->second;
  172. }
  173. }
  174. bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
  175. if (!initialized_) {
  176. return false;
  177. }
  178. kernel_meta_map_[kernel_name] = kernel_json;
  179. return true;
  180. }
  181. bool CheckCache(const std::string &kernel_name) {
  182. // check cache.
  183. KernelMeta *bin_map = KernelMeta::GetInstance();
  184. if (bin_map == nullptr) {
  185. MS_LOG(DEBUG) << "Kernel cache is invalid, kernel_name: " << kernel_name;
  186. return false;
  187. }
  188. std::string kernel_json = bin_map->Search(kernel_name);
  189. bool ret = (!kernel_json.empty());
  190. if (ret) {
  191. MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
  192. } else {
  193. MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
  194. }
  195. return ret;
  196. }
  197. KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
  198. // search cache.
  199. KernelMeta *bin_map = KernelMeta::GetInstance();
  200. if (bin_map == nullptr) {
  201. MS_LOG(DEBUG) << "kernel cache is invalid, kernel_name: " << kernel_name;
  202. return nullptr;
  203. }
  204. std::string kernel_json = bin_map->Search(kernel_name);
  205. if (!kernel_json.empty()) {
  206. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  207. // just a tmp solution.
  208. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  209. MS_LOG(ERROR) << "Read cache json and bin file failed[" << kernel_json << "].";
  210. return nullptr;
  211. } else {
  212. return kernel_pack;
  213. }
  214. } else {
  215. MS_LOG(INFO) << "The cache kernel not found[" << kernel_name << "].";
  216. return nullptr;
  217. }
  218. }
  219. KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
  220. MS_LOG(INFO) << "Insert cache for kernel:" << kernel_name << ", processr:" << processor;
  221. KernelMeta *bin_map = KernelMeta::GetInstance();
  222. std::string kernel_json = bin_map->kernel_meta_path();
  223. (void)kernel_json.append(kernel_name).append(kJsonSuffix);
  224. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  225. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  226. MS_LOG(ERROR) << "Read json and bin file failed[" << kernel_json << "].";
  227. return nullptr;
  228. }
  229. if (bin_map == nullptr) {
  230. MS_LOG(DEBUG) << "Kernel cache is invalid, kernel name :" << kernel_name;
  231. return nullptr;
  232. }
  233. if (bin_map->Insert(kernel_name, kernel_json)) {
  234. MS_LOG(INFO) << "Kernel insert cache success[" << kernel_json << "], kernel name[" << kernel_name << "].";
  235. }
  236. return kernel_pack;
  237. }
  238. TypeId DtypeToTypeId(const std::string &dtypes) {
  239. if (dtypes == "float") {
  240. return TypeId::kNumberTypeFloat32;
  241. }
  242. if (dtypes.empty()) {
  243. return TypeId::kMetaTypeNone;
  244. }
  245. return StringToTypeId(dtypes);
  246. }
  247. std::string Dtype2ShortType(const std::string &dtype) {
  248. auto iter = dtype_shortdtype_map_.find(dtype);
  249. if (iter != dtype_shortdtype_map_.end()) {
  250. return iter->second;
  251. } else {
  252. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
  253. }
  254. }
  255. size_t GetDtypeNbyte(const std::string &dtype) {
  256. auto iter = dtype_nbyte_map.find(dtype);
  257. if (iter != dtype_nbyte_map.end()) {
  258. return iter->second;
  259. } else {
  260. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
  261. }
  262. }
  263. bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
  264. size_t builder_idex, const std::vector<int64_t> &dyn_input_sizes,
  265. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  266. MS_EXCEPTION_IF_NULL(builder);
  267. std::vector<TypeId> inputs_device_type;
  268. std::vector<std::string> inputs_format;
  269. size_t dyn_input_idx = 0;
  270. size_t kernel_info_index = 0;
  271. MS_EXCEPTION_IF_NULL(inputs[0]);
  272. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  273. for (const auto &input : inputs) {
  274. MS_EXCEPTION_IF_NULL(input);
  275. std::string param_type = input->param_type();
  276. std::vector<std::string> dtypes = input->dtypes();
  277. std::vector<std::string> formats = input->formats();
  278. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  279. MS_LOG(DEBUG) << "Set input kernel builder info failed, dtyps size != formats size. dtypes size: "
  280. << dtypes.size() << ", formats size : " << formats.size();
  281. return false;
  282. }
  283. if (param_type == "dynamic") {
  284. if (dyn_input_sizes.empty()) {
  285. MS_LOG(DEBUG) << "Set input kernel builder info failed, dyn_input_sizes's size is 0 when param_type is dynamic";
  286. return false;
  287. }
  288. for (int64_t t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  289. kernel_info_index++;
  290. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  291. inputs_device_type.push_back(type_id);
  292. inputs_format.push_back(formats[builder_idex]);
  293. }
  294. dyn_input_idx++;
  295. } else if (param_type == "required") {
  296. kernel_info_index++;
  297. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  298. inputs_device_type.push_back(type_id);
  299. inputs_format.push_back(formats[builder_idex]);
  300. } else {
  301. if (kernel_info_index < real_input_num) {
  302. MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
  303. kernel_info_index++;
  304. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  305. inputs_device_type.push_back(type_id);
  306. inputs_format.push_back(formats[builder_idex]);
  307. }
  308. }
  309. }
  310. builder->SetInputsDeviceType(inputs_device_type);
  311. builder->SetInputsFormat(inputs_format);
  312. return true;
  313. }
  314. bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
  315. const size_t &real_output_num,
  316. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  317. // not now but in the next we need to support dynamic output case
  318. MS_EXCEPTION_IF_NULL(builder);
  319. size_t output_idx = 0;
  320. std::vector<TypeId> outputs_device_type;
  321. std::vector<std::string> outputs_format;
  322. MS_EXCEPTION_IF_NULL(outputs[0]);
  323. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  324. for (const auto &output : outputs) {
  325. MS_EXCEPTION_IF_NULL(output);
  326. if (output_idx >= real_output_num) {
  327. MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
  328. continue;
  329. }
  330. size_t output_num = 0;
  331. if (output->param_type() == "dynamic") {
  332. if (outputs.size() > 1) {
  333. MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
  334. }
  335. output_num = real_output_num;
  336. } else if (output->param_type() == "required") {
  337. output_num = 1;
  338. } else {
  339. if (output_idx < real_output_num) {
  340. MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
  341. output_num = 1;
  342. }
  343. }
  344. for (size_t i = 0; i < output_num; i++) {
  345. std::vector<std::string> dtypes = output->dtypes();
  346. std::vector<std::string> formats = output->formats();
  347. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  348. MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size.";
  349. return false;
  350. }
  351. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  352. outputs_device_type.push_back(type_id);
  353. outputs_format.push_back(formats[builder_idex]);
  354. output_idx++;
  355. }
  356. }
  357. builder->SetOutputsFormat(outputs_format);
  358. builder->SetOutputsDeviceType(outputs_device_type);
  359. return true;
  360. }
  361. void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
  362. const std::shared_ptr<const OpInfo> &op_info_ptr) {
  363. MS_EXCEPTION_IF_NULL(builder);
  364. MS_EXCEPTION_IF_NULL(op_info_ptr);
  365. auto imply_type = op_info_ptr->imply_type();
  366. builder->SetProcessor(processor);
  367. std::string fusion_name = op_info_ptr->fusion_type();
  368. auto fusion_type = GetFusionTypeByName(fusion_name);
  369. builder->SetFusionType(fusion_type);
  370. if (imply_type == kAKG) {
  371. builder->SetKernelType(AKG_KERNEL);
  372. } else if (imply_type == kGPU) {
  373. builder->SetKernelType(GPU_KERNEL);
  374. } else if (imply_type == kCPU) {
  375. builder->SetKernelType(CPU_KERNEL);
  376. } else if (imply_type == kAICPU) {
  377. builder->SetKernelType(AICPU_KERNEL);
  378. } else {
  379. builder->SetKernelType(TBE_KERNEL);
  380. }
  381. }
  382. bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
  383. std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
  384. MS_EXCEPTION_IF_NULL(kernel_node);
  385. MS_EXCEPTION_IF_NULL(kernel_info_list);
  386. size_t real_input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
  387. size_t real_output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
  388. std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
  389. std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
  390. std::vector<int64_t> dyn_input_sizes;
  391. auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel_node);
  392. MS_EXCEPTION_IF_NULL(primitive);
  393. auto op_name = common::AnfAlgo::GetCNodeName(kernel_node);
  394. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  395. dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr("dyn_input_sizes"));
  396. }
  397. if (inputs.size() > 0) {
  398. if (inputs[0] == nullptr) {
  399. MS_LOG(EXCEPTION) << "Inputs[0] is nullptr. Op name: " << op_name;
  400. }
  401. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  402. for (size_t j = 0; j < kernel_info_cnt; j++) {
  403. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  404. MS_EXCEPTION_IF_NULL(builder);
  405. SetKernelBuildInfo(builder, processor, op_info_ptr);
  406. if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
  407. MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed. Op name: " << op_name;
  408. return false;
  409. }
  410. if (outputs.size() > 0) {
  411. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  412. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
  413. return false;
  414. }
  415. }
  416. kernel_info_list->push_back(builder->Build());
  417. }
  418. } else if (outputs.size() > 0) {
  419. if (outputs[0] == nullptr) {
  420. MS_LOG(EXCEPTION) << "Outputs[0] is nullptr. Op name: " << op_name;
  421. }
  422. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  423. for (size_t j = 0; j < kernel_info_cnt; j++) {
  424. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  425. MS_EXCEPTION_IF_NULL(builder);
  426. SetKernelBuildInfo(builder, processor, op_info_ptr);
  427. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  428. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
  429. return false;
  430. }
  431. kernel_info_list->push_back(builder->Build());
  432. }
  433. } else {
  434. if (processor == AICPU) {
  435. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  436. MS_EXCEPTION_IF_NULL(builder);
  437. SetKernelBuildInfo(builder, processor, op_info_ptr);
  438. kernel_info_list->push_back(builder->Build());
  439. }
  440. }
  441. return true;
  442. }
  443. void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
  444. std::string path = base_path + json_name + kInfoSuffix;
  445. auto realpath = Common::CreatePrefixPath(path);
  446. if (!realpath.has_value()) {
  447. MS_LOG(ERROR) << "Get real path failed, path=" << path;
  448. return;
  449. }
  450. ChangeFileMode(realpath.value(), S_IWUSR);
  451. std::ofstream filewrite(realpath.value());
  452. if (!filewrite.is_open()) {
  453. MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!";
  454. return;
  455. }
  456. filewrite << info << std::endl;
  457. filewrite.close();
  458. ChangeFileMode(realpath.value(), S_IRUSR);
  459. }
  460. Processor GetProcessor(const string &processor) {
  461. if (processor == kProcessorAiCore) return Processor::AICORE;
  462. if (processor == kProcessorAiCpu) return Processor::AICPU;
  463. if (processor == kProcessorCuda) return Processor::CUDA;
  464. MS_LOG(DEBUG) << "Unknown processor type.";
  465. return Processor::UNKNOWN;
  466. }
  467. std::string GetProcessor(const AnfNodePtr &anf_node) {
  468. MS_EXCEPTION_IF_NULL(anf_node);
  469. std::string device;
  470. switch (AnfAlgo::GetProcessor(anf_node)) {
  471. case Processor::AICORE:
  472. device = kProcessorAiCore;
  473. break;
  474. case Processor::AICPU:
  475. device = kProcessorAiCpu;
  476. break;
  477. case Processor::CUDA:
  478. device = kProcessorCuda;
  479. break;
  480. default:
  481. MS_LOG(DEBUG) << "Unknown processor type.";
  482. break;
  483. }
  484. return device;
  485. }
  486. bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b) {
  487. if (shape_a.size() != shape_b.size()) {
  488. return false;
  489. }
  490. for (size_t i = 0; i < shape_a.size(); ++i) {
  491. if (shape_a[i] != shape_b[i]) {
  492. return false;
  493. }
  494. }
  495. return true;
  496. }
  497. std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
  498. const std::vector<AnfNodePtr> &input_list,
  499. const std::vector<AnfNodePtr> &output_list) {
  500. std::vector<std::pair<AnfNodePtr, size_t>> output_index;
  501. for (size_t i = 0; i < output_list.size(); ++i) {
  502. auto const &output = output_list[i];
  503. MS_EXCEPTION_IF_NULL(output);
  504. bool found = false;
  505. auto pree_node = common::AnfAlgo::VisitKernel(output, 0);
  506. auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
  507. if (pos != std::end(node_list)) {
  508. output_index.push_back(pree_node);
  509. continue;
  510. }
  511. auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
  512. if (ret != std::end(input_list)) {
  513. output_index.push_back(std::make_pair(pree_node.first, 0));
  514. found = true;
  515. }
  516. if (!found) {
  517. MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
  518. << output->func_graph()->ToString() << "] found no related kernel info.";
  519. }
  520. }
  521. return output_index;
  522. }
  523. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
  524. MS_EXCEPTION_IF_NULL(node_list);
  525. MS_EXCEPTION_IF_NULL(func_graph);
  526. std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
  527. for (auto const &node : node_lists) {
  528. if (!AnfUtils::IsRealKernel(node) || !node->isa<CNode>()) {
  529. continue;
  530. }
  531. auto cnode = node->cast<CNodePtr>();
  532. MS_EXCEPTION_IF_NULL(cnode);
  533. if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
  534. node_list->push_back(node);
  535. }
  536. }
  537. }
  538. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
  539. std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
  540. MS_EXCEPTION_IF_NULL(func_graph);
  541. MS_EXCEPTION_IF_NULL(node_list);
  542. MS_EXCEPTION_IF_NULL(input_list);
  543. GetValidKernelNodes(func_graph, node_list);
  544. auto parameters = func_graph->parameters();
  545. input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
  546. GetFuncGraphOutputNodes(func_graph, output_list);
  547. }
  548. void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
  549. MS_EXCEPTION_IF_NULL(func_graph);
  550. MS_EXCEPTION_IF_NULL(output_list);
  551. auto func_output = func_graph->output();
  552. MS_EXCEPTION_IF_NULL(func_output);
  553. if (func_output->isa<CNode>()) {
  554. // multi output.
  555. auto cnode = func_output->cast<CNodePtr>();
  556. MS_EXCEPTION_IF_NULL(cnode);
  557. auto input0 = cnode->input(kAnfPrimitiveIndex);
  558. MS_EXCEPTION_IF_NULL(input0);
  559. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  560. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
  561. auto input_node = cnode->input(input_idx);
  562. MS_EXCEPTION_IF_NULL(input_node);
  563. if (input_node->isa<CNode>() && common::AnfAlgo::GetInputTensorNum(input_node) == 0) {
  564. continue;
  565. }
  566. output_list->push_back(common::AnfAlgo::VisitKernel(input_node, 0).first);
  567. }
  568. } else {
  569. // single output.
  570. output_list->push_back(common::AnfAlgo::VisitKernel(func_output, 0).first);
  571. }
  572. } else {
  573. // single output.
  574. output_list->push_back(common::AnfAlgo::VisitKernel(func_output, 0).first);
  575. }
  576. }
  577. bool IsWeightBoundary(const AnfNodePtr &node) {
  578. if (node->isa<ValueNode>()) {
  579. return true;
  580. }
  581. if (node->isa<Parameter>() && common::AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  582. return true;
  583. }
  584. return false;
  585. }
  586. std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
  587. if (common::AnfAlgo::GetInputTensorNum(cnode) != 1 || common::AnfAlgo::GetOutputTensorNum(cnode) != 1) {
  588. MS_LOG(EXCEPTION) << "The reduce node [" << cnode->DebugString() << "] is not single input or single output."
  589. << trace::DumpSourceLines(cnode);
  590. }
  591. std::vector<int64_t> axis;
  592. auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
  593. auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
  594. MS_EXCEPTION_IF_NULL(primitive);
  595. auto axis_attr = primitive->GetAttr(kAxis);
  596. if (axis_attr == nullptr) {
  597. MS_LOG(ERROR) << "This node doesn't have axis attr. Node info [" << cnode->DebugString() << "]";
  598. return std::vector<int64_t>();
  599. }
  600. std::vector<int64_t> axis_list;
  601. if (axis_attr->isa<Int64Imm>()) {
  602. (void)axis_list.emplace_back(GetValue<int64_t>(axis_attr));
  603. } else {
  604. axis_list = GetValue<std::vector<int64_t>>(axis_attr);
  605. }
  606. for (const auto &elem : axis_list) {
  607. if (elem < 0) {
  608. (void)axis.emplace_back(input_shape.size() + elem);
  609. } else {
  610. (void)axis.emplace_back(elem);
  611. }
  612. }
  613. common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
  614. return axis;
  615. }
  616. void FillEmptyDims(const CNodePtr &kernel_node, std::vector<int64_t> *begin, std::vector<int64_t> *end,
  617. std::vector<int64_t> *stride, std::vector<size_t> *input_shape) {
  618. std::vector<int64_t> &_begin = *begin;
  619. std::vector<int64_t> &_end = *end;
  620. std::vector<int64_t> &_stride = *stride;
  621. std::vector<size_t> &_input_shape = *input_shape;
  622. if (_begin.size() != _end.size() || _begin.size() != _stride.size() || _begin.size() > _input_shape.size()) {
  623. MS_LOG(EXCEPTION) << "For '" << common::AnfAlgo::GetCNodeName(kernel_node)
  624. << "', the length of 'begin', 'stride' and 'end' should be equal "
  625. "and less than or equal to the dimension of 'input_x', but got the length of 'begin': "
  626. << _begin.size() << ", the length of 'stride': " << _stride.size()
  627. << ", the length of 'end': " << _end.size()
  628. << ", the dimension of 'input_x': " << _input_shape.size();
  629. }
  630. for (size_t i = 0; i < kStridedSliceMaxDims; i++) {
  631. if (i >= _input_shape.size()) {
  632. _input_shape.push_back(1);
  633. }
  634. if (i < _begin.size()) {
  635. int64_t dim = SizeToLong(_input_shape[i]);
  636. _begin[i] = std::min(_begin[i] < 0 ? std::max(_begin[i] + dim, static_cast<int64_t>(0)) : _begin[i], dim - 1);
  637. } else {
  638. _begin.push_back(0);
  639. }
  640. if (i < _end.size()) {
  641. int64_t dim = SizeToLong(_input_shape[i]);
  642. _end[i] = std::max(_end[i] < 0 ? _end[i] + dim : std::min(_end[i], dim), static_cast<int64_t>(-1));
  643. } else {
  644. _end.push_back(i < _input_shape.size() ? SizeToLong(_input_shape[i]) : 1);
  645. }
  646. if (i >= _stride.size()) {
  647. _stride.push_back(1);
  648. }
  649. }
  650. }
  651. std::vector<bool> Dec2Bin(const int64_t &mask) {
  652. auto mask_str = std::bitset<kStridedSliceMaxDims>(mask).to_string();
  653. int64_t dim_idx = 0;
  654. std::vector<bool> result(kStridedSliceMaxDims, false);
  655. for (int64_t i = mask_str.size() - 1; i >= 0; i--) {
  656. if (mask_str[i] == '1') {
  657. result[dim_idx] = true;
  658. }
  659. dim_idx++;
  660. }
  661. return result;
  662. }
  663. void ComputeBeginMask(const CNodePtr &kernel_node, std::vector<int64_t> *begin, const std::vector<int64_t> &stride,
  664. const std::vector<size_t> &input_shape) {
  665. std::vector<int64_t> &_begin = *begin;
  666. auto begin_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrBeginMask);
  667. auto begin_mask = Dec2Bin(begin_mask_int);
  668. for (size_t i = 0; i < begin_mask.size(); i++) {
  669. if (i < kStridedSliceMaxDims && begin_mask[i]) {
  670. _begin[i] = stride[i] < 0 ? SizeToLong(input_shape[i]) - 1 : 0;
  671. }
  672. }
  673. }
  674. void ComputeEndMask(const CNodePtr &kernel_node, std::vector<int64_t> *end, const std::vector<int64_t> &stride,
  675. const std::vector<size_t> &input_shape) {
  676. std::vector<int64_t> &_end = *end;
  677. auto end_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEndMask);
  678. auto end_mask = Dec2Bin(end_mask_int);
  679. for (size_t j = 0; j < end_mask.size(); j++) {
  680. if (j < kStridedSliceMaxDims && end_mask[j]) {
  681. _end[j] = stride[j] < 0 ? -1 : SizeToLong(input_shape[j]);
  682. }
  683. }
  684. }
  685. void ComputeEllipsisMask(const CNodePtr &kernel_node, std::vector<int64_t> *begin, std::vector<int64_t> *end,
  686. std::vector<int64_t> *stride, const std::vector<size_t> &input_shape) {
  687. std::vector<int64_t> &_begin = *begin;
  688. std::vector<int64_t> &_end = *end;
  689. std::vector<int64_t> &_stride = *stride;
  690. auto ellipsis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEllipsisMask);
  691. auto ellipsis_mask = Dec2Bin(ellipsis_mask_int);
  692. for (size_t k = 0; k < ellipsis_mask.size(); k++) {
  693. if (k < kStridedSliceMaxDims && ellipsis_mask[k]) {
  694. _begin[k] = 0;
  695. _end[k] = SizeToLong(input_shape[k]);
  696. _stride[k] = 1;
  697. }
  698. }
  699. }
  700. void ComputNewAxisMask(const CNodePtr &kernel_node, std::vector<int64_t> *begin, std::vector<int64_t> *end,
  701. std::vector<int64_t> *stride, const std::vector<size_t> &input_shape) {
  702. std::vector<int64_t> &_begin = *begin;
  703. std::vector<int64_t> &_end = *end;
  704. std::vector<int64_t> &_stride = *stride;
  705. auto new_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrNewAxisMask);
  706. auto new_axis_mask = Dec2Bin(new_axis_mask_int);
  707. for (size_t l = 0; l < new_axis_mask.size(); l++) {
  708. if (l < kStridedSliceMaxDims && new_axis_mask[l]) {
  709. _begin[l] = 0;
  710. _end[l] = SizeToLong(input_shape[l]);
  711. _stride[l] = 1;
  712. }
  713. }
  714. }
  715. void ComputShrinkAxisMask(const CNodePtr &kernel_node, const std::vector<int64_t> &begin, std::vector<int64_t> *end,
  716. std::vector<int64_t> *stride) {
  717. std::vector<int64_t> &_end = *end;
  718. std::vector<int64_t> &_stride = *stride;
  719. auto shrink_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrShrinkAxisMask);
  720. auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
  721. for (size_t m = 0; m < shrink_axis_mask.size(); m++) {
  722. if (m < kStridedSliceMaxDims && shrink_axis_mask[m]) {
  723. _end[m] = _end[m] > begin[m] ? begin[m] + 1 : begin[m] - 1;
  724. _stride[m] = _end[m] > begin[m] ? 1 : -1;
  725. }
  726. }
  727. }
  728. void ParseStrideSliceMasks(const CNodePtr &kernel_node, std::vector<int64_t> *begin, std::vector<int64_t> *end,
  729. std::vector<int64_t> *stride, const std::vector<size_t> &input_shape) {
  730. ComputeBeginMask(kernel_node, begin, *stride, input_shape);
  731. ComputeEndMask(kernel_node, end, *stride, input_shape);
  732. ComputeEllipsisMask(kernel_node, begin, end, stride, input_shape);
  733. ComputNewAxisMask(kernel_node, begin, end, stride, input_shape);
  734. ComputShrinkAxisMask(kernel_node, *begin, end, stride);
  735. }
  736. std::string GetProcessorStr(const AnfNodePtr &anf_node) {
  737. MS_EXCEPTION_IF_NULL(anf_node);
  738. std::string processor = kProcessorUnknown;
  739. auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
  740. MS_EXCEPTION_IF_NULL(kernel_info);
  741. auto build_info = kernel_info->select_kernel_build_info();
  742. // we may call this before kernel select.
  743. if (build_info == nullptr) {
  744. return processor;
  745. }
  746. switch (build_info->processor()) {
  747. case Processor::AICORE:
  748. processor = kProcessorAiCore;
  749. break;
  750. case Processor::AICPU:
  751. processor = kProcessorAiCpu;
  752. break;
  753. case Processor::CUDA:
  754. processor = kProcessorCuda;
  755. break;
  756. default:
  757. MS_LOG(ERROR) << "Unknown processor type.";
  758. break;
  759. }
  760. return processor;
  761. }
  762. Processor GetProcessorFromContext() {
  763. kernel::Processor processor = kernel::Processor::UNKNOWN;
  764. auto context_ptr = MsContext::GetInstance();
  765. MS_EXCEPTION_IF_NULL(context_ptr);
  766. auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  767. if (device_info == kGPUDevice) {
  768. processor = kernel::Processor::CUDA;
  769. } else if (device_info == kAscendDevice) {
  770. processor = kernel::Processor::AICORE;
  771. } else if (device_info == kCPUDevice) {
  772. processor = kernel::Processor::CPU;
  773. }
  774. return processor;
  775. }
  776. std::string GetStrProcessorFromContext() {
  777. auto processor = GetProcessorFromContext();
  778. string str_processor = kernel::kProcessorUnknown;
  779. if (processor == kernel::Processor::CUDA) {
  780. str_processor = kernel::kProcessorCuda;
  781. } else if (processor == kernel::Processor::AICORE) {
  782. str_processor = kernel::kProcessorAiCore;
  783. } else if (processor == kernel::Processor::CPU) {
  784. str_processor = kernel::kProcessorCpu;
  785. }
  786. return str_processor;
  787. }
  788. float Scaling(size_t in_size, size_t out_size, bool align_corners) {
  789. return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
  790. : in_size / static_cast<float>(out_size);
  791. }
  792. float ScaleGrid(const int x, const float scale) { return static_cast<float>(x) * scale; }
  793. void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale,
  794. CachedInterpolation *interpolation) {
  795. interpolation[out_size].lower = 0;
  796. interpolation[out_size].upper = 0;
  797. for (size_t i = 0; i <= out_size - 1; ++i) {
  798. const float in = ScaleGrid(i, scale);
  799. const float in_f = std::floor(in);
  800. interpolation[i].lower = std::max(static_cast<size_t>(in_f), static_cast<size_t>(0));
  801. interpolation[i].upper = std::min(static_cast<size_t>(std::ceil(in)), in_size - 1);
  802. interpolation[i].lerp = in - in_f;
  803. }
  804. }
  805. bool GetShapeSize(const std::vector<size_t> &shape, const TypePtr &type_ptr, int64_t *size_i) {
  806. MS_EXCEPTION_IF_NULL(type_ptr);
  807. size_t type_byte = GetTypeByte(type_ptr);
  808. if (type_byte == 0) {
  809. return false;
  810. }
  811. for (size_t j = 0; j < shape.size(); j++) {
  812. size_i[0] = LongMulWithOverflowCheck(size_i[0], static_cast<int64_t>(shape[j]));
  813. }
  814. size_i[0] = LongMulWithOverflowCheck(size_i[0], SizeToInt(type_byte));
  815. return true;
  816. }
  817. void CastShapeSizeToLong(const std::vector<size_t> &shape, std::vector<int64_t> *long_shape) {
  818. MS_EXCEPTION_IF_NULL(long_shape);
  819. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(*long_shape), SizeToLong);
  820. }
  821. void CheckSliceValid(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
  822. const std::vector<int64_t> &step, const std::vector<int64_t> &input_shape) {
  823. if (start.size() != stop.size() || start.size() != step.size() || start.size() > input_shape.size()) {
  824. MS_LOG(EXCEPTION)
  825. << "TensorCopySlices requires the length of begin, stride and end must be equal and less than input dimension.";
  826. }
  827. size_t size = start.size();
  828. for (size_t i = 0; i < size; ++i) {
  829. if (stop[i] <= start[i]) {
  830. MS_LOG(EXCEPTION) << "Invalid slice: (" << start[i] << ", " << stop[i] << " ," << step[i] << ")";
  831. }
  832. // Operator need to be generalized in the future. Only support to copy continuous memory now.
  833. if (step[i] != 1) {
  834. MS_LOG(EXCEPTION) << "The element in step only support 1, but got:" << step;
  835. }
  836. }
  837. size_t slice_pos = size;
  838. for (size_t i = 0; i < size; ++i) {
  839. if (stop[i] - start[i] > 1) {
  840. slice_pos = i;
  841. break;
  842. }
  843. }
  844. for (size_t i = slice_pos + 1; i < size; ++i) {
  845. if (stop[i] - start[i] != input_shape[i]) {
  846. MS_LOG(EXCEPTION) << "Only support copy continuous memory now. For example tensor[0, 0:100] is fine, "
  847. "but tensor[0:100, 0] is not supported.";
  848. }
  849. }
  850. }
  851. size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
  852. const std::vector<int64_t> &stop) {
  853. for (size_t i = 0; i < start.size(); ++i) {
  854. if (stop[i] - start[i] != 1) {
  855. return SizetMulWithOverflowCheck(LongToSize(stop[i] - start[i]), LongToSize(dim_offset[i]));
  856. }
  857. }
  858. return LongToSize(dim_offset[start.size() - 1]);
  859. }
  860. std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape) {
  861. std::vector<int64_t> dim_offset;
  862. int64_t offset = 1;
  863. for (auto iter = input_shape.rbegin(); iter != input_shape.rend(); ++iter) {
  864. dim_offset.push_back(offset);
  865. offset = offset * (*iter);
  866. }
  867. std::reverse(dim_offset.begin(), dim_offset.end());
  868. return dim_offset;
  869. }
  870. size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
  871. const std::vector<int64_t> &dim_offset) {
  872. size_t size = start.size();
  873. size_t offset = 0;
  874. for (size_t i = 0; i < size; ++i) {
  875. offset += SizetMulWithOverflowCheck(LongToSize(dim_offset[i]), LongToSize(start[i]));
  876. if (stop[i] - start[i] != 1) {
  877. break;
  878. }
  879. }
  880. return offset;
  881. }
  882. size_t UnitSizeInBytes(const mindspore::TypeId &t) {
  883. size_t bytes = 0;
  884. switch (t) {
  885. case kNumberTypeBool:
  886. case kNumberTypeInt8:
  887. case kNumberTypeUInt8:
  888. bytes = sizeof(int8_t);
  889. break;
  890. case kNumberTypeInt16:
  891. case kNumberTypeUInt16:
  892. case kNumberTypeFloat16:
  893. bytes = sizeof(int16_t);
  894. break;
  895. case kNumberTypeInt:
  896. case kNumberTypeUInt:
  897. case kNumberTypeInt32:
  898. case kNumberTypeUInt32:
  899. case kNumberTypeFloat:
  900. case kNumberTypeFloat32:
  901. bytes = sizeof(int32_t);
  902. break;
  903. case kNumberTypeUInt64:
  904. case kNumberTypeInt64:
  905. case kNumberTypeFloat64:
  906. bytes = sizeof(int64_t);
  907. break;
  908. case kNumberTypeInt4:
  909. default:
  910. MS_LOG(EXCEPTION) << "Invalid types " << t;
  911. }
  912. return bytes;
  913. }
  914. KernelAttr &KernelAttr::AddInputAttr(const TypeId &ms_type, const std::string &format) {
  915. input_type_.emplace_back(ms_type, format);
  916. return *this;
  917. }
  918. KernelAttr &KernelAttr::AddOutputAttr(const TypeId &ms_type, const std::string &format) {
  919. output_type_.emplace_back(ms_type, format);
  920. return *this;
  921. }
  922. KernelAttr &KernelAttr::AddAllSameAttr(const bool &all_same) {
  923. all_same_ = all_same;
  924. return *this;
  925. }
  926. KernelAttr &KernelAttr::AddOutInRef(size_t output_index, size_t input_index) {
  927. out_in_ref_map_[output_index] = input_index;
  928. return *this;
  929. }
  930. void KernelAttr::SetInputAttrList(const std::vector<DataType> &addr_list) {
  931. input_type_.assign(addr_list.begin(), addr_list.end());
  932. }
  933. std::ostream &operator<<(std::ostream &os, KernelAttr kernel_attr) {
  934. std::stringstream ss;
  935. ss << "[Kernel Attr] all same: " << kernel_attr.GetAllSame();
  936. size_t input_num = kernel_attr.GetInputSize();
  937. if (input_num > 0) {
  938. ss << ", input(";
  939. for (size_t i = 0; i < input_num; ++i) {
  940. ss << TypeIdLabel(kernel_attr.GetInputAttr(i).first);
  941. if (i != input_num - 1) {
  942. ss << ",";
  943. }
  944. }
  945. ss << ") ";
  946. }
  947. size_t output_num = kernel_attr.GetOutputSize();
  948. if (output_num > 0) {
  949. ss << ", output(";
  950. for (size_t i = 0; i < output_num; ++i) {
  951. ss << TypeIdLabel(kernel_attr.GetOutputAttr(i).first);
  952. if (i != output_num - 1) {
  953. ss << ",";
  954. }
  955. }
  956. ss << ").";
  957. }
  958. return os << ss.str();
  959. }
  960. std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr,
  961. const std::vector<KernelAttr> &kernel_attr_list) {
  962. // kernel_attr should not be all same. If so, then return false.
  963. if (kernel_attr.GetAllSame()) {
  964. return std::make_pair(false, 0);
  965. }
  966. auto input_num = kernel_attr.GetInputSize();
  967. auto output_num = kernel_attr.GetOutputSize();
  968. for (size_t index = 0; index < kernel_attr_list.size(); ++index) {
  969. const auto &cur_kernel_attr = kernel_attr_list[index];
  970. auto cur_input_num = cur_kernel_attr.GetInputSize();
  971. auto cur_output_num = cur_kernel_attr.GetOutputSize();
  972. if (!cur_kernel_attr.GetAllSame() && (input_num != cur_input_num || output_num != cur_output_num)) {
  973. continue;
  974. }
  975. bool mis_match = false;
  976. for (size_t i = 0; i < input_num; ++i) {
  977. auto dtype = cur_kernel_attr.GetInputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).first;
  978. auto format = cur_kernel_attr.GetInputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).second;
  979. if (kernel_attr.GetInputAttr(i).first != dtype || kernel_attr.GetInputAttr(i).second != format) {
  980. mis_match = true;
  981. break;
  982. }
  983. }
  984. if (mis_match) {
  985. continue;
  986. }
  987. for (size_t i = 0; i < output_num; ++i) {
  988. auto dtype = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).first;
  989. auto format = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).second;
  990. if (kernel_attr.GetOutputAttr(i).first != dtype || kernel_attr.GetOutputAttr(i).second != format) {
  991. mis_match = true;
  992. break;
  993. }
  994. }
  995. if (!mis_match) {
  996. return std::make_pair(true, index);
  997. }
  998. }
  999. return std::make_pair(false, 0);
  1000. }
  1001. KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info) {
  1002. MS_EXCEPTION_IF_NULL(build_info);
  1003. KernelAttr kernel_attr;
  1004. for (size_t i = 0; i < build_info->GetInputNum(); i++) {
  1005. kernel_attr.AddInputAttr(build_info->GetInputDeviceType(i), build_info->GetInputFormat(i));
  1006. }
  1007. for (size_t j = 0; j < build_info->GetOutputNum(); j++) {
  1008. kernel_attr.AddOutputAttr(build_info->GetOutputDeviceType(j), build_info->GetOutputFormat(j));
  1009. }
  1010. return kernel_attr;
  1011. }
  1012. KernelAttr GetKernelAttrFromNode(const AnfNodePtr &kernel_node) {
  1013. auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
  1014. return GetKernelAttrFromBuildInfo(build_info);
  1015. }
  1016. } // namespace kernel
  1017. } // namespace mindspore